Data processing

Set up graph dataset builder

In order to train a model or run batched inference, one needs to process the data into objects of type GraphDataset. This can be achieved by using the GraphDatasetBuilder class, which can be instantiated from its associated pydantic config and a tuple of datasets that is derived from the Dataset base class:

from dipm.data import GraphDatasetBuilder

# datasets is a tuple of train, validation and test datasets
datasets = _get_datasets()  # this is a placeholder for the moment
builder_config = GraphDatasetBuilder.Config(
    graph_cutoff_angstrom=5.0,
    max_n_node=None,
    max_n_edge=None,
    batch_size=16,
)
graph_dataset_builder = GraphDatasetBuilder(datasets, builder_config)

In the example above, we set some example values for the settings in the GraphDatasetBuilderConfig. For simpler code, we allow to access this config object directly via GraphDatasetBuilder.Config. Check out the API reference of the class to see the full set of configurable values and for which values we have defaults available.

The datasets is a tuple of instances of Dataset class. This class allows to read a dataset into lists of ChemicalSystem objects via its __getitem__ method. You can either implement your own derived class to do this for your custom dataset format, or you can use bulit-in create_datasets.

from dipm.data import create_datasets, ChemicalDatasetsConfig

datasets_config = ChemicalDatasetsConfig(
    train_dataset_paths = "...",
    valid_dataset_paths = "...",
    test_dataset_paths = "...",
)

# If data is stored locally
datasets = create_datasets(datasets_config)

If you have multiple datasets in different formats and would like to combine them, or you want to slice a dataset without loading all of it, you can do so by instead using the ConcatDataset and Subset classes.

from dipm.data import ConcatDataset, Subset

datasets = _get_list_of_individual_chemical_datasets()  # placeholder
combined_dataset = ConcatDataset(datasets)

start, length = 0, 1000  # example slice
subset_dataset = Subset(combined_dataset, start, length)

This resulting dataset can then also be used as an input to the GraphDatasetBuilder.

Built-in graph dataset: data formats

We only provide one built-in core dataset: Hdf5Dataset.

To train an force model, we need a dataset of atomic systems with the following features per system with specific units:

  • the positions (i.e., coordinates) of the atoms in the structure in Angstrom

  • the element numbers of the atoms

  • the forces of the atoms in eV / Angstrom

  • the energy of the structure in eV

  • (optional) the stress of the structure in eV / Angstrom3

  • (optional) the periodic boundary conditions

For a detailed description of the data format that the Hdf5Dataset. requires, see here.

If you want to use different data formats or units, it is recommended to use our dataset conversion tool to convert your data into the required format. You can also implement your own derived class to read your custom dataset format.

Start preprocessing

Once you have the graph_dataset_builder set up, you can start the preprocessing and fetch the resulting datasets:

graph_dataset_builder.prepare_datasets()

splits = graph_dataset_builder.get_splits()
train_set, validation_set, test_set = splits

The resulting datasets are of type GraphDataset as mentioned above. For example, to process the batches in the training set, one can execute:

num_graphs = len(train_set.graphs)
num_batches = len(train_set)

for batch in train_set:
    _process_batch_in_some_way(batch)

Get sharded batches

If one wants to generate batches that are sharded across devices and prefetched, the arguments to the get_splits() member of the GraphDatasetBuilder must be set to the following:

splits = graph_dataset_builder.get_datasets(
    prefetch=True, devices=jax.local_devices()
)
train_set, valid_set, test_set = splits

Now, the datasets are not of type GraphDataset anymore, but of type PrefetchIterator instead which implements batch prefetching on top of the ParallelGraphDataset class. It can be iterated over to obtain the sharded batches in the same way, however, note that it does not have a graphs member that can be accessed directly.

Get dataset info

Furthermore, the builder class also populates a dataclass of type DatasetInfo, which contains metadata about the dataset which are relevant to the models while training and must be stored together with the models for these to be usable. The populated instance of this dataclass can be accessed easily like this:

# Note: this will call prepare_datasets() and give a warning if accessed
# before prepare_datasets() is run
dataset_info = graph_dataset_builder.dataset_info