Base class¶
- class dipm.models.force_model.PrecallInterface¶
Interface for pre-call functions.
- precall(**kwargs) dict[str, Any]¶
Pre-call function to be called before the forward pass. To reduce possible errors, please always use keyword arguments and use a
kwargsto catch extra arguments.
- cache(**_kwargs) dict[str, Any] | None¶
Return a dict of values to be cached for the forward pass.
- static context_handler(forward_fn: Callable) Callable¶
Wraps __call__() to handle precall context.
- class dipm.models.force_model.ForceModel(*args: Any, **kwargs: Any)¶
Base class for GNN node-wise energy models.
Energy models deriving from this class return node-wise contributions to the total energy, from the edge vectors of a graph, the atomic species of the nodes, and the edges themselves passed as
sendersandreceiversindices.Our MLIP models are validated with Pydantic, and hold a reference to their
.Configclass describing the set of hyperparameters.All subclasses of ForceModel must call super.__init__() and pass in an additional nnx.Rngs parameter during initialization
NOTE: Don’t keep any jax.Array that need to be used in the forward pass directly in the class attributes, as they won’t be replicated when parallelized. Instead, use the nnx.Cache to wrap them.
To support direct forces prediction, you should specify
force_head_prefix(str) in the class’s constants. The prefix should be the state dict key prefix for the forces head parameters.To support dropping unseen elements while loading the model, you should specify the
embedding_layer_regexp(re.Pattern) attribute in the class’s constants. The pattern will be used to match the embedding layer parameters of shape (num_species, …) and drop the corresponding rows of the that should be modified.The number of elements (atomic species descriptors) allowed will always be infered from the atomic energies map in the dataset info.
- __call__(edge_vectors: Array, node_species: Array, senders: Array, receivers: Array, n_node: Array) Array¶
Compute node-wise energy summands. This function must be overridden by the implementation of
ForceModel.