Commonly used layers

class dipm.layers.activations.Activation(*values)

Supported activation functions:

Options are: TANH = "tanh", SILU = "silu", RELU = "relu", ELU = "elu", BETA_SWISH = "beta_swish", SIGMOID = "sigmoid", and NONE = "none".

dipm.layers.activations.get_activation_fn(act: ~dipm.layers.activations.Activation | str, features: int = -1, *, dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType | ~typing.Any | None = None, param_dtype: str | type[~typing.Any] | ~numpy.dtype | ~jax._src.typing.SupportsDType | ~typing.Any = <class 'jax.numpy.float32'>, rngs: ~flax.nnx.rnglib.Rngs | None = None) Callable[[Array], Array]

Parse activation function among available options.

See Activation.

Parameters:
  • act – Activation type.

  • features (optional) – Number of features only for BetaSwish.

  • dtype (optional) – Data type of BetaSwish during computation.

  • param_dtype (optional) – Dtype of BetaSwish parameters.

  • rngs (optional) – Only for BetaSwish.

class dipm.layers.normalizations.VecNormType(*values)

Options for the VecLayerNorm of the ViSNet model.

dipm.layers.normalizations.get_veclayernorm_fn(norm_type: VecNormType | str, eps: float = 1e-12)
class dipm.layers.radial_basis.RadialBasis(*values)

Radial basis option(s). For the moment, only BESSEL = "bessel" exists.

dipm.layers.radial_basis.get_radial_basis_fn(basis: RadialBasis | str) Callable

Parse RadialBasis parameter among available options.

See RadialBasis.

class dipm.layers.radial_embeddings.RadialEnvelope(*values)

Radial envelope options. For the moment, POLYNOMIAL = "polynomial_envelope" and SOFT = "soft_envelope" exist.

dipm.layers.radial_embeddings.get_radial_envelope_cls(envelope: RadialEnvelope | str) Callable

Parse RadialEnvelope parameter among available options.

See RadialEnvelope.