Training Loop¶
- class dipm.training.configs.TrainingLoopConfig(*, num_epochs: Annotated[int, Gt(gt=0)], num_gradient_accumulation_steps: Annotated[int, Gt(gt=0)] = 1, ema_decay: Annotated[float, Ge(ge=0.0), Le(le=1.0)] = 0.99, use_ema_params_for_eval: bool = True, run_eval_at_start: bool = True, log_interval: Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Gt(gt=0)])] | None = None)¶
Pydantic config holding all settings related to the
TrainingLoopclass.Use
valid_num_to_loadinstead ofeval_num_graphs.- num_epochs¶
Number of epoch to run.
- Type:
int
- num_gradient_accumulation_steps¶
Number of gradient steps to accumulate before taking an optimizer step. Default is 1.
- Type:
int
- ema_decay¶
The EMA decay rate, by default set to 0.99.
- Type:
float
- use_ema_params_for_eval¶
Whether to use the EMA parameters for evaluation, set to
Trueby default.- Type:
bool
- run_eval_at_start¶
Whether to run an evaluation on the validation set before we start the first epoch. By default, it is set to
True.- Type:
bool
- log_interval¶
Number of steps to log the metrics. Default is
None, which means logging once per epoch.- Type:
int | None
- __init__(**data: Any) None¶
Create a new model by parsing and validating input data from keyword arguments.
Raises [
ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.selfis explicitly positional-only to allowselfas a field name.
- class dipm.training.training_loop.TrainingLoop(train_dataset: GraphDataset | PrefetchIterator, validation_dataset: GraphDataset | PrefetchIterator, force_field: ForceFieldPredictor, loss: Loss, optimizer: GradientTransformation, config: TrainingLoopConfig, io_handler: TrainingIOHandler | None = None, should_parallelize: bool = False)¶
Training loop class.
It implements only the loop based on its inputs but does not construct any auxiliary objects within it. For example, the model, dataset, and optimizer must be passed to this function from the outside.
- training_state¶
The training state.
- best_model¶
The current state of the force field model with the best parameters so far.
- Type:
- __init__(train_dataset: GraphDataset | PrefetchIterator, validation_dataset: GraphDataset | PrefetchIterator, force_field: ForceFieldPredictor, loss: Loss, optimizer: GradientTransformation, config: TrainingLoopConfig, io_handler: TrainingIOHandler | None = None, should_parallelize: bool = False) None¶
Constructor.
- Parameters:
train_dataset – The training dataset as either a GraphDataset or a PrefetchIterator.
validation_dataset – The validation dataset as either a GraphDataset or a PrefetchIterator.
force_field – The force field model holding at least the initial parameters and a dataset info object.
loss – The loss, which it is derived from the
Lossbase class.optimizer – The optimizer (based on optax).
config – The training loop pydantic config.
io_handler – The IO handler which handles checkpointing and (specialized) logging. This is an optional argument. The default is
None, which means that a default IO handler will be set up which does not include checkpointing but some very basic metrics logging.should_parallelize – Whether to parallelize (using data parallelization) across multiple devices. The default is
False.
- run(rngs: Rngs | None = None) None¶
Runs the training loop.
The final training state can be accessed via its member variable.
- Parameters:
rngs – The random number generators for training. Only used if the model contains dropout or other stochastic layers. Default to
nnx.Rngs(42).
- test(test_dataset: GraphDataset | PrefetchIterator) None¶
Run the evaluation on the test dataset with the best parameters seen so far.
- Parameters:
test_dataset – The test dataset as either a GraphDataset or a PrefetchIterator.