Base class¶
- class dipm.models.force_model.ForceModel(*args: Any, **kwargs: Any)¶
Base class for GNN node-wise energy models.
Energy models deriving from this class return node-wise contributions to the total energy, from the edge vectors of a graph, the atomic species of the nodes, and the edges themselves passed as
sendersandreceiversindices.Our MLIP models are validated with Pydantic, and hold a reference to their
.Configclass describing the set of hyperparameters.All subclasses of ForceModel must call super.__init__() and pass in an additional nnx.Rngs parameter during initialization
The
jax.Arrayconstant should be defined in__call__instead of in__init__, orjax.jitwill treat it as a runtime buffer rather than a compile time literal. In case of duplicate creation when JIT is disabled, you can usefunctools.cacheto cache it.To support direct forces prediction, you should specify
force_head_prefix(str) in the class’s constants. The prefix should be the state dict key prefix for the forces head parameters.To support dropping unseen elements while loading the model, you should specify the
embedding_layer_regexp(re.Pattern) attribute in the class’s constants. The pattern will be used to match the embedding layer parameters of shape (num_species, …) and drop the corresponding rows of the that should be modified.The number of elements (atomic species descriptors) allowed will always be inferred from the atomic energies map in the dataset info.
- __call__(edge_vectors: Array, node_species: Array, senders: Array, receivers: Array, n_node: Array) Array¶
Compute node-wise energy summands. This function must be overridden by the implementation of
ForceModel.