Model training¶
Run the following command to train a model:
python scripts/train.py scripts/train.yaml
The scripts/train.yaml file contains the configuration for the training process. It is
divided into three sections:
model: Configuration for the model architecture and hyperparameters.dataset: Configuration for the dataset loading and processing.train: Configuration for the training loop, including the optimizer, loss function, and learning rate schedule.
Configuration details¶
The model section contains parameters for model initialization. It can be either a pre-trained
model or a new model. To use a pre-trained model, use the following parameters:
pretrained: The path to the pre-trained model weights.drop_force_head: Whether to drop the force head of the pre-trained model (if any). Default isFalse.dtype: The computation data type. Default is the same as the parameters data type.
To create a new model, use the following parameters:
target: The name of the model class indipm.models.KNOWN_MODELSto use. Not case sensitive or underline sensitive.seed: The random seed for model initialization.dtype: The computation data type. Default is the same as the parameters data type.Model specific parameters: Parameters specified in the
Configclass of the model.For target
MACE: SeeMaceConfigfor details.For target
NequIP: SeeNequipConfigfor details.For target
ViSNet: SeeVisnetConfigfor details.For target
LiTEN: SeeLiTENConfigfor details.For target
Equiformer V2: SeeEquiformerV2Configfor details.For target
UMA: SeeUMAConfigfor details.
The dataset section contains the following parameters:
info_file: The path to the dataset information JSON file generated byGraphDatasetBuilder. It is typically in thecheckpointsdirectory. If exists, dataset information computation is skipped.Reader specific parameters: See
ChemicalDatasetsConfigfor details.Dataset builder specific parameters: See
GraphDatasetBuilderConfigfor details.
The train section contains the following parameters:
loss: The name of the loss function class indipm.loss.KNOWN_LOSSESto use. Not case sensitive or underline sensitive.mseormse_lossare both valid.extended_metrics: Whether to include an extended list of metrics such asmae_e_per_atomandlearning_rate. Default isFalse.parallel: Whether to use parallel training. By default, it is determined by the number of available devices.dtype: The computation data type used in mixed precision training. Default is the same as the parameters data type (float32).save_path: The path to save the final model weights. Must end with.safetensors.stage_splits: If provided, the training loop will perform two stages of training, where the the energy weight is small in the first stage and increased to a larger value in the second stage. This parameter specifies the portion of datasets used in each stage.energy_weights: The scheduled energy weights in the two stage training.forces_weights: The scheduled forces weights in the two stage training.tensorboard: The project name for TensorBoard. If not provided, TensorBoard logging is disabled. Logs will be saved in./logs/project_name/datetime/.wandb: The project name for Weights and Biases. If not provided, W&B logging is disabled. Logs will be saved in./logs/project_name/.Train loop specific parameters: See
TrainingLoopConfigfor details.Optimizer specific parameters: See
OptimizerConfigfor details.IO handler specific parameters: See
TrainingIOHandlerConfigfor details.
Example¶
Here is an example of a scripts/train.yaml file:
model:
target: EquiformerV2 # Model class name in dipm.models
num_layers: 8 # Defined in EquiformerV2Config
num_rbf: 512 # Defined in EquiformerV2Config
lmax: 4 # Defined in EquiformerV2Config
grid_resolution: 18 # Defined in EquiformerV2Config
avg_num_neighbors: null # Defined in EquiformerV2Config, default to Meta IS2RE value
avg_num_nodes: null # Defined in EquiformerV2Config, default to Meta IS2RE value
seed: 42 # Random seed for model initialization
dataset:
train_dataset_paths: ./datasets/train # Defined in ChemicalDatasetsConfig
valid_dataset_paths: ./datasets/val # Defined in ChemicalDatasetsConfig
test_dataset_paths: ./datasets/test # Defined in ChemicalDatasetsConfig
graph_cutoff_angstrom: 5.0 # Defined in GraphDatasetBuilderConfig
batch_size: 16 # Defined in GraphDatasetBuilderConfig
info_file: ./checkpoints/dataset_info.json # Skip info calculation if exists
train:
num_epochs: 10 # Defined in TrainingLoopConfig
final_learning_rate: 1e-4 # Defined in OptimizerConfig
dtype: bfloat16 # Computation dtype in mixed precision training
loss: MSELoss # Loss class name in dipm.loss
local_model_output_dir: ./checkpoints # Defined in TrainingIOHandlerConfig
restore_checkpoint_if_exists: true # Defined in TrainingIOHandlerConfig
restore_optimizer_state: true # Defined in TrainingIOHandlerConfig
save_path: ./equiformerv2.safetensors # Final save file name
stage_splits: [0.7, 0.3] # Two stage training
energy_weights: [1.0, 25.0] # Schedule for energy weight in stage 1 and 2
forces_weights: [25.0, 1.0] # Schedule for forces weight in stage 1 and 2
tensorboard: ./runs/equiformerv2 # Log directory for TensorBoard
wandb: equiformerv2_train # Project name for Weights and Biases
Advanced¶
See Advanced model training for training with custom codes.