.. _training: Model training ============== Run the following command to train a model: .. code-block:: bash python scripts/train.py scripts/train.yaml The `scripts/train.yaml` file contains the configuration for the training process. It is divided into three sections: 1. `model`: Configuration for the model architecture and hyperparameters. 2. `dataset`: Configuration for the dataset loading and processing. 3. `train`: Configuration for the training loop, including the optimizer, loss function, and learning rate schedule. Configuration details --------------------- The `model` section contains parameters for model initialization. It can be either a pre-trained model or a new model. To use a pre-trained model, use the following parameters: * `pretrained`: The path to the pre-trained model weights. * `drop_force_head`: Whether to drop the force head of the pre-trained model (if any). Default is `False`. * `dtype`: The computation data type. Default is the same as the parameters data type. To create a new model, use the following parameters: * `target`: The name of the model class in :py:const:`dipm.models.KNOWN_MODELS` to use. Not case sensitive or underline sensitive. * `seed`: The random seed for model initialization. * `dtype`: The computation data type. Default is the same as the parameters data type. * Model specific parameters: Parameters specified in the `Config` class of the model. * For target `MACE`: See :py:class:`MaceConfig ` for details. * For target `NequIP`: See :py:class:`NequipConfig ` for details. * For target `ViSNet`: See :py:class:`VisnetConfig ` for details. * For target `LiTEN`: See :py:class:`LiTENConfig ` for details. * For target `Equiformer V2`: See :py:class:`EquiformerV2Config ` for details. * For target `UMA`: See :py:class:`UMAConfig ` for details. The `dataset` section contains the following parameters: * `info_file`: The path to the dataset information JSON file generated by :py:class:`GraphDatasetBuilder `. It is typically in the `checkpoints` directory. If exists, dataset information computation is skipped. * Reader specific parameters: See :py:class:`ChemicalDatasetsConfig ` for details. * Dataset builder specific parameters: See :py:class:`GraphDatasetBuilderConfig ` for details. The `train` section contains the following parameters: * `loss`: The name of the loss function class in :py:const:`dipm.loss.KNOWN_LOSSES` to use. Not case sensitive or underline sensitive. `mse` or `mse_loss` are both valid. * `extended_metrics`: Whether to include an extended list of metrics such as `mae_e_per_atom` and `learning_rate`. Default is `False`. * `parallel`: Whether to use parallel training. By default, it is determined by the number of available devices. * `dtype`: The computation data type used in mixed precision training. Default is the same as the parameters data type (`float32`). * `save_path`: The path to save the final model weights. Must end with `.safetensors`. * `stage_splits`: If provided, the training loop will perform two stages of training, where the the energy weight is small in the first stage and increased to a larger value in the second stage. This parameter specifies the portion of datasets used in each stage. * `energy_weights`: The scheduled energy weights in the two stage training. * `forces_weights`: The scheduled forces weights in the two stage training. * `tensorboard`: The project name for TensorBoard. If not provided, TensorBoard logging is disabled. Logs will be saved in `./logs/project_name/datetime/`. * `wandb`: The project name for Weights and Biases. If not provided, W&B logging is disabled. Logs will be saved in `./logs/project_name/`. * Train loop specific parameters: See :py:class:`TrainingLoopConfig ` for details. * Optimizer specific parameters: See :py:class:`OptimizerConfig ` for details. * IO handler specific parameters: See :py:class:`TrainingIOHandlerConfig ` for details. Example ------- Here is an example of a `scripts/train.yaml` file: .. code-block:: yaml model: target: EquiformerV2 # Model class name in dipm.models num_layers: 8 # Defined in EquiformerV2Config num_rbf: 512 # Defined in EquiformerV2Config lmax: 4 # Defined in EquiformerV2Config grid_resolution: 18 # Defined in EquiformerV2Config avg_num_neighbors: null # Defined in EquiformerV2Config, default to Meta IS2RE value avg_num_nodes: null # Defined in EquiformerV2Config, default to Meta IS2RE value seed: 42 # Random seed for model initialization dataset: train_dataset_paths: ./datasets/train # Defined in ChemicalDatasetsConfig valid_dataset_paths: ./datasets/val # Defined in ChemicalDatasetsConfig test_dataset_paths: ./datasets/test # Defined in ChemicalDatasetsConfig graph_cutoff_angstrom: 5.0 # Defined in GraphDatasetBuilderConfig batch_size: 16 # Defined in GraphDatasetBuilderConfig info_file: ./checkpoints/dataset_info.json # Skip info calculation if exists train: num_epochs: 10 # Defined in TrainingLoopConfig final_learning_rate: 1e-4 # Defined in OptimizerConfig dtype: bfloat16 # Computation dtype in mixed precision training loss: MSELoss # Loss class name in dipm.loss local_model_output_dir: ./checkpoints # Defined in TrainingIOHandlerConfig restore_checkpoint_if_exists: true # Defined in TrainingIOHandlerConfig restore_optimizer_state: true # Defined in TrainingIOHandlerConfig save_path: ./equiformerv2.safetensors # Final save file name stage_splits: [0.7, 0.3] # Two stage training energy_weights: [1.0, 25.0] # Schedule for energy weight in stage 1 and 2 forces_weights: [25.0, 1.0] # Schedule for forces weight in stage 1 and 2 tensorboard: ./runs/equiformerv2 # Log directory for TensorBoard wandb: equiformerv2_train # Project name for Weights and Biases Advanced -------- See :ref:`training_advanced` for training with custom codes.