Training step

class dipm.training.training_step.TrainingState(predictor: ForceFieldPredictor, optimizer: Optimizer, ema_tracker: EMATracker, num_steps: TrainingStateVar, acc_steps: TrainingStateVar)

Represents the state of training.

model

ForceField Model.

optimizer

NNX optimizer.

Type:

flax.nnx.training.optimizer.Optimizer

ema_tracker

Exponentially weighted average tracker.

Type:

dipm.training.optimizer.EMATracker

num_steps

The number of training steps taken.

Type:

dipm.training.training_step.TrainingStateVar

acc_steps

The number of gradient accumulation steps taken; resets to 0 after each optimizer step.

Type:

dipm.training.training_step.TrainingStateVar

dipm.training.training_step.make_train_step(loss_fun: Callable[[Prediction, GraphsTuple, int, bool], tuple[Array, dict[str, Array]]], avg_n_graphs_per_batch: float, num_gradient_accumulation_steps: int = 1, should_parallelize: bool = True) Callable[[TrainingState, GraphsTuple, Rngs, int], dict]

Create a training step function to optimize model params using gradients.

Parameters:
  • loss_fun – A function that computes the loss from predictions, a reference labelled graph, and the epoch number.

  • avg_n_graphs_per_batch – Average number of graphs per batch used for reweighting of metrics.

  • num_gradient_accumulation_steps – The number of gradient accumulation steps before a parameter update is performed. Defaults to 1, implying immediate updates.

  • should_parallelize – Whether to apply pmap.

Returns:

A function that takes the current training state and a batch of data as input, and returns the updated training state along with training metrics.