onsagernet._activations
Custom activations
This module contains custom activation functions.
1""" 2# Custom activations 3 4This module contains custom activation functions. 5""" 6 7import jax 8import jax.numpy as jnp 9 10# ------------------------- Typing imports ------------------------- # 11 12from jax import Array 13from jax.typing import ArrayLike 14from typing import Callable 15 16 17# ------------------------------------------------------------------ # 18# Custom activation functions # 19# ------------------------------------------------------------------ # 20 21 22@jax.jit 23def recu(x: ArrayLike) -> Array: 24 r"""Rectified cubic unit activation function. 25 26 Implements the activation function 27 $$ 28 x \mapsto 29 \begin{cases} 30 x^3 / 3 & \text{if } x \in [0, 1), \qquad \\ 31 x - 2/3 & \text{if } x \in [1, \infty). 32 \end{cases} 33 $$ 34 35 Args: 36 x (ArrayLike): inputs 37 38 Returns: 39 Array: activated inputs 40 """ 41 cubic_part = x**3 / 3 42 linear_part = x - 2 / 3 43 return jnp.where(x < 0, 0, jnp.where(x < 1, cubic_part, linear_part)) 44 45 46@jax.jit 47def srequ(x: ArrayLike) -> Array: 48 r"""Shifted rectified quadratic unit activation function. 49 50 Implements the activation function 51 $$ 52 x \mapsto 53 \max(0, x)^2 - \max(0, x - 0.5)^2. 54 $$ 55 56 Args: 57 x (ArrayLike): inputs 58 59 Returns: 60 Array: activated inputs 61 """ 62 return jnp.maximum(0, x) ** 2 - jnp.maximum(0, x - 0.5) ** 2 63 64 65# ------------------------------------------------------------------ # 66# Helpers # 67# ------------------------------------------------------------------ # 68 69CUSTOM_ACTIVATIONS = { 70 "recu": recu, 71 "srequ": srequ, 72} 73 74 75def get_activation(name: str) -> Callable[[ArrayLike], Array]: 76 """Get the activation function by name. 77 First checks if the activation function is a custom activation, then tries to get it from `jax.nn`. 78 79 Args: 80 name (str): name of the activation function 81 82 Raises: 83 ValueError: If the activation function is not found in custom activations or `jax.nn` 84 TypeError: If the activation function is not callable 85 86 Returns: 87 _type_: activation function 88 """ 89 90 # Check custom activations first 91 if name in CUSTOM_ACTIVATIONS: 92 activation_function = CUSTOM_ACTIVATIONS[name] 93 else: 94 # Try getting the activation function from jax.nn 95 try: 96 activation_function = getattr(jax.nn, name) 97 except AttributeError: 98 raise ValueError( 99 f"Activation function '{name}' not found in custom activations or jax.nn" 100 ) 101 102 # Check if the result is callable (i.e., a function) 103 if not callable(activation_function): 104 raise TypeError(f"The activation function '{name}' is not callable") 105 106 return activation_function
@jax.jit
def
recu( x: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]) -> jax.Array:
23@jax.jit 24def recu(x: ArrayLike) -> Array: 25 r"""Rectified cubic unit activation function. 26 27 Implements the activation function 28 $$ 29 x \mapsto 30 \begin{cases} 31 x^3 / 3 & \text{if } x \in [0, 1), \qquad \\ 32 x - 2/3 & \text{if } x \in [1, \infty). 33 \end{cases} 34 $$ 35 36 Args: 37 x (ArrayLike): inputs 38 39 Returns: 40 Array: activated inputs 41 """ 42 cubic_part = x**3 / 3 43 linear_part = x - 2 / 3 44 return jnp.where(x < 0, 0, jnp.where(x < 1, cubic_part, linear_part))
Rectified cubic unit activation function.
Implements the activation function $$ x \mapsto \begin{cases} x^3 / 3 & \text{if } x \in [0, 1), \qquad \ x - 2/3 & \text{if } x \in [1, \infty). \end{cases} $$
Arguments:
- x (ArrayLike): inputs
Returns:
Array: activated inputs
@jax.jit
def
srequ( x: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]) -> jax.Array:
47@jax.jit 48def srequ(x: ArrayLike) -> Array: 49 r"""Shifted rectified quadratic unit activation function. 50 51 Implements the activation function 52 $$ 53 x \mapsto 54 \max(0, x)^2 - \max(0, x - 0.5)^2. 55 $$ 56 57 Args: 58 x (ArrayLike): inputs 59 60 Returns: 61 Array: activated inputs 62 """ 63 return jnp.maximum(0, x) ** 2 - jnp.maximum(0, x - 0.5) ** 2
Shifted rectified quadratic unit activation function.
Implements the activation function $$ x \mapsto \max(0, x)^2 - \max(0, x - 0.5)^2. $$
Arguments:
- x (ArrayLike): inputs
Returns:
Array: activated inputs
CUSTOM_ACTIVATIONS =
{'recu': <PjitFunction of <function recu>>, 'srequ': <PjitFunction of <function srequ>>}
def
get_activation( name: str) -> Callable[[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]], jax.Array]:
76def get_activation(name: str) -> Callable[[ArrayLike], Array]: 77 """Get the activation function by name. 78 First checks if the activation function is a custom activation, then tries to get it from `jax.nn`. 79 80 Args: 81 name (str): name of the activation function 82 83 Raises: 84 ValueError: If the activation function is not found in custom activations or `jax.nn` 85 TypeError: If the activation function is not callable 86 87 Returns: 88 _type_: activation function 89 """ 90 91 # Check custom activations first 92 if name in CUSTOM_ACTIVATIONS: 93 activation_function = CUSTOM_ACTIVATIONS[name] 94 else: 95 # Try getting the activation function from jax.nn 96 try: 97 activation_function = getattr(jax.nn, name) 98 except AttributeError: 99 raise ValueError( 100 f"Activation function '{name}' not found in custom activations or jax.nn" 101 ) 102 103 # Check if the result is callable (i.e., a function) 104 if not callable(activation_function): 105 raise TypeError(f"The activation function '{name}' is not callable") 106 107 return activation_function
Get the activation function by name.
First checks if the activation function is a custom activation, then tries to get it from jax.nn
.
Arguments:
- name (str): name of the activation function
Raises:
- ValueError: If the activation function is not found in custom activations or
jax.nn
- TypeError: If the activation function is not callable
Returns:
_type_: activation function