Model training

Run the following command to train a model:

python scripts/train.py scripts/train.yaml

The scripts/train.yaml file contains the configuration for the training process. It is divided into three sections:

  1. model: Configuration for the model architecture and hyperparameters.

  2. dataset: Configuration for the dataset loading and processing.

  3. train: Configuration for the training loop, including the optimizer, loss function, and learning rate schedule.

Configuration details

The model section contains parameters for model initialization. It can be either a pre-trained model or a new model. To use a pre-trained model, use the following parameters:

  • pretrained: The path to the pre-trained model weights.

  • drop_force_head: Whether to drop the force head of the pre-trained model (if any). Default is False.

  • dtype: The computation data type. Default is the same as the parameters data type.

To create a new model, use the following parameters:

  • target: The name of the model class in dipm.models.KNOWN_MODELS to use. Not case sensitive or underline sensitive.

  • seed: The random seed for model initialization.

  • dtype: The computation data type. Default is the same as the parameters data type.

  • Model specific parameters: Parameters specified in the Config class of the model.

    • For target MACE: See MaceConfig for details.

    • For target NequIP: See NequipConfig for details.

    • For target ViSNet: See VisnetConfig for details.

    • For target LiTEN: See LiTENConfig for details.

    • For target Equiformer V2: See EquiformerV2Config for details.

    • For target UMA: See UMAConfig for details.

The dataset section contains the following parameters:

  • info_file: The path to the dataset information JSON file generated by GraphDatasetBuilder. It is typically in the checkpoints directory. If exists, dataset information computation is skipped.

  • Reader specific parameters: See ChemicalDatasetsConfig for details.

  • Dataset builder specific parameters: See GraphDatasetBuilderConfig for details.

The train section contains the following parameters:

  • loss: The name of the loss function class in dipm.loss.KNOWN_LOSSES to use. Not case sensitive or underline sensitive. mse or mse_loss are both valid.

  • extended_metrics: Whether to include an extended list of metrics such as mae_e_per_atom and learning_rate. Default is False.

  • parallel: Whether to use parallel training. By default, it is determined by the number of available devices.

  • dtype: The computation data type used in mixed precision training. Default is the same as the parameters data type (float32).

  • save_path: The path to save the final model weights. Must end with .safetensors.

  • stage_splits: If provided, the training loop will perform two stages of training, where the the energy weight is small in the first stage and increased to a larger value in the second stage. This parameter specifies the portion of datasets used in each stage.

  • energy_weights: The scheduled energy weights in the two stage training.

  • forces_weights: The scheduled forces weights in the two stage training.

  • tensorboard: The project name for TensorBoard. If not provided, TensorBoard logging is disabled. Logs will be saved in ./logs/project_name/datetime/.

  • wandb: The project name for Weights and Biases. If not provided, W&B logging is disabled. Logs will be saved in ./logs/project_name/.

  • Train loop specific parameters: See TrainingLoopConfig for details.

  • Optimizer specific parameters: See OptimizerConfig for details.

  • IO handler specific parameters: See TrainingIOHandlerConfig for details.

Example

Here is an example of a scripts/train.yaml file:

model:
  target: EquiformerV2    # Model class name in dipm.models
  num_layers: 8           # Defined in EquiformerV2Config
  num_rbf: 512            # Defined in EquiformerV2Config
  lmax: 4                 # Defined in EquiformerV2Config
  grid_resolution: 18     # Defined in EquiformerV2Config
  avg_num_neighbors: null # Defined in EquiformerV2Config, default to Meta IS2RE value
  avg_num_nodes: null     # Defined in EquiformerV2Config, default to Meta IS2RE value
  seed: 42                # Random seed for model initialization

dataset:
  train_dataset_paths: ./datasets/train      # Defined in ChemicalDatasetsConfig
  valid_dataset_paths: ./datasets/val        # Defined in ChemicalDatasetsConfig
  test_dataset_paths: ./datasets/test        # Defined in ChemicalDatasetsConfig
  graph_cutoff_angstrom: 5.0                 # Defined in GraphDatasetBuilderConfig
  batch_size: 16                             # Defined in GraphDatasetBuilderConfig
  info_file: ./checkpoints/dataset_info.json # Skip info calculation if exists

train:
  num_epochs: 10                        # Defined in TrainingLoopConfig
  final_learning_rate: 1e-4             # Defined in OptimizerConfig
  dtype: bfloat16                       # Computation dtype in mixed precision training
  loss: MSELoss                         # Loss class name in dipm.loss
  local_model_output_dir: ./checkpoints # Defined in TrainingIOHandlerConfig
  restore_checkpoint_if_exists: true    # Defined in TrainingIOHandlerConfig
  restore_optimizer_state: true         # Defined in TrainingIOHandlerConfig
  save_path: ./equiformerv2.safetensors # Final save file name
  stage_splits: [0.7, 0.3]              # Two stage training
  energy_weights: [1.0, 25.0]           # Schedule for energy weight in stage 1 and 2
  forces_weights: [25.0, 1.0]           # Schedule for forces weight in stage 1 and 2
  tensorboard: ./runs/equiformerv2      # Log directory for TensorBoard
  wandb: equiformerv2_train             # Project name for Weights and Biases

Advanced

See Advanced model training for training with custom codes.