onsagernet.transformations

Dimensionality transformations

This module contains classes for dimensionality transformations, such as encoders and decoders. These are to be used in the context where reduced dynamics is sought after, but only microscopic data is available.

The main classes are Encoder and Decoder, which are abstract classes that should be implemented by the user. The ClosureEncoder and ClosureDecoder classes are concrete implementations of the Encoder and Decoder classes, respectively. They are used to encode and decode the data into a reduced space, after which the reduced dynamics can be trained. In this case the reduced space consist of both known macroscopic coordinates and learned closure coordinates.

The Encoder and Decoder objects are used in onsagernet.dynamics.ReducedSDE to be used for closure modelling, for example.

from onsagernet.transformations import ClosureEncoder, ClosureDecoder
from onsagernet.dynamics = ReducedSDE, OnsagerNet

encoder = ClosureEncoder(...)
decoder = ClosureDecoder(...)
sde = OnsagerNet(...)

reduced_sde = ReducedSDE(encoder, decoder, sde)  # can be used to train, predict, etc

As in onsagernet.dynamics, only the model assembly logic is provided here. Some example implementations of the model architecture are provided in the onsagernet.models module.

  1"""
  2# Dimensionality transformations
  3
  4This module contains classes for dimensionality transformations, such as encoders and decoders.
  5These are to be used in the context where reduced dynamics is sought after,
  6but only microscopic data is available.
  7
  8The main classes are `Encoder` and `Decoder`, which are abstract classes that should be
  9implemented by the user. The `ClosureEncoder` and `ClosureDecoder` classes are concrete
 10implementations of the `Encoder` and `Decoder` classes, respectively.
 11They are used to encode and decode the data into a reduced space,
 12after which the reduced dynamics can be trained.
 13In this case the reduced space consist of both known macroscopic coordinates
 14and learned closure coordinates.
 15
 16The `Encoder` and `Decoder` objects are used in `onsagernet.dynamics.ReducedSDE`
 17to be used for closure modelling, for example.
 18
 19```python
 20from onsagernet.transformations import ClosureEncoder, ClosureDecoder
 21from onsagernet.dynamics = ReducedSDE, OnsagerNet
 22
 23encoder = ClosureEncoder(...)
 24decoder = ClosureDecoder(...)
 25sde = OnsagerNet(...)
 26
 27reduced_sde = ReducedSDE(encoder, decoder, sde)  # can be used to train, predict, etc
 28```
 29
 30As in `onsagernet.dynamics`, only the model assembly logic is provided here.
 31Some example implementations of the model architecture are provided
 32in the `onsagernet.models` module.
 33
 34"""
 35
 36import equinox as eqx
 37import jax.numpy as jnp
 38from abc import abstractmethod
 39
 40# ------------------------- Typing imports ------------------------- #
 41from jax import Array
 42from jax.typing import ArrayLike
 43
 44# ------------------------------------------------------------------ #
 45#                        Encoders and decoders                       #
 46# ------------------------------------------------------------------ #
 47
 48
 49class Encoder(eqx.Module):
 50    """The base class for encoders."""
 51
 52    @abstractmethod
 53    def __call__(self, x: ArrayLike) -> Array:
 54        pass
 55
 56
 57class Decoder(eqx.Module):
 58    """The base class for decoders."""
 59
 60    @abstractmethod
 61    def __call__(self, z: ArrayLike) -> Array:
 62        pass
 63
 64
 65class EncoderfromFunc(Encoder):
 66    """Encoder constructed from a given closure transform."""
 67
 68    closure_transform: eqx.Module
 69
 70    def __init__(self, closure_transform: eqx.Module) -> None:
 71        r"""Encoder constructed from a given closure transform.
 72
 73        Takes a given closure transformation $x\mapsto z$
 74        to define the encoder.
 75
 76        Args:
 77            closure_transform (eqx.Module): a given transformation $x\mapsto z$
 78        """
 79        self.closure_transform = closure_transform
 80
 81    def __call__(self, x: ArrayLike) -> Array:
 82        return self.closure_transform(x)
 83
 84
 85class DecoderfromFunc(Decoder):
 86    """Decoder constructed from a given inverse closure transform."""
 87
 88    inverse_closure_transform: eqx.Module
 89
 90    def __init__(self, inverse_closure_transform: eqx.Module) -> None:
 91        r"""Decoder constructed from a given inverse closure transform.
 92
 93        Takes a given inverse closure transformation $z\mapsto x$
 94
 95        Args:
 96            inverse_closure_transform (eqx.Module): a given transformation $z\mapsto x$
 97        """
 98        self.inverse_closure_transform = inverse_closure_transform
 99
100    def __call__(self, z: ArrayLike) -> Array:
101        return self.inverse_closure_transform(z)
102
103
104class ClosureEncoder(EncoderfromFunc):
105    """Closure encoder which combines known macroscopic coordinates
106    with learned (or PCA) closure coordinates.
107    """
108
109    macroscopic_transform: eqx.Module
110
111    def __init__(
112        self, macroscopic_transform: eqx.Module, closure_transform: eqx.Module
113    ) -> None:
114        r"""Closure encoder which combines known macroscopic coordinates
115        with learned (or PCA) closure coordinates.
116
117        $$
118            x \mapsto z = [\varphi^*(x), \hat\varphi(x)]
119        $$
120
121        where $\varphi^*$ is the known macroscopic transformation and
122        $\hat\varphi$ is the learned closure transformation.
123
124        Args:
125            macroscopic_transform (eqx.Module): the known macroscopic transformation
126            closure_transform (eqx.Module): the learned closure transformation
127        """
128        self.macroscopic_transform = macroscopic_transform
129        self.closure_transform = closure_transform
130
131    def __call__(self, x: ArrayLike) -> Array:
132        """Combines the macroscopic and closure coordinates.
133
134        Args:
135            x (ArrayLike): miroscopic state
136
137        Returns:
138            Array: reduced state
139        """
140        macroscopic_coords = self.macroscopic_transform(x)
141        closure_coords = self.closure_transform(x)
142        reduced_coords = jnp.concatenate([macroscopic_coords, closure_coords])
143        return reduced_coords
144
145
146class ClosureDecoder(DecoderfromFunc):
147    """Decodes from a closure encoder model output."""
148
149    macroscopic_dim: int
150
151    def __init__(
152        self, inverse_closure_transform: eqx.Module, macroscopic_dim: int
153    ) -> None:
154        r"""Closure decoder which extracts the closure coordinates from the reduced state
155        and then applies the inverse closure transformation to reconstruct the microscopic state.
156
157        $$
158            z[\text{macroscopic_dim}:] \mapsto x
159        $$
160
161        It is assuemd that the first `macroscopic_dim` coordinates are the known
162        macroscopic coordinates and the rest are the learned closure coordinates.
163
164        Args:
165            inverse_closure_transform (eqx.Module): transformation from closure coordinates to microscopic state
166            macroscopic_dim (int): the dimension of the known macroscopic state
167        """
168        self.inverse_closure_transform = inverse_closure_transform
169        self.macroscopic_dim = macroscopic_dim
170
171    def __call__(self, z: ArrayLike) -> Array:
172        """Extracts the closure coordinates and applies the inverse closure transformation.
173
174        Args:
175            z (ArrayLike): reduced coordinates
176
177        Returns:
178            Array: reconstructed microscopic state
179        """
180        z_closure = z[self.macroscopic_dim :]
181        return self.inverse_closure_transform(z_closure)
class Encoder(equinox._module.Module):
50class Encoder(eqx.Module):
51    """The base class for encoders."""
52
53    @abstractmethod
54    def __call__(self, x: ArrayLike) -> Array:
55        pass

The base class for encoders.

class Decoder(equinox._module.Module):
58class Decoder(eqx.Module):
59    """The base class for decoders."""
60
61    @abstractmethod
62    def __call__(self, z: ArrayLike) -> Array:
63        pass

The base class for decoders.

class EncoderfromFunc(Encoder):
66class EncoderfromFunc(Encoder):
67    """Encoder constructed from a given closure transform."""
68
69    closure_transform: eqx.Module
70
71    def __init__(self, closure_transform: eqx.Module) -> None:
72        r"""Encoder constructed from a given closure transform.
73
74        Takes a given closure transformation $x\mapsto z$
75        to define the encoder.
76
77        Args:
78            closure_transform (eqx.Module): a given transformation $x\mapsto z$
79        """
80        self.closure_transform = closure_transform
81
82    def __call__(self, x: ArrayLike) -> Array:
83        return self.closure_transform(x)

Encoder constructed from a given closure transform.

EncoderfromFunc(closure_transform: equinox._module.Module)
71    def __init__(self, closure_transform: eqx.Module) -> None:
72        r"""Encoder constructed from a given closure transform.
73
74        Takes a given closure transformation $x\mapsto z$
75        to define the encoder.
76
77        Args:
78            closure_transform (eqx.Module): a given transformation $x\mapsto z$
79        """
80        self.closure_transform = closure_transform

Encoder constructed from a given closure transform.

Takes a given closure transformation $x\mapsto z$ to define the encoder.

Arguments:
  • closure_transform (eqx.Module): a given transformation $x\mapsto z$
closure_transform: equinox._module.Module
class DecoderfromFunc(Decoder):
 86class DecoderfromFunc(Decoder):
 87    """Decoder constructed from a given inverse closure transform."""
 88
 89    inverse_closure_transform: eqx.Module
 90
 91    def __init__(self, inverse_closure_transform: eqx.Module) -> None:
 92        r"""Decoder constructed from a given inverse closure transform.
 93
 94        Takes a given inverse closure transformation $z\mapsto x$
 95
 96        Args:
 97            inverse_closure_transform (eqx.Module): a given transformation $z\mapsto x$
 98        """
 99        self.inverse_closure_transform = inverse_closure_transform
100
101    def __call__(self, z: ArrayLike) -> Array:
102        return self.inverse_closure_transform(z)

Decoder constructed from a given inverse closure transform.

DecoderfromFunc(inverse_closure_transform: equinox._module.Module)
91    def __init__(self, inverse_closure_transform: eqx.Module) -> None:
92        r"""Decoder constructed from a given inverse closure transform.
93
94        Takes a given inverse closure transformation $z\mapsto x$
95
96        Args:
97            inverse_closure_transform (eqx.Module): a given transformation $z\mapsto x$
98        """
99        self.inverse_closure_transform = inverse_closure_transform

Decoder constructed from a given inverse closure transform.

Takes a given inverse closure transformation $z\mapsto x$

Arguments:
  • inverse_closure_transform (eqx.Module): a given transformation $z\mapsto x$
inverse_closure_transform: equinox._module.Module
class ClosureEncoder(EncoderfromFunc):
105class ClosureEncoder(EncoderfromFunc):
106    """Closure encoder which combines known macroscopic coordinates
107    with learned (or PCA) closure coordinates.
108    """
109
110    macroscopic_transform: eqx.Module
111
112    def __init__(
113        self, macroscopic_transform: eqx.Module, closure_transform: eqx.Module
114    ) -> None:
115        r"""Closure encoder which combines known macroscopic coordinates
116        with learned (or PCA) closure coordinates.
117
118        $$
119            x \mapsto z = [\varphi^*(x), \hat\varphi(x)]
120        $$
121
122        where $\varphi^*$ is the known macroscopic transformation and
123        $\hat\varphi$ is the learned closure transformation.
124
125        Args:
126            macroscopic_transform (eqx.Module): the known macroscopic transformation
127            closure_transform (eqx.Module): the learned closure transformation
128        """
129        self.macroscopic_transform = macroscopic_transform
130        self.closure_transform = closure_transform
131
132    def __call__(self, x: ArrayLike) -> Array:
133        """Combines the macroscopic and closure coordinates.
134
135        Args:
136            x (ArrayLike): miroscopic state
137
138        Returns:
139            Array: reduced state
140        """
141        macroscopic_coords = self.macroscopic_transform(x)
142        closure_coords = self.closure_transform(x)
143        reduced_coords = jnp.concatenate([macroscopic_coords, closure_coords])
144        return reduced_coords

Closure encoder which combines known macroscopic coordinates with learned (or PCA) closure coordinates.

ClosureEncoder( macroscopic_transform: equinox._module.Module, closure_transform: equinox._module.Module)
112    def __init__(
113        self, macroscopic_transform: eqx.Module, closure_transform: eqx.Module
114    ) -> None:
115        r"""Closure encoder which combines known macroscopic coordinates
116        with learned (or PCA) closure coordinates.
117
118        $$
119            x \mapsto z = [\varphi^*(x), \hat\varphi(x)]
120        $$
121
122        where $\varphi^*$ is the known macroscopic transformation and
123        $\hat\varphi$ is the learned closure transformation.
124
125        Args:
126            macroscopic_transform (eqx.Module): the known macroscopic transformation
127            closure_transform (eqx.Module): the learned closure transformation
128        """
129        self.macroscopic_transform = macroscopic_transform
130        self.closure_transform = closure_transform

Closure encoder which combines known macroscopic coordinates with learned (or PCA) closure coordinates.

$$ x \mapsto z = [\varphi^*(x), \hat\varphi(x)] $$

where $\varphi^*$ is the known macroscopic transformation and $\hat\varphi$ is the learned closure transformation.

Arguments:
  • macroscopic_transform (eqx.Module): the known macroscopic transformation
  • closure_transform (eqx.Module): the learned closure transformation
macroscopic_transform: equinox._module.Module
closure_transform
class ClosureDecoder(DecoderfromFunc):
147class ClosureDecoder(DecoderfromFunc):
148    """Decodes from a closure encoder model output."""
149
150    macroscopic_dim: int
151
152    def __init__(
153        self, inverse_closure_transform: eqx.Module, macroscopic_dim: int
154    ) -> None:
155        r"""Closure decoder which extracts the closure coordinates from the reduced state
156        and then applies the inverse closure transformation to reconstruct the microscopic state.
157
158        $$
159            z[\text{macroscopic_dim}:] \mapsto x
160        $$
161
162        It is assuemd that the first `macroscopic_dim` coordinates are the known
163        macroscopic coordinates and the rest are the learned closure coordinates.
164
165        Args:
166            inverse_closure_transform (eqx.Module): transformation from closure coordinates to microscopic state
167            macroscopic_dim (int): the dimension of the known macroscopic state
168        """
169        self.inverse_closure_transform = inverse_closure_transform
170        self.macroscopic_dim = macroscopic_dim
171
172    def __call__(self, z: ArrayLike) -> Array:
173        """Extracts the closure coordinates and applies the inverse closure transformation.
174
175        Args:
176            z (ArrayLike): reduced coordinates
177
178        Returns:
179            Array: reconstructed microscopic state
180        """
181        z_closure = z[self.macroscopic_dim :]
182        return self.inverse_closure_transform(z_closure)

Decodes from a closure encoder model output.

ClosureDecoder( inverse_closure_transform: equinox._module.Module, macroscopic_dim: int)
152    def __init__(
153        self, inverse_closure_transform: eqx.Module, macroscopic_dim: int
154    ) -> None:
155        r"""Closure decoder which extracts the closure coordinates from the reduced state
156        and then applies the inverse closure transformation to reconstruct the microscopic state.
157
158        $$
159            z[\text{macroscopic_dim}:] \mapsto x
160        $$
161
162        It is assuemd that the first `macroscopic_dim` coordinates are the known
163        macroscopic coordinates and the rest are the learned closure coordinates.
164
165        Args:
166            inverse_closure_transform (eqx.Module): transformation from closure coordinates to microscopic state
167            macroscopic_dim (int): the dimension of the known macroscopic state
168        """
169        self.inverse_closure_transform = inverse_closure_transform
170        self.macroscopic_dim = macroscopic_dim

Closure decoder which extracts the closure coordinates from the reduced state and then applies the inverse closure transformation to reconstruct the microscopic state.

$$ z[\text{macroscopic_dim}:] \mapsto x $$

It is assuemd that the first macroscopic_dim coordinates are the known macroscopic coordinates and the rest are the learned closure coordinates.

Arguments:
  • inverse_closure_transform (eqx.Module): transformation from closure coordinates to microscopic state
  • macroscopic_dim (int): the dimension of the known macroscopic state
macroscopic_dim: int
inverse_closure_transform