onsagernet._layers

Custom layers as equinox modules

 1"""Custom layers as equinox modules"""
 2
 3import jax
 4import equinox as eqx
 5from math import sqrt
 6from ._utils import default_floating_dtype
 7from jax import Array
 8from jax.random import PRNGKey
 9from typing import Optional
10from jax.typing import DTypeLike
11
12
13class ConstantLayer(eqx.Module):
14    """Constant layer."""
15
16    weight: Array
17    dim: int
18
19    def __init__(
20        self, dim: int, key: PRNGKey, dtype: Optional[DTypeLike] = None
21    ) -> None:
22        """Constant layer.
23
24        Returns a constant, trainable vector,
25        similar to the bias term in a neural network,
26        that is the same size as the input.
27
28        Args:
29            dim (int): dimension of the input space
30            key (PRNGKey): random key
31            dtype (Optional[DTypeLike], optional): data type. Defaults to None.
32        """
33        dtype = default_floating_dtype() if dtype is None else dtype
34        lim = 1 / sqrt(dim)
35        shape = (dim,)
36        self.dim = dim
37        self.weight = jax.random.uniform(
38            key, shape, minval=-lim, maxval=lim, dtype=dtype
39        )
40
41    def __call__(self, *, key: Optional[PRNGKey] = None) -> Array:
42        return self.weight
class ConstantLayer(equinox._module.Module):
14class ConstantLayer(eqx.Module):
15    """Constant layer."""
16
17    weight: Array
18    dim: int
19
20    def __init__(
21        self, dim: int, key: PRNGKey, dtype: Optional[DTypeLike] = None
22    ) -> None:
23        """Constant layer.
24
25        Returns a constant, trainable vector,
26        similar to the bias term in a neural network,
27        that is the same size as the input.
28
29        Args:
30            dim (int): dimension of the input space
31            key (PRNGKey): random key
32            dtype (Optional[DTypeLike], optional): data type. Defaults to None.
33        """
34        dtype = default_floating_dtype() if dtype is None else dtype
35        lim = 1 / sqrt(dim)
36        shape = (dim,)
37        self.dim = dim
38        self.weight = jax.random.uniform(
39            key, shape, minval=-lim, maxval=lim, dtype=dtype
40        )
41
42    def __call__(self, *, key: Optional[PRNGKey] = None) -> Array:
43        return self.weight

Constant layer.

ConstantLayer( dim: int, key: <function PRNGKey>, dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, NoneType] = None)
20    def __init__(
21        self, dim: int, key: PRNGKey, dtype: Optional[DTypeLike] = None
22    ) -> None:
23        """Constant layer.
24
25        Returns a constant, trainable vector,
26        similar to the bias term in a neural network,
27        that is the same size as the input.
28
29        Args:
30            dim (int): dimension of the input space
31            key (PRNGKey): random key
32            dtype (Optional[DTypeLike], optional): data type. Defaults to None.
33        """
34        dtype = default_floating_dtype() if dtype is None else dtype
35        lim = 1 / sqrt(dim)
36        shape = (dim,)
37        self.dim = dim
38        self.weight = jax.random.uniform(
39            key, shape, minval=-lim, maxval=lim, dtype=dtype
40        )

Constant layer.

Returns a constant, trainable vector, similar to the bias term in a neural network, that is the same size as the input.

Arguments:
  • dim (int): dimension of the input space
  • key (PRNGKey): random key
  • dtype (Optional[DTypeLike], optional): data type. Defaults to None.
weight: jax.Array
dim: int