onsagernet._activations

Custom activations

This module contains custom activation functions.

  1"""
  2# Custom activations
  3
  4This module contains custom activation functions.
  5"""
  6
  7import jax
  8import jax.numpy as jnp
  9
 10# ------------------------- Typing imports ------------------------- #
 11
 12from jax import Array
 13from jax.typing import ArrayLike
 14from typing import Callable
 15
 16
 17# ------------------------------------------------------------------ #
 18#                     Custom activation functions                    #
 19# ------------------------------------------------------------------ #
 20
 21
 22@jax.jit
 23def recu(x: ArrayLike) -> Array:
 24    r"""Rectified cubic unit activation function.
 25
 26    Implements the activation function
 27    $$
 28        x \mapsto
 29        \begin{cases}
 30            x^3 / 3 & \text{if } x \in [0, 1), \qquad \\
 31            x - 2/3 & \text{if } x \in [1, \infty).
 32        \end{cases}
 33    $$
 34
 35    Args:
 36        x (ArrayLike): inputs
 37
 38    Returns:
 39        Array: activated inputs
 40    """
 41    cubic_part = x**3 / 3
 42    linear_part = x - 2 / 3
 43    return jnp.where(x < 0, 0, jnp.where(x < 1, cubic_part, linear_part))
 44
 45
 46@jax.jit
 47def srequ(x: ArrayLike) -> Array:
 48    r"""Shifted rectified quadratic unit activation function.
 49
 50    Implements the activation function
 51    $$
 52        x \mapsto
 53        \max(0, x)^2 - \max(0, x - 0.5)^2.
 54    $$
 55
 56    Args:
 57        x (ArrayLike): inputs
 58
 59    Returns:
 60        Array: activated inputs
 61    """
 62    return jnp.maximum(0, x) ** 2 - jnp.maximum(0, x - 0.5) ** 2
 63
 64
 65# ------------------------------------------------------------------ #
 66#                               Helpers                              #
 67# ------------------------------------------------------------------ #
 68
 69CUSTOM_ACTIVATIONS = {
 70    "recu": recu,
 71    "srequ": srequ,
 72}
 73
 74
 75def get_activation(name: str) -> Callable[[ArrayLike], Array]:
 76    """Get the activation function by name.
 77    First checks if the activation function is a custom activation, then tries to get it from `jax.nn`.
 78
 79    Args:
 80        name (str): name of the activation function
 81
 82    Raises:
 83        ValueError: If the activation function is not found in custom activations or `jax.nn`
 84        TypeError: If the activation function is not callable
 85
 86    Returns:
 87        _type_: activation function
 88    """
 89
 90    # Check custom activations first
 91    if name in CUSTOM_ACTIVATIONS:
 92        activation_function = CUSTOM_ACTIVATIONS[name]
 93    else:
 94        # Try getting the activation function from jax.nn
 95        try:
 96            activation_function = getattr(jax.nn, name)
 97        except AttributeError:
 98            raise ValueError(
 99                f"Activation function '{name}' not found in custom activations or jax.nn"
100            )
101
102    # Check if the result is callable (i.e., a function)
103    if not callable(activation_function):
104        raise TypeError(f"The activation function '{name}' is not callable")
105
106    return activation_function
@jax.jit
def recu( x: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]) -> jax.Array:
23@jax.jit
24def recu(x: ArrayLike) -> Array:
25    r"""Rectified cubic unit activation function.
26
27    Implements the activation function
28    $$
29        x \mapsto
30        \begin{cases}
31            x^3 / 3 & \text{if } x \in [0, 1), \qquad \\
32            x - 2/3 & \text{if } x \in [1, \infty).
33        \end{cases}
34    $$
35
36    Args:
37        x (ArrayLike): inputs
38
39    Returns:
40        Array: activated inputs
41    """
42    cubic_part = x**3 / 3
43    linear_part = x - 2 / 3
44    return jnp.where(x < 0, 0, jnp.where(x < 1, cubic_part, linear_part))

Rectified cubic unit activation function.

Implements the activation function $$ x \mapsto \begin{cases} x^3 / 3 & \text{if } x \in [0, 1), \qquad \ x - 2/3 & \text{if } x \in [1, \infty). \end{cases} $$

Arguments:
  • x (ArrayLike): inputs
Returns:

Array: activated inputs

@jax.jit
def srequ( x: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]) -> jax.Array:
47@jax.jit
48def srequ(x: ArrayLike) -> Array:
49    r"""Shifted rectified quadratic unit activation function.
50
51    Implements the activation function
52    $$
53        x \mapsto
54        \max(0, x)^2 - \max(0, x - 0.5)^2.
55    $$
56
57    Args:
58        x (ArrayLike): inputs
59
60    Returns:
61        Array: activated inputs
62    """
63    return jnp.maximum(0, x) ** 2 - jnp.maximum(0, x - 0.5) ** 2

Shifted rectified quadratic unit activation function.

Implements the activation function $$ x \mapsto \max(0, x)^2 - \max(0, x - 0.5)^2. $$

Arguments:
  • x (ArrayLike): inputs
Returns:

Array: activated inputs

CUSTOM_ACTIVATIONS = {'recu': <PjitFunction of <function recu>>, 'srequ': <PjitFunction of <function srequ>>}
def get_activation( name: str) -> Callable[[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]], jax.Array]:
 76def get_activation(name: str) -> Callable[[ArrayLike], Array]:
 77    """Get the activation function by name.
 78    First checks if the activation function is a custom activation, then tries to get it from `jax.nn`.
 79
 80    Args:
 81        name (str): name of the activation function
 82
 83    Raises:
 84        ValueError: If the activation function is not found in custom activations or `jax.nn`
 85        TypeError: If the activation function is not callable
 86
 87    Returns:
 88        _type_: activation function
 89    """
 90
 91    # Check custom activations first
 92    if name in CUSTOM_ACTIVATIONS:
 93        activation_function = CUSTOM_ACTIVATIONS[name]
 94    else:
 95        # Try getting the activation function from jax.nn
 96        try:
 97            activation_function = getattr(jax.nn, name)
 98        except AttributeError:
 99            raise ValueError(
100                f"Activation function '{name}' not found in custom activations or jax.nn"
101            )
102
103    # Check if the result is callable (i.e., a function)
104    if not callable(activation_function):
105        raise TypeError(f"The activation function '{name}' is not callable")
106
107    return activation_function

Get the activation function by name. First checks if the activation function is a custom activation, then tries to get it from jax.nn.

Arguments:
  • name (str): name of the activation function
Raises:
  • ValueError: If the activation function is not found in custom activations or jax.nn
  • TypeError: If the activation function is not callable
Returns:

_type_: activation function