onsagernet._layers
Custom layers as equinox modules
1"""Custom layers as equinox modules""" 2 3import jax 4import equinox as eqx 5from math import sqrt 6from ._utils import default_floating_dtype 7from jax import Array 8from jax.random import PRNGKey 9from typing import Optional 10from jax.typing import DTypeLike 11 12 13class ConstantLayer(eqx.Module): 14 """Constant layer.""" 15 16 weight: Array 17 dim: int 18 19 def __init__( 20 self, dim: int, key: PRNGKey, dtype: Optional[DTypeLike] = None 21 ) -> None: 22 """Constant layer. 23 24 Returns a constant, trainable vector, 25 similar to the bias term in a neural network, 26 that is the same size as the input. 27 28 Args: 29 dim (int): dimension of the input space 30 key (PRNGKey): random key 31 dtype (Optional[DTypeLike], optional): data type. Defaults to None. 32 """ 33 dtype = default_floating_dtype() if dtype is None else dtype 34 lim = 1 / sqrt(dim) 35 shape = (dim,) 36 self.dim = dim 37 self.weight = jax.random.uniform( 38 key, shape, minval=-lim, maxval=lim, dtype=dtype 39 ) 40 41 def __call__(self, *, key: Optional[PRNGKey] = None) -> Array: 42 return self.weight
class
ConstantLayer(equinox._module.Module):
14class ConstantLayer(eqx.Module): 15 """Constant layer.""" 16 17 weight: Array 18 dim: int 19 20 def __init__( 21 self, dim: int, key: PRNGKey, dtype: Optional[DTypeLike] = None 22 ) -> None: 23 """Constant layer. 24 25 Returns a constant, trainable vector, 26 similar to the bias term in a neural network, 27 that is the same size as the input. 28 29 Args: 30 dim (int): dimension of the input space 31 key (PRNGKey): random key 32 dtype (Optional[DTypeLike], optional): data type. Defaults to None. 33 """ 34 dtype = default_floating_dtype() if dtype is None else dtype 35 lim = 1 / sqrt(dim) 36 shape = (dim,) 37 self.dim = dim 38 self.weight = jax.random.uniform( 39 key, shape, minval=-lim, maxval=lim, dtype=dtype 40 ) 41 42 def __call__(self, *, key: Optional[PRNGKey] = None) -> Array: 43 return self.weight
Constant layer.
ConstantLayer( dim: int, key: <function PRNGKey>, dtype: Union[str, type[Any], numpy.dtype, jax._src.typing.SupportsDType, NoneType] = None)
20 def __init__( 21 self, dim: int, key: PRNGKey, dtype: Optional[DTypeLike] = None 22 ) -> None: 23 """Constant layer. 24 25 Returns a constant, trainable vector, 26 similar to the bias term in a neural network, 27 that is the same size as the input. 28 29 Args: 30 dim (int): dimension of the input space 31 key (PRNGKey): random key 32 dtype (Optional[DTypeLike], optional): data type. Defaults to None. 33 """ 34 dtype = default_floating_dtype() if dtype is None else dtype 35 lim = 1 / sqrt(dim) 36 shape = (dim,) 37 self.dim = dim 38 self.weight = jax.random.uniform( 39 key, shape, minval=-lim, maxval=lim, dtype=dtype 40 )
Constant layer.
Returns a constant, trainable vector, similar to the bias term in a neural network, that is the same size as the input.
Arguments:
- dim (int): dimension of the input space
- key (PRNGKey): random key
- dtype (Optional[DTypeLike], optional): data type. Defaults to None.