Model evaluation¶
- dipm.training.evaluation.make_evaluation_step(eval_loss_fun: Callable[[Prediction, GraphsTuple, int, bool], tuple[Array, dict[str, Array]]], avg_n_graphs_per_batch: float, should_parallelize: bool = True) Callable[[ForceFieldPredictor, GraphsTuple, int], dict[str, ndarray]]¶
Creates the evaluation step function.
- Parameters:
eval_loss_fun – The loss function for the evaluation.
avg_n_graphs_per_batch – Average number of graphs per batch used for reweighting of metrics.
should_parallelize – Whether to apply data parallelization across multiple devices.
- Returns:
The evaluation step function.
- dipm.training.evaluation.run_evaluation(predictor: ForceFieldPredictor, evaluation_step: Callable[[ForceFieldPredictor, GraphsTuple, int], dict[str, ndarray]], eval_dataset: GraphDataset | PrefetchIterator, epoch_number: int, io_handler: TrainingIOHandler, devices: list[Device] | None = None, is_test_set: bool = False, extended_to_log: dict[str, Any] | None = None) float¶
Runs a model evaluation on a given dataset.
- Parameters:
predictor – The predictor to use.
evaluation_step – The evaluation step function.
eval_dataset – The dataset on which to evaluate the model.
params – The parameters to use for the evaluation.
epoch_number – The current epoch number.
io_handler – The IO handler class that handles the logging of the result.
devices – The jax devices. It can be None if not run in parallel (default).
is_test_set – Whether the evaluation is done on the test set, i.e., not during a training run. By default, this is false.
extended_to_log – Additional metrics to log.
- Returns:
The mean loss.