onsagernet.models

Custom equinox modules for OnsagerNet components

This module contains custom equinox modules for the components of the OnsagerNet model, which are used in various examples provided in the repository.

For new applications, it is suggested to try the simple models here first and then build upon them, adapting if necessary to the specific problem at hand.

  1"""
  2# Custom equinox modules for OnsagerNet components
  3
  4This module contains custom equinox modules for the components of the OnsagerNet model,
  5which are used in various examples provided in the repository.
  6
  7For new applications, it is suggested to try the simple models here first
  8and then build upon them, adapting if necessary to the specific problem at hand.
  9
 10
 11
 12
 13
 14"""
 15
 16import jax
 17import jax.numpy as jnp
 18import equinox as eqx
 19
 20from ._activations import get_activation
 21from ._layers import ConstantLayer
 22
 23# ------------------------- Typing imports ------------------------- #
 24
 25from jax import Array
 26from jax.typing import ArrayLike
 27from typing import Callable
 28from jax.random import PRNGKey
 29
 30
 31# ------------------------------------------------------------------ #
 32#                           Template models                          #
 33# ------------------------------------------------------------------ #
 34
 35
 36class MLP(eqx.Module):
 37    """Multi-layer perceptron."""
 38
 39    layers: list[eqx.nn.Linear]
 40    activation: Callable[[ArrayLike], Array]
 41
 42    def __init__(
 43        self, key: PRNGKey, dim: int, units: list[int], activation: str
 44    ) -> None:
 45        r"""Multi-layer perceptron.
 46
 47        Example:
 48        `mlp = MLP(key=jax.random.PRNGKey(0), dim=2, units=[32, 32, 1], activation='tanh')`
 49        gives a 2-hidden-layer MLP
 50
 51        $$2 \to 32 \to 32 \to 1$$
 52
 53        with tanh activation.
 54
 55        Args:
 56            key (PRNGKey): random key
 57            dim (int): dimension of the input
 58            units (list[int]): layer sizes
 59            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
 60        """
 61        num_layers = len(units)
 62        units = [dim] + units
 63        keys = jax.random.split(key, num_layers)
 64        self.layers = [
 65            eqx.nn.Linear(units[i], units[i + 1], key=key) for i, key in enumerate(keys)
 66        ]
 67        self.activation = get_activation(activation)
 68
 69    def __call__(self, x: ArrayLike) -> Array:
 70        h = x
 71        for layer in self.layers[:-1]:
 72            h = self.activation(layer(h))
 73        output = self.layers[-1](h)
 74        return output
 75
 76
 77# ------------------------------------------------------------------ #
 78#                         Potential networks                         #
 79# ------------------------------------------------------------------ #
 80
 81
 82class PotentialMLP(MLP):
 83    """Potential network based on a multi-layer perceptron."""
 84
 85    alpha: float
 86    dim: int
 87    param_dim: int
 88
 89    def __init__(
 90        self,
 91        key: PRNGKey,
 92        dim: int,
 93        units: list[int],
 94        activation: str,
 95        alpha: float,
 96        param_dim: int = 0,
 97    ) -> None:
 98        r"""Potential network based on a multi-layer perceptron.
 99
100        This implements the potential function
101        $$
102            V(x, args) = \alpha \|(x, args)\|^2 + \text{MLP}(x, args)
103        $$
104        where $x$ is the input and $u$ are additional parameters.
105        The constant $\alpha \geq 0$ is a regularisation term,
106        which gives a quadratic growth to ensure that the potential is integrable.
107        We are tacitly assuming that MLP is of sub-quadratic growth,
108        so only choose activation functions that have this property
109        (most activation functions are either bounded or of linear growth).
110
111        Args:
112            key (PRNGKey): random key
113            dim (int): dimension of the input
114            units (list[int]): layer sizes
115            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
116            alpha (float): regulariser
117            param_dim (int, optional): dimensions of the parameters. Defaults to 0.
118        """
119        self.dim = dim
120        units = units + [1]
121        self.param_dim = param_dim
122        super().__init__(key, dim + param_dim, units, activation)
123        self.alpha = alpha
124
125    def __call__(self, x: ArrayLike, args: ArrayLike) -> Array:
126        if self.param_dim > 0:
127            x = jnp.concatenate([x, args[1:]], axis=0)
128        output = super().__call__(x) + self.alpha * (x @ x)
129        return jnp.squeeze(output)
130
131
132class PotentialResMLP(MLP):
133    r"""Potential network with a residual connection."""
134
135    alpha: float
136    gamma_layer: eqx.nn.Linear
137    dim: int
138    param_dim: int
139
140    def __init__(
141        self,
142        key: PRNGKey,
143        dim: int,
144        units: list[int],
145        activation: str,
146        n_pot: int,
147        alpha: float,
148        param_dim: int = 0,
149    ) -> None:
150        r"""Potential network with a residual connection.
151
152        This implements the modified potential function
153        $$
154            V(x, args) = \alpha \|(x, args)\|^2
155            + \frac{1}{2}
156            \| \text{MLP}(x, args)+ \Gamma (x, args) \|^2
157        $$
158        where
159
160        - $\phi$ is a MLP of dim + param_dim -> n_pot
161        - $\Gamma$ ia matrix of size [n_pot, dim + para_dim]
162        - $\alpha > 0$ is a scalar regulariser
163
164        Args:
165            key (PRNGKey): random key
166            dim (int): dimension of the input
167            units (list[int]): layer sizes
168            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
169            n_pot (int): size of the MLP part of the potential
170            alpha (float): regulariser
171            param_dim (int, optional): dimension of the parameters. Defaults to 0.
172        """
173        self.dim = dim
174        self.param_dim = param_dim
175        units = units + [n_pot]
176        mlp_key, gamma_key = jax.random.split(key)
177        super().__init__(mlp_key, dim + param_dim, units, activation)
178        self.alpha = alpha
179        self.gamma_layer = eqx.nn.Linear(
180            dim + param_dim, n_pot, key=gamma_key, use_bias=False
181        )
182
183    def __call__(self, x: ArrayLike, args: ArrayLike) -> Array:
184        if self.param_dim > 0:
185            x = jnp.concatenate([x, args[1:]], axis=0)
186        output_phi = super().__call__(x)
187        output_gamma = self.gamma_layer(x)
188        output_combined = (output_phi + output_gamma) @ (output_phi + output_gamma)
189        regularisation = self.alpha * (x @ x)
190        return 0.5 * output_combined + regularisation
191
192
193# ------------------------------------------------------------------ #
194#                        Dissipation networks                        #
195# ------------------------------------------------------------------ #
196
197
198class DissipationMatrixMLP(MLP):
199    """Dissipation matrix network based on a multi-layer perceptron."""
200
201    alpha: float
202    is_bounded: bool
203    dim: int
204
205    def __init__(
206        self,
207        key: PRNGKey,
208        dim: int,
209        units: list[int],
210        activation: str,
211        alpha: float,
212        is_bounded: bool = True,
213    ) -> None:
214        r"""Dissipation matrix network based on a multi-layer perceptron.
215
216        The MLP maps $x$ of dimension `dim` to a matrix $L(x)$ of size `dim` x `dim`,
217        and then reshapes it to a `dim` x `dim` matrix.
218        Then, the output matrix is given by
219        $$
220            M(x) = \alpha I + L(x) L(x)^\top.
221        $$
222        If `is_bounded` is set to `True`, then the output is element-wise bounded
223        by applying a `jax.nn.tanh` activation to the output matrix $L$.
224
225        Args:
226            key (PRNGKey): random key
227            dim (int): dimension of the input
228            units (list[int]): layer sizes
229            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
230            alpha (float): regulariser
231            is_bounded (bool, optional): whether to give a element-wise bounded output. Defaults to True.
232        """
233        self.dim = dim
234        units = units + [dim * dim]
235        super().__init__(key, dim, units, activation)
236        self.alpha = alpha
237        self.is_bounded = is_bounded
238
239    def __call__(self, x: ArrayLike) -> Array:
240        L = super().__call__(x).reshape(self.dim, self.dim)
241        if self.is_bounded:
242            L = jax.nn.tanh(L)
243        return self.alpha * jnp.eye(self.dim) + L @ L.T
244
245
246# ------------------------------------------------------------------ #
247#                        Conservation networks                       #
248# ------------------------------------------------------------------ #
249
250
251class ConservationMatrixMLP(MLP):
252    """Conservation matrix network based on a multi-layer perceptron."""
253
254    is_bounded: bool
255    dim: int
256
257    def __init__(
258        self,
259        key: PRNGKey,
260        dim: int,
261        activation: str,
262        units: list[int],
263        is_bounded: bool = True,
264    ) -> None:
265        r"""Conservation matrix network based on a multi-layer perceptron.
266
267        The MLP maps $x$ of dimension `dim` to a matrix $L(x)$ of size `dim` x `dim`,
268        and then reshapes it to a `dim` x `dim` matrix.
269        Then, the output matrix is given by
270        $$
271            W(x) = L(x) - L(x)^\top.
272        $$
273        If `is_bounded` is set to `True`, then the output is element-wise bounded
274        by applying a `jax.nn.tanh` activation to the output matrix $L$.
275
276        Args:
277            key (PRNGKey): random key
278            dim (int): dimension of the input
279            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
280            units (list[int]): layer sizes
281            is_bounded (bool, optional): whether to give a element-wise bounded output. Defaults to True.
282        """
283        self.dim = dim
284        units = units + [dim * dim]
285        super().__init__(key, dim, units, activation)
286        self.is_bounded = is_bounded
287
288    def __call__(self, x: ArrayLike) -> Array:
289        L = super().__call__(x).reshape(self.dim, self.dim)
290        if self.is_bounded:
291            L = jax.nn.tanh(L)
292        return L - L.T
293
294
295# ------------------------------------------------------------------ #
296#                         Diffusion networks                         #
297# ------------------------------------------------------------------ #
298
299
300class DiffusionMLP(MLP):
301    """Diffusion matrix network based on a multi-layer perceptron."""
302
303    alpha: float
304    dim: int
305    param_dim: int
306
307    def __init__(
308        self,
309        key: PRNGKey,
310        dim: int,
311        units: list[int],
312        activation: str,
313        alpha: float,
314        param_dim: int = 0,
315    ) -> None:
316        r"""Diffusion matrix network based on a multi-layer perceptron.
317
318        This implements the diffusion matrix function
319        $$
320            \sigma(x, args) = \text{Chol}(\alpha I + \text{MLP}(x, args))
321        $$
322        where $\text{Chol}$ is the Cholesky decomposition.
323        Here, MLP maps $(x, args)$ of dimension `dim` + `param_dim` to a matrix of size `dim` x `dim`,
324
325        Args:
326            key (PRNGKey): random key
327            dim (int): dimension of the input
328            units (list[int]): layer sizes
329            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
330            alpha (float): regulariser
331            param_dim (int, optional): dimension of the parameters. Defaults to 0.
332        """
333        self.dim = dim
334        self.param_dim = param_dim
335        units = units + [dim * dim]
336        super().__init__(key, dim + param_dim, units, activation)
337        self.alpha = alpha
338
339    def __call__(self, x: ArrayLike, args: ArrayLike) -> Array:
340        if self.param_dim > 0:
341            x = jnp.concatenate([x, args[1:]], axis=0)
342        sigma = super().__call__(x).reshape(self.dim, self.dim)
343        sigma_squared_regularised = self.alpha * jnp.eye(self.dim) + sigma @ sigma.T
344        return jnp.linalg.cholesky(sigma_squared_regularised)
345
346
347class DiffusionDiagonalMLP(MLP):
348    """Diagonal diffusion matrix network based on a multi-layer perceptron."""
349
350    alpha: float
351    dim: int
352    param_dim: int
353
354    def __init__(
355        self,
356        key: PRNGKey,
357        dim: int,
358        units: list[int],
359        activation: str,
360        alpha: float,
361        param_dim: int = 0,
362    ) -> None:
363        r"""Diagonal diffusion matrix network based on a multi-layer perceptron.
364
365        This implements the diffusion matrix function
366        $$
367            \sigma(x, args) = \text{diag}(\alpha + \text{MLP}(x, args)^2)^{\frac{1}{2}}.
368        $$
369        Here, MLP maps $(x, args)$ of dimension `dim` + `param_dim` to a vector of size `dim`.
370
371        Args:
372            key (PRNGKey): random key
373            dim (int): dimension of the input
374            units (list[int]): layer sizes
375            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
376            alpha (float): regulariser
377            param_dim (int, optional): dimension of the parameters. Defaults to 0.
378        """
379        self.dim = dim
380        self.param_dim = param_dim
381        units = units + [dim]
382        super().__init__(key, dim + param_dim, units, activation)
383        self.alpha = alpha
384
385    def __call__(self, x: ArrayLike, args: ArrayLike) -> Array:
386        if self.param_dim > 0:
387            x = jnp.concatenate([x, args[1:]], axis=0)
388        sigma_diag = super().__call__(x)
389        sigma_diag_regularised = jnp.sqrt(self.alpha + sigma_diag**2)
390        return jnp.diag(sigma_diag_regularised)
391
392
393class DiffusionDiagonalConstant(eqx.Module):
394    """Diagonal diffusion matrix network based on a constant layer."""
395
396    alpha: float
397    constant_layer: ConstantLayer
398    dim: int
399
400    def __init__(self, key: PRNGKey, dim: int, alpha: float) -> None:
401        r"""Diagonal diffusion matrix network based on a constant layer.
402
403        This implements the diffusion matrix function that is constant
404        $$
405            \sigma(x, args) = \text{diag}(\alpha + \text{Constant}^2)^{\frac{1}{2}}.
406        $$
407        where $\text{Constant}$ is a vector of size `dim`.
408        Note that by constant we mean that it does not depend on the input $x$ or the parameters $args$,
409        but it can be trained.
410
411        Args:
412            key (PRNGKey): random key
413            dim (int): dimension of the input
414            alpha (float): regulariser
415        """
416        self.dim = dim
417        self.alpha = alpha
418        self.constant_layer = ConstantLayer(dim, key)
419
420    def __call__(self, x: ArrayLike, args: ArrayLike) -> Array:
421        sigma_diag = self.constant_layer()
422        sigma_squared_regularised = jnp.sqrt(self.alpha + sigma_diag**2)
423        return jnp.diag(sigma_squared_regularised)
424
425
426# ------------------------------------------------------------------ #
427#                      Dimensionality transforms                     #
428# ------------------------------------------------------------------ #
429
430
431class PCATransform(eqx.Module):
432    """PCA transform."""
433
434    mean: ArrayLike
435    components: ArrayLike
436    scaling: ArrayLike
437    centre: bool
438
439    def __init__(
440        self,
441        mean: ArrayLike,
442        components: ArrayLike,
443        scaling: ArrayLike,
444        centre: bool = False,
445    ) -> None:
446        r"""PCA transform.
447
448        Transforms the input vector $x$ via
449        $$
450            x \mapsto \text{components} ( x ) / \sqrt{\text{scaling}}.
451        $$
452        If `centre` is set to `True`, then the input is first centered by subtracting the `mean`.
453
454        Args:
455            mean (ArrayLike): mean of the data used to fit the PCA
456            components (ArrayLike): PCA components
457            scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance)
458            centre (bool, optional): whether to center the data using `mean`. Defaults to False.
459        """
460        self.mean = jnp.array(mean)
461        self.components = jnp.array(components)
462        self.scaling = jnp.array(scaling)
463        self.centre = centre
464
465    def __call__(self, x: ArrayLike) -> Array:
466        if self.centre:
467            x = x - self.mean
468        x_reduced = self.components @ x
469        return x_reduced / jnp.sqrt(self.scaling)
470
471
472class InversePCATransform(PCATransform):
473    """Inverse PCA transform."""
474
475    def __init__(
476        self, mean: ArrayLike, components: ArrayLike, scaling: ArrayLike
477    ) -> None:
478        r"""Inverse PCA transform.
479
480        Transforms the input vector $z$ via
481        $$
482            z \mapsto \text{components}^\top ( z \sqrt{\text{scaling}} )
483        $$
484        If `centre` is set to `True`, `mean` is added to the output.
485
486        Args:
487            mean (ArrayLike): mean of the data used to fit the PCA
488            components (ArrayLike): PCA components
489            scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance)
490        """
491        super().__init__(mean, components, scaling)
492
493    def __call__(self, z: ArrayLike) -> Array:
494        x_reconstructed = self.components.T @ (jnp.sqrt(self.scaling) * z)
495        if self.centre:
496            x_reconstructed = x_reconstructed + self.mean
497        return x_reconstructed
498
499
500class PCAResNetTransform(PCATransform):
501    mlp: MLP
502    mlp_scale: float
503    mlp_input_scale: float
504
505    def __init__(
506        self,
507        mean: ArrayLike,
508        components: ArrayLike,
509        scaling: ArrayLike,
510        key: PRNGKey,
511        units: list[int],
512        activation: str,
513        mlp_scale: float,
514        mlp_input_scale: float,
515    ) -> None:
516        r"""PCA-ResNet transform.
517
518        This combines the PCA transform with an MLP to give a ResNet-like architecture.
519        $$
520            x \mapsto \text{PCA}(x) + \text{mlp\_scale} \times \text{MLP}(\text{mlp\_input\_scale} \times x).
521        $$
522        The input scale adjusts the input (often not order 1) so that the MLP can learn the correct scale.
523
524        Args:
525            mean (ArrayLike): mean of the data used to fit the PCA
526            components (ArrayLike): PCA components
527            scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance)
528            key (PRNGKey): random key
529            units (list[int]): layer sizes of the MLP
530            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
531            mlp_scale (float): scale of the MLP output
532            mlp_input_scale (float): scale of the input to the MLP
533        """
534        super().__init__(mean, components, scaling)
535        units = units + [components.shape[0]]
536        dim = components.shape[1]
537        self.mlp = MLP(key, dim, units, activation)
538        self.mlp_scale = mlp_scale
539        self.mlp_input_scale = mlp_input_scale
540
541    def __call__(self, x: ArrayLike) -> Array:
542        pca_features = super().__call__(x)
543        mlp_features = self.mlp(self.mlp_input_scale * x)
544        return pca_features + self.mlp_scale * mlp_features
545
546    def pca_transform(self, x: ArrayLike) -> Array:
547        """Perform the PCA transform.
548
549        Args:
550            x (ArrayLike): state
551
552        Returns:
553            Array: pca features
554        """
555        return super().__call__(x)
556
557
558class InversePCAResNetTransform(InversePCATransform):
559    """Inverse PCA-ResNet transform."""
560
561    mlp: MLP
562    mlp_scale: float
563
564    def __init__(
565        self,
566        mean: ArrayLike,
567        components: ArrayLike,
568        scaling: ArrayLike,
569        key: PRNGKey,
570        units: list[int],
571        activation: str,
572        mlp_scale: float,
573    ) -> None:
574        r"""Inverse PCA-ResNet transform.
575
576        This combines the inverse PCA transform with an MLP to give a ResNet-like architecture.
577        $$
578            z \mapsto \text{PCA}^{-1}(z) + \text{mlp\_scale} \times \text{MLP}(z).
579        $$
580
581        Args:
582            mean (ArrayLike): mean of the data used to fit the PCA
583            components (ArrayLike): PCA components
584            scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance)
585            key (PRNGKey): random key
586            units (list[int]): layer sizes of the MLP
587            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
588            mlp_scale (float): scale of the MLP output
589        """
590        super().__init__(mean, components, scaling)
591        units = units + [components.shape[1]]
592        dim = components.shape[0]
593        self.mlp = MLP(key, dim, units, activation)
594        self.mlp_scale = mlp_scale
595
596    def __call__(self, z: ArrayLike) -> Array:
597        inverse_pca_recon = super().__call__(z)
598        mlp_recon = self.mlp(z)
599        return inverse_pca_recon + self.mlp_scale * mlp_recon
600
601    def inverse_pca_transform(self, z: ArrayLike) -> Array:
602        """Inverse PCA transform.
603
604        Args:
605            z (ArrayLike): reduced state
606
607        Returns:
608            Array: reconstructed state
609        """
610        return super().__call__(z)
class MLP(equinox._module.Module):
37class MLP(eqx.Module):
38    """Multi-layer perceptron."""
39
40    layers: list[eqx.nn.Linear]
41    activation: Callable[[ArrayLike], Array]
42
43    def __init__(
44        self, key: PRNGKey, dim: int, units: list[int], activation: str
45    ) -> None:
46        r"""Multi-layer perceptron.
47
48        Example:
49        `mlp = MLP(key=jax.random.PRNGKey(0), dim=2, units=[32, 32, 1], activation='tanh')`
50        gives a 2-hidden-layer MLP
51
52        $$2 \to 32 \to 32 \to 1$$
53
54        with tanh activation.
55
56        Args:
57            key (PRNGKey): random key
58            dim (int): dimension of the input
59            units (list[int]): layer sizes
60            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
61        """
62        num_layers = len(units)
63        units = [dim] + units
64        keys = jax.random.split(key, num_layers)
65        self.layers = [
66            eqx.nn.Linear(units[i], units[i + 1], key=key) for i, key in enumerate(keys)
67        ]
68        self.activation = get_activation(activation)
69
70    def __call__(self, x: ArrayLike) -> Array:
71        h = x
72        for layer in self.layers[:-1]:
73            h = self.activation(layer(h))
74        output = self.layers[-1](h)
75        return output

Multi-layer perceptron.

MLP(key: <function PRNGKey>, dim: int, units: list[int], activation: str)
43    def __init__(
44        self, key: PRNGKey, dim: int, units: list[int], activation: str
45    ) -> None:
46        r"""Multi-layer perceptron.
47
48        Example:
49        `mlp = MLP(key=jax.random.PRNGKey(0), dim=2, units=[32, 32, 1], activation='tanh')`
50        gives a 2-hidden-layer MLP
51
52        $$2 \to 32 \to 32 \to 1$$
53
54        with tanh activation.
55
56        Args:
57            key (PRNGKey): random key
58            dim (int): dimension of the input
59            units (list[int]): layer sizes
60            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
61        """
62        num_layers = len(units)
63        units = [dim] + units
64        keys = jax.random.split(key, num_layers)
65        self.layers = [
66            eqx.nn.Linear(units[i], units[i + 1], key=key) for i, key in enumerate(keys)
67        ]
68        self.activation = get_activation(activation)

Multi-layer perceptron.

Example: mlp = MLP(key=jax.random.PRNGKey(0), dim=2, units=[32, 32, 1], activation='tanh') gives a 2-hidden-layer MLP

$$2 \to 32 \to 32 \to 1$$

with tanh activation.

Arguments:
  • key (PRNGKey): random key
  • dim (int): dimension of the input
  • units (list[int]): layer sizes
  • activation (str): activation function (can be any in jax.nn or custom ones defined in onsagernet._activations)
layers: list[equinox.nn._linear.Linear]
activation: Callable[[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]], jax.Array]
class PotentialMLP(MLP):
 83class PotentialMLP(MLP):
 84    """Potential network based on a multi-layer perceptron."""
 85
 86    alpha: float
 87    dim: int
 88    param_dim: int
 89
 90    def __init__(
 91        self,
 92        key: PRNGKey,
 93        dim: int,
 94        units: list[int],
 95        activation: str,
 96        alpha: float,
 97        param_dim: int = 0,
 98    ) -> None:
 99        r"""Potential network based on a multi-layer perceptron.
100
101        This implements the potential function
102        $$
103            V(x, args) = \alpha \|(x, args)\|^2 + \text{MLP}(x, args)
104        $$
105        where $x$ is the input and $u$ are additional parameters.
106        The constant $\alpha \geq 0$ is a regularisation term,
107        which gives a quadratic growth to ensure that the potential is integrable.
108        We are tacitly assuming that MLP is of sub-quadratic growth,
109        so only choose activation functions that have this property
110        (most activation functions are either bounded or of linear growth).
111
112        Args:
113            key (PRNGKey): random key
114            dim (int): dimension of the input
115            units (list[int]): layer sizes
116            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
117            alpha (float): regulariser
118            param_dim (int, optional): dimensions of the parameters. Defaults to 0.
119        """
120        self.dim = dim
121        units = units + [1]
122        self.param_dim = param_dim
123        super().__init__(key, dim + param_dim, units, activation)
124        self.alpha = alpha
125
126    def __call__(self, x: ArrayLike, args: ArrayLike) -> Array:
127        if self.param_dim > 0:
128            x = jnp.concatenate([x, args[1:]], axis=0)
129        output = super().__call__(x) + self.alpha * (x @ x)
130        return jnp.squeeze(output)

Potential network based on a multi-layer perceptron.

PotentialMLP( key: <function PRNGKey>, dim: int, units: list[int], activation: str, alpha: float, param_dim: int = 0)
 90    def __init__(
 91        self,
 92        key: PRNGKey,
 93        dim: int,
 94        units: list[int],
 95        activation: str,
 96        alpha: float,
 97        param_dim: int = 0,
 98    ) -> None:
 99        r"""Potential network based on a multi-layer perceptron.
100
101        This implements the potential function
102        $$
103            V(x, args) = \alpha \|(x, args)\|^2 + \text{MLP}(x, args)
104        $$
105        where $x$ is the input and $u$ are additional parameters.
106        The constant $\alpha \geq 0$ is a regularisation term,
107        which gives a quadratic growth to ensure that the potential is integrable.
108        We are tacitly assuming that MLP is of sub-quadratic growth,
109        so only choose activation functions that have this property
110        (most activation functions are either bounded or of linear growth).
111
112        Args:
113            key (PRNGKey): random key
114            dim (int): dimension of the input
115            units (list[int]): layer sizes
116            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
117            alpha (float): regulariser
118            param_dim (int, optional): dimensions of the parameters. Defaults to 0.
119        """
120        self.dim = dim
121        units = units + [1]
122        self.param_dim = param_dim
123        super().__init__(key, dim + param_dim, units, activation)
124        self.alpha = alpha

Potential network based on a multi-layer perceptron.

This implements the potential function $$ V(x, args) = \alpha \|(x, args)\|^2 + \text{MLP}(x, args) $$ where $x$ is the input and $u$ are additional parameters. The constant $\alpha \geq 0$ is a regularisation term, which gives a quadratic growth to ensure that the potential is integrable. We are tacitly assuming that MLP is of sub-quadratic growth, so only choose activation functions that have this property (most activation functions are either bounded or of linear growth).

Arguments:
  • key (PRNGKey): random key
  • dim (int): dimension of the input
  • units (list[int]): layer sizes
  • activation (str): activation function (can be any in jax.nn or custom ones defined in onsagernet._activations)
  • alpha (float): regulariser
  • param_dim (int, optional): dimensions of the parameters. Defaults to 0.
alpha: float
dim: int
param_dim: int
Inherited Members
MLP
layers
activation
class PotentialResMLP(MLP):
133class PotentialResMLP(MLP):
134    r"""Potential network with a residual connection."""
135
136    alpha: float
137    gamma_layer: eqx.nn.Linear
138    dim: int
139    param_dim: int
140
141    def __init__(
142        self,
143        key: PRNGKey,
144        dim: int,
145        units: list[int],
146        activation: str,
147        n_pot: int,
148        alpha: float,
149        param_dim: int = 0,
150    ) -> None:
151        r"""Potential network with a residual connection.
152
153        This implements the modified potential function
154        $$
155            V(x, args) = \alpha \|(x, args)\|^2
156            + \frac{1}{2}
157            \| \text{MLP}(x, args)+ \Gamma (x, args) \|^2
158        $$
159        where
160
161        - $\phi$ is a MLP of dim + param_dim -> n_pot
162        - $\Gamma$ ia matrix of size [n_pot, dim + para_dim]
163        - $\alpha > 0$ is a scalar regulariser
164
165        Args:
166            key (PRNGKey): random key
167            dim (int): dimension of the input
168            units (list[int]): layer sizes
169            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
170            n_pot (int): size of the MLP part of the potential
171            alpha (float): regulariser
172            param_dim (int, optional): dimension of the parameters. Defaults to 0.
173        """
174        self.dim = dim
175        self.param_dim = param_dim
176        units = units + [n_pot]
177        mlp_key, gamma_key = jax.random.split(key)
178        super().__init__(mlp_key, dim + param_dim, units, activation)
179        self.alpha = alpha
180        self.gamma_layer = eqx.nn.Linear(
181            dim + param_dim, n_pot, key=gamma_key, use_bias=False
182        )
183
184    def __call__(self, x: ArrayLike, args: ArrayLike) -> Array:
185        if self.param_dim > 0:
186            x = jnp.concatenate([x, args[1:]], axis=0)
187        output_phi = super().__call__(x)
188        output_gamma = self.gamma_layer(x)
189        output_combined = (output_phi + output_gamma) @ (output_phi + output_gamma)
190        regularisation = self.alpha * (x @ x)
191        return 0.5 * output_combined + regularisation

Potential network with a residual connection.

PotentialResMLP( key: <function PRNGKey>, dim: int, units: list[int], activation: str, n_pot: int, alpha: float, param_dim: int = 0)
141    def __init__(
142        self,
143        key: PRNGKey,
144        dim: int,
145        units: list[int],
146        activation: str,
147        n_pot: int,
148        alpha: float,
149        param_dim: int = 0,
150    ) -> None:
151        r"""Potential network with a residual connection.
152
153        This implements the modified potential function
154        $$
155            V(x, args) = \alpha \|(x, args)\|^2
156            + \frac{1}{2}
157            \| \text{MLP}(x, args)+ \Gamma (x, args) \|^2
158        $$
159        where
160
161        - $\phi$ is a MLP of dim + param_dim -> n_pot
162        - $\Gamma$ ia matrix of size [n_pot, dim + para_dim]
163        - $\alpha > 0$ is a scalar regulariser
164
165        Args:
166            key (PRNGKey): random key
167            dim (int): dimension of the input
168            units (list[int]): layer sizes
169            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
170            n_pot (int): size of the MLP part of the potential
171            alpha (float): regulariser
172            param_dim (int, optional): dimension of the parameters. Defaults to 0.
173        """
174        self.dim = dim
175        self.param_dim = param_dim
176        units = units + [n_pot]
177        mlp_key, gamma_key = jax.random.split(key)
178        super().__init__(mlp_key, dim + param_dim, units, activation)
179        self.alpha = alpha
180        self.gamma_layer = eqx.nn.Linear(
181            dim + param_dim, n_pot, key=gamma_key, use_bias=False
182        )

Potential network with a residual connection.

This implements the modified potential function $$ V(x, args) = \alpha \|(x, args)\|^2 + \frac{1}{2} \| \text{MLP}(x, args)+ \Gamma (x, args) \|^2 $$ where

  • $\phi$ is a MLP of dim + param_dim -> n_pot
  • $\Gamma$ ia matrix of size [n_pot, dim + para_dim]
  • $\alpha > 0$ is a scalar regulariser
Arguments:
  • key (PRNGKey): random key
  • dim (int): dimension of the input
  • units (list[int]): layer sizes
  • activation (str): activation function (can be any in jax.nn or custom ones defined in onsagernet._activations)
  • n_pot (int): size of the MLP part of the potential
  • alpha (float): regulariser
  • param_dim (int, optional): dimension of the parameters. Defaults to 0.
alpha: float
gamma_layer: equinox.nn._linear.Linear
dim: int
param_dim: int
Inherited Members
MLP
layers
activation
class DissipationMatrixMLP(MLP):
199class DissipationMatrixMLP(MLP):
200    """Dissipation matrix network based on a multi-layer perceptron."""
201
202    alpha: float
203    is_bounded: bool
204    dim: int
205
206    def __init__(
207        self,
208        key: PRNGKey,
209        dim: int,
210        units: list[int],
211        activation: str,
212        alpha: float,
213        is_bounded: bool = True,
214    ) -> None:
215        r"""Dissipation matrix network based on a multi-layer perceptron.
216
217        The MLP maps $x$ of dimension `dim` to a matrix $L(x)$ of size `dim` x `dim`,
218        and then reshapes it to a `dim` x `dim` matrix.
219        Then, the output matrix is given by
220        $$
221            M(x) = \alpha I + L(x) L(x)^\top.
222        $$
223        If `is_bounded` is set to `True`, then the output is element-wise bounded
224        by applying a `jax.nn.tanh` activation to the output matrix $L$.
225
226        Args:
227            key (PRNGKey): random key
228            dim (int): dimension of the input
229            units (list[int]): layer sizes
230            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
231            alpha (float): regulariser
232            is_bounded (bool, optional): whether to give a element-wise bounded output. Defaults to True.
233        """
234        self.dim = dim
235        units = units + [dim * dim]
236        super().__init__(key, dim, units, activation)
237        self.alpha = alpha
238        self.is_bounded = is_bounded
239
240    def __call__(self, x: ArrayLike) -> Array:
241        L = super().__call__(x).reshape(self.dim, self.dim)
242        if self.is_bounded:
243            L = jax.nn.tanh(L)
244        return self.alpha * jnp.eye(self.dim) + L @ L.T

Dissipation matrix network based on a multi-layer perceptron.

DissipationMatrixMLP( key: <function PRNGKey>, dim: int, units: list[int], activation: str, alpha: float, is_bounded: bool = True)
206    def __init__(
207        self,
208        key: PRNGKey,
209        dim: int,
210        units: list[int],
211        activation: str,
212        alpha: float,
213        is_bounded: bool = True,
214    ) -> None:
215        r"""Dissipation matrix network based on a multi-layer perceptron.
216
217        The MLP maps $x$ of dimension `dim` to a matrix $L(x)$ of size `dim` x `dim`,
218        and then reshapes it to a `dim` x `dim` matrix.
219        Then, the output matrix is given by
220        $$
221            M(x) = \alpha I + L(x) L(x)^\top.
222        $$
223        If `is_bounded` is set to `True`, then the output is element-wise bounded
224        by applying a `jax.nn.tanh` activation to the output matrix $L$.
225
226        Args:
227            key (PRNGKey): random key
228            dim (int): dimension of the input
229            units (list[int]): layer sizes
230            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
231            alpha (float): regulariser
232            is_bounded (bool, optional): whether to give a element-wise bounded output. Defaults to True.
233        """
234        self.dim = dim
235        units = units + [dim * dim]
236        super().__init__(key, dim, units, activation)
237        self.alpha = alpha
238        self.is_bounded = is_bounded

Dissipation matrix network based on a multi-layer perceptron.

The MLP maps $x$ of dimension dim to a matrix $L(x)$ of size dim x dim, and then reshapes it to a dim x dim matrix. Then, the output matrix is given by $$ M(x) = \alpha I + L(x) L(x)^\top. $$ If is_bounded is set to True, then the output is element-wise bounded by applying a jax.nn.tanh activation to the output matrix $L$.

Arguments:
  • key (PRNGKey): random key
  • dim (int): dimension of the input
  • units (list[int]): layer sizes
  • activation (str): activation function (can be any in jax.nn or custom ones defined in onsagernet._activations)
  • alpha (float): regulariser
  • is_bounded (bool, optional): whether to give a element-wise bounded output. Defaults to True.
alpha: float
is_bounded: bool
dim: int
Inherited Members
MLP
layers
activation
class ConservationMatrixMLP(MLP):
252class ConservationMatrixMLP(MLP):
253    """Conservation matrix network based on a multi-layer perceptron."""
254
255    is_bounded: bool
256    dim: int
257
258    def __init__(
259        self,
260        key: PRNGKey,
261        dim: int,
262        activation: str,
263        units: list[int],
264        is_bounded: bool = True,
265    ) -> None:
266        r"""Conservation matrix network based on a multi-layer perceptron.
267
268        The MLP maps $x$ of dimension `dim` to a matrix $L(x)$ of size `dim` x `dim`,
269        and then reshapes it to a `dim` x `dim` matrix.
270        Then, the output matrix is given by
271        $$
272            W(x) = L(x) - L(x)^\top.
273        $$
274        If `is_bounded` is set to `True`, then the output is element-wise bounded
275        by applying a `jax.nn.tanh` activation to the output matrix $L$.
276
277        Args:
278            key (PRNGKey): random key
279            dim (int): dimension of the input
280            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
281            units (list[int]): layer sizes
282            is_bounded (bool, optional): whether to give a element-wise bounded output. Defaults to True.
283        """
284        self.dim = dim
285        units = units + [dim * dim]
286        super().__init__(key, dim, units, activation)
287        self.is_bounded = is_bounded
288
289    def __call__(self, x: ArrayLike) -> Array:
290        L = super().__call__(x).reshape(self.dim, self.dim)
291        if self.is_bounded:
292            L = jax.nn.tanh(L)
293        return L - L.T

Conservation matrix network based on a multi-layer perceptron.

ConservationMatrixMLP( key: <function PRNGKey>, dim: int, activation: str, units: list[int], is_bounded: bool = True)
258    def __init__(
259        self,
260        key: PRNGKey,
261        dim: int,
262        activation: str,
263        units: list[int],
264        is_bounded: bool = True,
265    ) -> None:
266        r"""Conservation matrix network based on a multi-layer perceptron.
267
268        The MLP maps $x$ of dimension `dim` to a matrix $L(x)$ of size `dim` x `dim`,
269        and then reshapes it to a `dim` x `dim` matrix.
270        Then, the output matrix is given by
271        $$
272            W(x) = L(x) - L(x)^\top.
273        $$
274        If `is_bounded` is set to `True`, then the output is element-wise bounded
275        by applying a `jax.nn.tanh` activation to the output matrix $L$.
276
277        Args:
278            key (PRNGKey): random key
279            dim (int): dimension of the input
280            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
281            units (list[int]): layer sizes
282            is_bounded (bool, optional): whether to give a element-wise bounded output. Defaults to True.
283        """
284        self.dim = dim
285        units = units + [dim * dim]
286        super().__init__(key, dim, units, activation)
287        self.is_bounded = is_bounded

Conservation matrix network based on a multi-layer perceptron.

The MLP maps $x$ of dimension dim to a matrix $L(x)$ of size dim x dim, and then reshapes it to a dim x dim matrix. Then, the output matrix is given by $$ W(x) = L(x) - L(x)^\top. $$ If is_bounded is set to True, then the output is element-wise bounded by applying a jax.nn.tanh activation to the output matrix $L$.

Arguments:
  • key (PRNGKey): random key
  • dim (int): dimension of the input
  • activation (str): activation function (can be any in jax.nn or custom ones defined in onsagernet._activations)
  • units (list[int]): layer sizes
  • is_bounded (bool, optional): whether to give a element-wise bounded output. Defaults to True.
is_bounded: bool
dim: int
Inherited Members
MLP
layers
activation
class DiffusionMLP(MLP):
301class DiffusionMLP(MLP):
302    """Diffusion matrix network based on a multi-layer perceptron."""
303
304    alpha: float
305    dim: int
306    param_dim: int
307
308    def __init__(
309        self,
310        key: PRNGKey,
311        dim: int,
312        units: list[int],
313        activation: str,
314        alpha: float,
315        param_dim: int = 0,
316    ) -> None:
317        r"""Diffusion matrix network based on a multi-layer perceptron.
318
319        This implements the diffusion matrix function
320        $$
321            \sigma(x, args) = \text{Chol}(\alpha I + \text{MLP}(x, args))
322        $$
323        where $\text{Chol}$ is the Cholesky decomposition.
324        Here, MLP maps $(x, args)$ of dimension `dim` + `param_dim` to a matrix of size `dim` x `dim`,
325
326        Args:
327            key (PRNGKey): random key
328            dim (int): dimension of the input
329            units (list[int]): layer sizes
330            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
331            alpha (float): regulariser
332            param_dim (int, optional): dimension of the parameters. Defaults to 0.
333        """
334        self.dim = dim
335        self.param_dim = param_dim
336        units = units + [dim * dim]
337        super().__init__(key, dim + param_dim, units, activation)
338        self.alpha = alpha
339
340    def __call__(self, x: ArrayLike, args: ArrayLike) -> Array:
341        if self.param_dim > 0:
342            x = jnp.concatenate([x, args[1:]], axis=0)
343        sigma = super().__call__(x).reshape(self.dim, self.dim)
344        sigma_squared_regularised = self.alpha * jnp.eye(self.dim) + sigma @ sigma.T
345        return jnp.linalg.cholesky(sigma_squared_regularised)

Diffusion matrix network based on a multi-layer perceptron.

DiffusionMLP( key: <function PRNGKey>, dim: int, units: list[int], activation: str, alpha: float, param_dim: int = 0)
308    def __init__(
309        self,
310        key: PRNGKey,
311        dim: int,
312        units: list[int],
313        activation: str,
314        alpha: float,
315        param_dim: int = 0,
316    ) -> None:
317        r"""Diffusion matrix network based on a multi-layer perceptron.
318
319        This implements the diffusion matrix function
320        $$
321            \sigma(x, args) = \text{Chol}(\alpha I + \text{MLP}(x, args))
322        $$
323        where $\text{Chol}$ is the Cholesky decomposition.
324        Here, MLP maps $(x, args)$ of dimension `dim` + `param_dim` to a matrix of size `dim` x `dim`,
325
326        Args:
327            key (PRNGKey): random key
328            dim (int): dimension of the input
329            units (list[int]): layer sizes
330            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
331            alpha (float): regulariser
332            param_dim (int, optional): dimension of the parameters. Defaults to 0.
333        """
334        self.dim = dim
335        self.param_dim = param_dim
336        units = units + [dim * dim]
337        super().__init__(key, dim + param_dim, units, activation)
338        self.alpha = alpha

Diffusion matrix network based on a multi-layer perceptron.

This implements the diffusion matrix function $$ \sigma(x, args) = \text{Chol}(\alpha I + \text{MLP}(x, args)) $$ where $\text{Chol}$ is the Cholesky decomposition. Here, MLP maps $(x, args)$ of dimension dim + param_dim to a matrix of size dim x dim,

Arguments:
  • key (PRNGKey): random key
  • dim (int): dimension of the input
  • units (list[int]): layer sizes
  • activation (str): activation function (can be any in jax.nn or custom ones defined in onsagernet._activations)
  • alpha (float): regulariser
  • param_dim (int, optional): dimension of the parameters. Defaults to 0.
alpha: float
dim: int
param_dim: int
Inherited Members
MLP
layers
activation
class DiffusionDiagonalMLP(MLP):
348class DiffusionDiagonalMLP(MLP):
349    """Diagonal diffusion matrix network based on a multi-layer perceptron."""
350
351    alpha: float
352    dim: int
353    param_dim: int
354
355    def __init__(
356        self,
357        key: PRNGKey,
358        dim: int,
359        units: list[int],
360        activation: str,
361        alpha: float,
362        param_dim: int = 0,
363    ) -> None:
364        r"""Diagonal diffusion matrix network based on a multi-layer perceptron.
365
366        This implements the diffusion matrix function
367        $$
368            \sigma(x, args) = \text{diag}(\alpha + \text{MLP}(x, args)^2)^{\frac{1}{2}}.
369        $$
370        Here, MLP maps $(x, args)$ of dimension `dim` + `param_dim` to a vector of size `dim`.
371
372        Args:
373            key (PRNGKey): random key
374            dim (int): dimension of the input
375            units (list[int]): layer sizes
376            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
377            alpha (float): regulariser
378            param_dim (int, optional): dimension of the parameters. Defaults to 0.
379        """
380        self.dim = dim
381        self.param_dim = param_dim
382        units = units + [dim]
383        super().__init__(key, dim + param_dim, units, activation)
384        self.alpha = alpha
385
386    def __call__(self, x: ArrayLike, args: ArrayLike) -> Array:
387        if self.param_dim > 0:
388            x = jnp.concatenate([x, args[1:]], axis=0)
389        sigma_diag = super().__call__(x)
390        sigma_diag_regularised = jnp.sqrt(self.alpha + sigma_diag**2)
391        return jnp.diag(sigma_diag_regularised)

Diagonal diffusion matrix network based on a multi-layer perceptron.

DiffusionDiagonalMLP( key: <function PRNGKey>, dim: int, units: list[int], activation: str, alpha: float, param_dim: int = 0)
355    def __init__(
356        self,
357        key: PRNGKey,
358        dim: int,
359        units: list[int],
360        activation: str,
361        alpha: float,
362        param_dim: int = 0,
363    ) -> None:
364        r"""Diagonal diffusion matrix network based on a multi-layer perceptron.
365
366        This implements the diffusion matrix function
367        $$
368            \sigma(x, args) = \text{diag}(\alpha + \text{MLP}(x, args)^2)^{\frac{1}{2}}.
369        $$
370        Here, MLP maps $(x, args)$ of dimension `dim` + `param_dim` to a vector of size `dim`.
371
372        Args:
373            key (PRNGKey): random key
374            dim (int): dimension of the input
375            units (list[int]): layer sizes
376            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
377            alpha (float): regulariser
378            param_dim (int, optional): dimension of the parameters. Defaults to 0.
379        """
380        self.dim = dim
381        self.param_dim = param_dim
382        units = units + [dim]
383        super().__init__(key, dim + param_dim, units, activation)
384        self.alpha = alpha

Diagonal diffusion matrix network based on a multi-layer perceptron.

This implements the diffusion matrix function $$ \sigma(x, args) = \text{diag}(\alpha + \text{MLP}(x, args)^2)^{\frac{1}{2}}. $$ Here, MLP maps $(x, args)$ of dimension dim + param_dim to a vector of size dim.

Arguments:
  • key (PRNGKey): random key
  • dim (int): dimension of the input
  • units (list[int]): layer sizes
  • activation (str): activation function (can be any in jax.nn or custom ones defined in onsagernet._activations)
  • alpha (float): regulariser
  • param_dim (int, optional): dimension of the parameters. Defaults to 0.
alpha: float
dim: int
param_dim: int
Inherited Members
MLP
layers
activation
class DiffusionDiagonalConstant(equinox._module.Module):
394class DiffusionDiagonalConstant(eqx.Module):
395    """Diagonal diffusion matrix network based on a constant layer."""
396
397    alpha: float
398    constant_layer: ConstantLayer
399    dim: int
400
401    def __init__(self, key: PRNGKey, dim: int, alpha: float) -> None:
402        r"""Diagonal diffusion matrix network based on a constant layer.
403
404        This implements the diffusion matrix function that is constant
405        $$
406            \sigma(x, args) = \text{diag}(\alpha + \text{Constant}^2)^{\frac{1}{2}}.
407        $$
408        where $\text{Constant}$ is a vector of size `dim`.
409        Note that by constant we mean that it does not depend on the input $x$ or the parameters $args$,
410        but it can be trained.
411
412        Args:
413            key (PRNGKey): random key
414            dim (int): dimension of the input
415            alpha (float): regulariser
416        """
417        self.dim = dim
418        self.alpha = alpha
419        self.constant_layer = ConstantLayer(dim, key)
420
421    def __call__(self, x: ArrayLike, args: ArrayLike) -> Array:
422        sigma_diag = self.constant_layer()
423        sigma_squared_regularised = jnp.sqrt(self.alpha + sigma_diag**2)
424        return jnp.diag(sigma_squared_regularised)

Diagonal diffusion matrix network based on a constant layer.

DiffusionDiagonalConstant(key: <function PRNGKey>, dim: int, alpha: float)
401    def __init__(self, key: PRNGKey, dim: int, alpha: float) -> None:
402        r"""Diagonal diffusion matrix network based on a constant layer.
403
404        This implements the diffusion matrix function that is constant
405        $$
406            \sigma(x, args) = \text{diag}(\alpha + \text{Constant}^2)^{\frac{1}{2}}.
407        $$
408        where $\text{Constant}$ is a vector of size `dim`.
409        Note that by constant we mean that it does not depend on the input $x$ or the parameters $args$,
410        but it can be trained.
411
412        Args:
413            key (PRNGKey): random key
414            dim (int): dimension of the input
415            alpha (float): regulariser
416        """
417        self.dim = dim
418        self.alpha = alpha
419        self.constant_layer = ConstantLayer(dim, key)

Diagonal diffusion matrix network based on a constant layer.

This implements the diffusion matrix function that is constant $$ \sigma(x, args) = \text{diag}(\alpha + \text{Constant}^2)^{\frac{1}{2}}. $$ where $\text{Constant}$ is a vector of size dim. Note that by constant we mean that it does not depend on the input $x$ or the parameters $args$, but it can be trained.

Arguments:
  • key (PRNGKey): random key
  • dim (int): dimension of the input
  • alpha (float): regulariser
alpha: float
dim: int
class PCATransform(equinox._module.Module):
432class PCATransform(eqx.Module):
433    """PCA transform."""
434
435    mean: ArrayLike
436    components: ArrayLike
437    scaling: ArrayLike
438    centre: bool
439
440    def __init__(
441        self,
442        mean: ArrayLike,
443        components: ArrayLike,
444        scaling: ArrayLike,
445        centre: bool = False,
446    ) -> None:
447        r"""PCA transform.
448
449        Transforms the input vector $x$ via
450        $$
451            x \mapsto \text{components} ( x ) / \sqrt{\text{scaling}}.
452        $$
453        If `centre` is set to `True`, then the input is first centered by subtracting the `mean`.
454
455        Args:
456            mean (ArrayLike): mean of the data used to fit the PCA
457            components (ArrayLike): PCA components
458            scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance)
459            centre (bool, optional): whether to center the data using `mean`. Defaults to False.
460        """
461        self.mean = jnp.array(mean)
462        self.components = jnp.array(components)
463        self.scaling = jnp.array(scaling)
464        self.centre = centre
465
466    def __call__(self, x: ArrayLike) -> Array:
467        if self.centre:
468            x = x - self.mean
469        x_reduced = self.components @ x
470        return x_reduced / jnp.sqrt(self.scaling)

PCA transform.

PCATransform( mean: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], components: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], scaling: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], centre: bool = False)
440    def __init__(
441        self,
442        mean: ArrayLike,
443        components: ArrayLike,
444        scaling: ArrayLike,
445        centre: bool = False,
446    ) -> None:
447        r"""PCA transform.
448
449        Transforms the input vector $x$ via
450        $$
451            x \mapsto \text{components} ( x ) / \sqrt{\text{scaling}}.
452        $$
453        If `centre` is set to `True`, then the input is first centered by subtracting the `mean`.
454
455        Args:
456            mean (ArrayLike): mean of the data used to fit the PCA
457            components (ArrayLike): PCA components
458            scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance)
459            centre (bool, optional): whether to center the data using `mean`. Defaults to False.
460        """
461        self.mean = jnp.array(mean)
462        self.components = jnp.array(components)
463        self.scaling = jnp.array(scaling)
464        self.centre = centre

PCA transform.

Transforms the input vector $x$ via $$ x \mapsto \text{components} ( x ) / \sqrt{\text{scaling}}. $$ If centre is set to True, then the input is first centered by subtracting the mean.

Arguments:
  • mean (ArrayLike): mean of the data used to fit the PCA
  • components (ArrayLike): PCA components
  • scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance)
  • centre (bool, optional): whether to center the data using mean. Defaults to False.
mean: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]
components: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]
scaling: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]
centre: bool
class InversePCATransform(PCATransform):
473class InversePCATransform(PCATransform):
474    """Inverse PCA transform."""
475
476    def __init__(
477        self, mean: ArrayLike, components: ArrayLike, scaling: ArrayLike
478    ) -> None:
479        r"""Inverse PCA transform.
480
481        Transforms the input vector $z$ via
482        $$
483            z \mapsto \text{components}^\top ( z \sqrt{\text{scaling}} )
484        $$
485        If `centre` is set to `True`, `mean` is added to the output.
486
487        Args:
488            mean (ArrayLike): mean of the data used to fit the PCA
489            components (ArrayLike): PCA components
490            scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance)
491        """
492        super().__init__(mean, components, scaling)
493
494    def __call__(self, z: ArrayLike) -> Array:
495        x_reconstructed = self.components.T @ (jnp.sqrt(self.scaling) * z)
496        if self.centre:
497            x_reconstructed = x_reconstructed + self.mean
498        return x_reconstructed

Inverse PCA transform.

InversePCATransform( mean: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], components: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], scaling: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex])
476    def __init__(
477        self, mean: ArrayLike, components: ArrayLike, scaling: ArrayLike
478    ) -> None:
479        r"""Inverse PCA transform.
480
481        Transforms the input vector $z$ via
482        $$
483            z \mapsto \text{components}^\top ( z \sqrt{\text{scaling}} )
484        $$
485        If `centre` is set to `True`, `mean` is added to the output.
486
487        Args:
488            mean (ArrayLike): mean of the data used to fit the PCA
489            components (ArrayLike): PCA components
490            scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance)
491        """
492        super().__init__(mean, components, scaling)

Inverse PCA transform.

Transforms the input vector $z$ via $$ z \mapsto \text{components}^\top ( z \sqrt{\text{scaling}} ) $$ If centre is set to True, mean is added to the output.

Arguments:
  • mean (ArrayLike): mean of the data used to fit the PCA
  • components (ArrayLike): PCA components
  • scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance)
class PCAResNetTransform(PCATransform):
501class PCAResNetTransform(PCATransform):
502    mlp: MLP
503    mlp_scale: float
504    mlp_input_scale: float
505
506    def __init__(
507        self,
508        mean: ArrayLike,
509        components: ArrayLike,
510        scaling: ArrayLike,
511        key: PRNGKey,
512        units: list[int],
513        activation: str,
514        mlp_scale: float,
515        mlp_input_scale: float,
516    ) -> None:
517        r"""PCA-ResNet transform.
518
519        This combines the PCA transform with an MLP to give a ResNet-like architecture.
520        $$
521            x \mapsto \text{PCA}(x) + \text{mlp\_scale} \times \text{MLP}(\text{mlp\_input\_scale} \times x).
522        $$
523        The input scale adjusts the input (often not order 1) so that the MLP can learn the correct scale.
524
525        Args:
526            mean (ArrayLike): mean of the data used to fit the PCA
527            components (ArrayLike): PCA components
528            scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance)
529            key (PRNGKey): random key
530            units (list[int]): layer sizes of the MLP
531            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
532            mlp_scale (float): scale of the MLP output
533            mlp_input_scale (float): scale of the input to the MLP
534        """
535        super().__init__(mean, components, scaling)
536        units = units + [components.shape[0]]
537        dim = components.shape[1]
538        self.mlp = MLP(key, dim, units, activation)
539        self.mlp_scale = mlp_scale
540        self.mlp_input_scale = mlp_input_scale
541
542    def __call__(self, x: ArrayLike) -> Array:
543        pca_features = super().__call__(x)
544        mlp_features = self.mlp(self.mlp_input_scale * x)
545        return pca_features + self.mlp_scale * mlp_features
546
547    def pca_transform(self, x: ArrayLike) -> Array:
548        """Perform the PCA transform.
549
550        Args:
551            x (ArrayLike): state
552
553        Returns:
554            Array: pca features
555        """
556        return super().__call__(x)
PCAResNetTransform( mean: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], components: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], scaling: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], key: <function PRNGKey>, units: list[int], activation: str, mlp_scale: float, mlp_input_scale: float)
506    def __init__(
507        self,
508        mean: ArrayLike,
509        components: ArrayLike,
510        scaling: ArrayLike,
511        key: PRNGKey,
512        units: list[int],
513        activation: str,
514        mlp_scale: float,
515        mlp_input_scale: float,
516    ) -> None:
517        r"""PCA-ResNet transform.
518
519        This combines the PCA transform with an MLP to give a ResNet-like architecture.
520        $$
521            x \mapsto \text{PCA}(x) + \text{mlp\_scale} \times \text{MLP}(\text{mlp\_input\_scale} \times x).
522        $$
523        The input scale adjusts the input (often not order 1) so that the MLP can learn the correct scale.
524
525        Args:
526            mean (ArrayLike): mean of the data used to fit the PCA
527            components (ArrayLike): PCA components
528            scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance)
529            key (PRNGKey): random key
530            units (list[int]): layer sizes of the MLP
531            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
532            mlp_scale (float): scale of the MLP output
533            mlp_input_scale (float): scale of the input to the MLP
534        """
535        super().__init__(mean, components, scaling)
536        units = units + [components.shape[0]]
537        dim = components.shape[1]
538        self.mlp = MLP(key, dim, units, activation)
539        self.mlp_scale = mlp_scale
540        self.mlp_input_scale = mlp_input_scale

PCA-ResNet transform.

This combines the PCA transform with an MLP to give a ResNet-like architecture. $$ x \mapsto \text{PCA}(x) + \text{mlp_scale} \times \text{MLP}(\text{mlp_input_scale} \times x). $$ The input scale adjusts the input (often not order 1) so that the MLP can learn the correct scale.

Arguments:
  • mean (ArrayLike): mean of the data used to fit the PCA
  • components (ArrayLike): PCA components
  • scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance)
  • key (PRNGKey): random key
  • units (list[int]): layer sizes of the MLP
  • activation (str): activation function (can be any in jax.nn or custom ones defined in onsagernet._activations)
  • mlp_scale (float): scale of the MLP output
  • mlp_input_scale (float): scale of the input to the MLP
mlp: MLP
mlp_scale: float
mlp_input_scale: float
def pca_transform(unknown):

Perform the PCA transform.

Arguments:
  • x (ArrayLike): state
Returns:

Array: pca features

class InversePCAResNetTransform(InversePCATransform):
559class InversePCAResNetTransform(InversePCATransform):
560    """Inverse PCA-ResNet transform."""
561
562    mlp: MLP
563    mlp_scale: float
564
565    def __init__(
566        self,
567        mean: ArrayLike,
568        components: ArrayLike,
569        scaling: ArrayLike,
570        key: PRNGKey,
571        units: list[int],
572        activation: str,
573        mlp_scale: float,
574    ) -> None:
575        r"""Inverse PCA-ResNet transform.
576
577        This combines the inverse PCA transform with an MLP to give a ResNet-like architecture.
578        $$
579            z \mapsto \text{PCA}^{-1}(z) + \text{mlp\_scale} \times \text{MLP}(z).
580        $$
581
582        Args:
583            mean (ArrayLike): mean of the data used to fit the PCA
584            components (ArrayLike): PCA components
585            scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance)
586            key (PRNGKey): random key
587            units (list[int]): layer sizes of the MLP
588            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
589            mlp_scale (float): scale of the MLP output
590        """
591        super().__init__(mean, components, scaling)
592        units = units + [components.shape[1]]
593        dim = components.shape[0]
594        self.mlp = MLP(key, dim, units, activation)
595        self.mlp_scale = mlp_scale
596
597    def __call__(self, z: ArrayLike) -> Array:
598        inverse_pca_recon = super().__call__(z)
599        mlp_recon = self.mlp(z)
600        return inverse_pca_recon + self.mlp_scale * mlp_recon
601
602    def inverse_pca_transform(self, z: ArrayLike) -> Array:
603        """Inverse PCA transform.
604
605        Args:
606            z (ArrayLike): reduced state
607
608        Returns:
609            Array: reconstructed state
610        """
611        return super().__call__(z)

Inverse PCA-ResNet transform.

InversePCAResNetTransform( mean: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], components: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], scaling: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], key: <function PRNGKey>, units: list[int], activation: str, mlp_scale: float)
565    def __init__(
566        self,
567        mean: ArrayLike,
568        components: ArrayLike,
569        scaling: ArrayLike,
570        key: PRNGKey,
571        units: list[int],
572        activation: str,
573        mlp_scale: float,
574    ) -> None:
575        r"""Inverse PCA-ResNet transform.
576
577        This combines the inverse PCA transform with an MLP to give a ResNet-like architecture.
578        $$
579            z \mapsto \text{PCA}^{-1}(z) + \text{mlp\_scale} \times \text{MLP}(z).
580        $$
581
582        Args:
583            mean (ArrayLike): mean of the data used to fit the PCA
584            components (ArrayLike): PCA components
585            scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance)
586            key (PRNGKey): random key
587            units (list[int]): layer sizes of the MLP
588            activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`)
589            mlp_scale (float): scale of the MLP output
590        """
591        super().__init__(mean, components, scaling)
592        units = units + [components.shape[1]]
593        dim = components.shape[0]
594        self.mlp = MLP(key, dim, units, activation)
595        self.mlp_scale = mlp_scale

Inverse PCA-ResNet transform.

This combines the inverse PCA transform with an MLP to give a ResNet-like architecture. $$ z \mapsto \text{PCA}^{-1}(z) + \text{mlp_scale} \times \text{MLP}(z). $$

Arguments:
  • mean (ArrayLike): mean of the data used to fit the PCA
  • components (ArrayLike): PCA components
  • scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance)
  • key (PRNGKey): random key
  • units (list[int]): layer sizes of the MLP
  • activation (str): activation function (can be any in jax.nn or custom ones defined in onsagernet._activations)
  • mlp_scale (float): scale of the MLP output
mlp: MLP
mlp_scale: float
def inverse_pca_transform(unknown):

Inverse PCA transform.

Arguments:
  • z (ArrayLike): reduced state
Returns:

Array: reconstructed state