Training step¶
- class dipm.training.training_step.TrainingState(predictor: ForceFieldPredictor, optimizer: Optimizer, ema_tracker: EMATracker, num_steps: TrainingStateVar, acc_steps: TrainingStateVar)¶
Represents the state of training.
- model¶
ForceField Model.
- optimizer¶
NNX optimizer.
- Type:
flax.nnx.training.optimizer.Optimizer
- ema_tracker¶
Exponentially weighted average tracker.
- num_steps¶
The number of training steps taken.
- Type:
dipm.training.training_step.TrainingStateVar
- acc_steps¶
The number of gradient accumulation steps taken; resets to 0 after each optimizer step.
- Type:
dipm.training.training_step.TrainingStateVar
- dipm.training.training_step.make_train_step(loss_fun: Callable[[Prediction, GraphsTuple, int, bool], tuple[Array, dict[str, Array]]], avg_n_graphs_per_batch: float, num_gradient_accumulation_steps: int = 1, should_parallelize: bool = True) Callable[[TrainingState, GraphsTuple, Rngs, int], dict]¶
Create a training step function to optimize model params using gradients.
- Parameters:
loss_fun – A function that computes the loss from predictions, a reference labelled graph, and the epoch number.
avg_n_graphs_per_batch – Average number of graphs per batch used for reweighting of metrics.
num_gradient_accumulation_steps – The number of gradient accumulation steps before a parameter update is performed. Defaults to 1, implying immediate updates.
should_parallelize – Whether to apply pmap.
- Returns:
A function that takes the current training state and a batch of data as input, and returns the updated training state along with training metrics.