Commonly used layers

class dipm.layers.activations.Activation(*values)

Supported activation functions:

Options are: TANH = "tanh", SILU = "silu", RELU = "relu", ELU = "elu", BETA_SWISH = "beta_swish", SIGMOID = "sigmoid", and NONE = "none".

dipm.layers.activations.get_activation_fn(act: Activation | str, features: int = -1, *, dtype: str | type[Any] | dtype | SupportsDType | Any | None = None, param_dtype: str | type[Any] | dtype | SupportsDType | Any = <class 'jax.numpy.float32'>, rngs: Rngs | None = None) Callable[[Array], Array]

Parse activation function among available options.

See Activation.

Parameters:
  • act – Activation type.

  • features (optional) – Number of features only for BetaSwish.

  • dtype (optional) – Data type of BetaSwish during computation.

  • param_dtype (optional) – Dtype of BetaSwish parameters.

  • rngs (optional) – Only for BetaSwish.

class dipm.layers.normalizations.VecNormType(*values)

Options for the VecLayerNorm of the ViSNet model.

dipm.layers.normalizations.get_veclayernorm_fn(norm_type: VecNormType | str, eps: float = 1e-12)
class dipm.layers.radial_basis.RadialBasis(*values)

Radial basis option(s).