Base class

class dipm.models.force_model.ForceModel(*args: Any, **kwargs: Any)

Base class for GNN node-wise energy models.

Energy models deriving from this class return node-wise contributions to the total energy, from the edge vectors of a graph, the atomic species of the nodes, and the edges themselves passed as senders and receivers indices.

Our MLIP models are validated with Pydantic, and hold a reference to their .Config class describing the set of hyperparameters.

All subclasses of ForceModel must call super.__init__() and pass in an additional nnx.Rngs parameter during initialization

The jax.Array constant should be defined in __call__ instead of in __init__, or jax.jit will treat it as a runtime buffer rather than a compile time literal. In case of duplicate creation when JIT is disabled, you can use functools.cache to cache it.

To support direct forces prediction, you should specify force_head_prefix (str) in the class’s constants. The prefix should be the state dict key prefix for the forces head parameters.

To support dropping unseen elements while loading the model, you should specify the embedding_layer_regexp (re.Pattern) attribute in the class’s constants. The pattern will be used to match the embedding layer parameters of shape (num_species, …) and drop the corresponding rows of the that should be modified.

The number of elements (atomic species descriptors) allowed will always be inferred from the atomic energies map in the dataset info.

__call__(edge_vectors: Array, node_species: Array, senders: Array, receivers: Array, n_node: Array) Array

Compute node-wise energy summands. This function must be overridden by the implementation of ForceModel.