Training Loop

class dipm.training.configs.TrainingLoopConfig(*, num_epochs: Annotated[int, Gt(gt=0)], num_gradient_accumulation_steps: Annotated[int, Gt(gt=0)] = 1, ema_decay: Annotated[float, Ge(ge=0.0), Le(le=1.0)] = 0.99, use_ema_params_for_eval: bool = True, run_eval_at_start: bool = True, log_interval: Annotated[int, FieldInfo(annotation=NoneType, required=True, metadata=[Gt(gt=0)])] | None = None)

Pydantic config holding all settings related to the TrainingLoop class.

Use valid_num_to_load instead of eval_num_graphs.

num_epochs

Number of epoch to run.

Type:

int

num_gradient_accumulation_steps

Number of gradient steps to accumulate before taking an optimizer step. Default is 1.

Type:

int

ema_decay

The EMA decay rate, by default set to 0.99.

Type:

float

use_ema_params_for_eval

Whether to use the EMA parameters for evaluation, set to True by default.

Type:

bool

run_eval_at_start

Whether to run an evaluation on the validation set before we start the first epoch. By default, it is set to True.

Type:

bool

log_interval

Number of steps to log the metrics. Default is None, which means logging once per epoch.

Type:

int | None

__init__(**data: Any) None

Create a new model by parsing and validating input data from keyword arguments.

Raises [ValidationError][pydantic_core.ValidationError] if the input data cannot be validated to form a valid model.

self is explicitly positional-only to allow self as a field name.

class dipm.training.training_loop.TrainingLoop(train_dataset: GraphDataset | PrefetchIterator, validation_dataset: GraphDataset | PrefetchIterator, force_field: ForceFieldPredictor, loss: Loss, optimizer: GradientTransformation, config: TrainingLoopConfig, io_handler: TrainingIOHandler | None = None, should_parallelize: bool = False)

Training loop class.

It implements only the loop based on its inputs but does not construct any auxiliary objects within it. For example, the model, dataset, and optimizer must be passed to this function from the outside.

training_state

The training state.

best_model

The current state of the force field model with the best parameters so far.

Type:

ForceFieldPredictor

__init__(train_dataset: GraphDataset | PrefetchIterator, validation_dataset: GraphDataset | PrefetchIterator, force_field: ForceFieldPredictor, loss: Loss, optimizer: GradientTransformation, config: TrainingLoopConfig, io_handler: TrainingIOHandler | None = None, should_parallelize: bool = False) None

Constructor.

Parameters:
  • train_dataset – The training dataset as either a GraphDataset or a PrefetchIterator.

  • validation_dataset – The validation dataset as either a GraphDataset or a PrefetchIterator.

  • force_field – The force field model holding at least the initial parameters and a dataset info object.

  • loss – The loss, which it is derived from the Loss base class.

  • optimizer – The optimizer (based on optax).

  • config – The training loop pydantic config.

  • io_handler – The IO handler which handles checkpointing and (specialized) logging. This is an optional argument. The default is None, which means that a default IO handler will be set up which does not include checkpointing but some very basic metrics logging.

  • should_parallelize – Whether to parallelize (using data parallelization) across multiple devices. The default is False.

run(rngs: Rngs | None = None) None

Runs the training loop.

The final training state can be accessed via its member variable.

Parameters:

rngs – The random number generators for training. Only used if the model contains dropout or other stochastic layers. Default to nnx.Rngs(42).

test(test_dataset: GraphDataset | PrefetchIterator) None

Run the evaluation on the test dataset with the best parameters seen so far.

Parameters:

test_dataset – The test dataset as either a GraphDataset or a PrefetchIterator.