onsagernet.dynamics

Dynamical models

This module defines the dynamical models used in the OnsagerNet and related architectures. Note that the actual model components' architectures are not defined here, but rather it is the assembly logic for the model components that is handled here.

The following is an example of how to use the OnsagerNet model to create the stochastic OnsagerNet dynamics

$$ dX(t) = - \left[ M(X(t)) + W(x(t)) \right] \nabla V(x(t), u(t)) dt + \sqrt{\epsilon} \sigma(x(t), u(t)) dW(t) \qquad X(t) \in \mathbb{R}^d, \quad u(t) \in \mathbb{R}^m. $$

import equinox as eqx
from onsagernet.dynamics import OnsagerNet

class MyPotential(eqx.Module):
    """Implement your V function here

    This should be a function (d + m) -> (1)
    """

class MyDissipation(eqx.Module):
    """Implement your M function here

    This should be a function (d) -> (d, d)
    """

class MyConservation(eqx.Module):
    """Implement your W function here

    This should be a function (d) -> (d, d)
    """

class MyDiffusion(eqx.Module):
    """Implement your sigma function here

    This should be a function (d + m) -> (d, d)
    """


potential = MyPotential()
dissipation = MyDissipation()
conservation = MyConservation()
diffusion = MyDiffusion()

sde = OnsagerNet(
    potential=potential,
    dissipation=dissipation,
    conservation=conservation,
    diffusion=diffusion,
)

The sde instance can then be used to simulate the dynamics of the system or perform training.

  1r'''
  2# Dynamical models
  3
  4This module defines the dynamical models used in the OnsagerNet and related architectures.
  5Note that the actual model components' architectures are not defined here, but rather it is the
  6assembly logic for the model components that is handled here.
  7
  8The following is an example of how to use the `OnsagerNet` model
  9to create the stochastic OnsagerNet dynamics
 10
 11$$
 12    dX(t) = -
 13    \left[
 14        M(X(t)) + W(x(t))
 15    \right] \nabla V(x(t), u(t)) dt
 16    + \sqrt{\epsilon} \sigma(x(t), u(t)) dW(t)
 17    \qquad
 18    X(t) \in \mathbb{R}^d, \quad u(t) \in \mathbb{R}^m.
 19$$
 20
 21```python
 22import equinox as eqx
 23from onsagernet.dynamics import OnsagerNet
 24
 25class MyPotential(eqx.Module):
 26    """Implement your V function here
 27
 28    This should be a function (d + m) -> (1)
 29    """
 30
 31class MyDissipation(eqx.Module):
 32    """Implement your M function here
 33
 34    This should be a function (d) -> (d, d)
 35    """
 36
 37class MyConservation(eqx.Module):
 38    """Implement your W function here
 39
 40    This should be a function (d) -> (d, d)
 41    """
 42
 43class MyDiffusion(eqx.Module):
 44    """Implement your sigma function here
 45
 46    This should be a function (d + m) -> (d, d)
 47    """
 48
 49
 50potential = MyPotential()
 51dissipation = MyDissipation()
 52conservation = MyConservation()
 53diffusion = MyDiffusion()
 54
 55sde = OnsagerNet(
 56    potential=potential,
 57    dissipation=dissipation,
 58    conservation=conservation,
 59    diffusion=diffusion,
 60)
 61```
 62
 63The `sde` instance can then be used to simulate the dynamics
 64of the system or perform training.
 65
 66- Some simple definition of the potential, dissipation, conservation, and diffusion functions are
 67  provided in `onsagernet.models`
 68- The `ReducedSDE` class includes both an `SDE` component and a dimensionality reduction component
 69  involving both an `onsagernet.transformations.Encoder` and a `onsagernet.transformations.Decoder`
 70- Standard training routines are provided in `onsagernet.trainers`
 71'''
 72
 73import jax
 74import jax.numpy as jnp
 75import equinox as eqx
 76from abc import abstractmethod
 77
 78# ------------------------- Typing imports ------------------------- #
 79
 80from typing import Callable
 81from jax.typing import ArrayLike
 82from jax import Array
 83from .transformations import Encoder, Decoder
 84
 85DynamicCallable = Callable[[ArrayLike, ArrayLike, ArrayLike], Array]
 86
 87# ------------------------------------------------------------------ #
 88#                                 SDE                                #
 89# ------------------------------------------------------------------ #
 90
 91
 92class SDE(eqx.Module):
 93    """Base class for stochastic differential equations models."""
 94
 95    @abstractmethod
 96    def drift(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array:
 97        """Drift function
 98
 99        Args:
100            t (ArrayLike): time
101            x (ArrayLike): state
102            args (ArrayLike): additional arguments or parameters
103
104        Returns:
105            Array: drift vector field
106        """
107        pass
108
109    @abstractmethod
110    def diffusion(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array:
111        """Diffusion function
112
113        Args:
114            t (ArrayLike): time
115            x (ArrayLike): state
116            args (ArrayLike): additional arguments or parameters
117
118        Returns:
119            Array: diffusion matrix of size (state_dim, bm_dim)
120        """
121        pass
122
123
124class SDEfromFunc(SDE):
125    """SDE model defined by providing drift and diffusion functions."""
126
127    _drift_func: DynamicCallable
128    _diffusion_func: DynamicCallable
129
130    def __init__(
131        self, drift_func: DynamicCallable, diffusion_func: DynamicCallable
132    ) -> None:
133        r"""SDE model defined by providing drift and diffusion functions.
134
135        This implements the dynamics
136        $$
137            dX(t) = f(t, X(t), u(t)) dt + g(t, X(t), u(t)) dW(t)
138        $$
139        where $f$ is the drift function and $g$ is the diffusion function.
140        The `args` argument (which represents $u(t)$)
141        is used to pass additional parameters to the drift and diffusion functions.
142
143        Args:
144            drift_func (Callable[[ArrayLike, ArrayLike, ArrayLike], Array]): provided drift function
145            diffusion_func (Callable[[ArrayLike, ArrayLike, ArrayLike], Array]): provided diffusion function
146        """
147
148        self._drift_func = drift_func
149        self._diffusion_func = diffusion_func
150
151    def drift(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array:
152        """Drift function
153
154        Args:
155            t (ArrayLike): time
156            x (ArrayLike): state
157            args (ArrayLike): additional arguments or parameters
158
159        Returns:
160            Array: drift vector field
161        """
162        return self._drift_func(t, x, args)
163
164    def diffusion(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array:
165        """Diffusion function
166
167        Args:
168            t (ArrayLike): time
169            x (ArrayLike): state
170            args (ArrayLike): addional arguments or parameters
171
172        Returns:
173            Array: diffusion matrix
174        """
175        return self._diffusion_func(t, x, args)
176
177
178class ReducedSDE(eqx.Module):
179    """SDE model with encoder and decoder with dimensionality reduction or closure modelling."""
180
181    encoder: Encoder
182    decoder: Decoder
183    sde: SDE
184
185    def __init__(self, encoder: Encoder, decoder: Decoder, sde: SDE) -> None:
186        """SDE model with encoder and decoder with dimensionality reduction or closure modelling.
187
188        The `sde` attribute can be any model of the [SDE](#SDE) class or its sub-classes.
189        The `encoder` must be a (sub-)class of `onsagernet.transformations.Encoder` and
190        the `decoder` must be a (sub-)class of `onsagernet.transformations.Decoder`.
191
192        Args:
193            encoder (Encoder): The encoder function mapping the microscopic state to the reduced state
194            decoder (Decoder): The decoder function mapping the reduced state to the microscopic state
195            sde (SDE): The stochastic dynamics for the reduced state
196        """
197        self.encoder = encoder
198        self.decoder = decoder
199        self.sde = sde
200
201
202# ------------------------------------------------------------------ #
203#                             OnsagerNet                             #
204# ------------------------------------------------------------------ #
205
206
207class OnsagerNet(SDE):
208    potential: eqx.Module
209    dissipation: eqx.Module
210    conservation: eqx.Module
211    diffusion_func: eqx.Module
212
213    def __init__(
214        self,
215        potential: eqx.Module,
216        dissipation: eqx.Module,
217        conservation: eqx.Module,
218        diffusion: eqx.Module,
219    ) -> None:
220        r"""Stochastic OnsagerNet model.
221
222        Let $X(t) \in \mathbb{R}^d$. The Stochastic OnsagerNet model is defined by the SDE
223        $$
224            dX(t) = -
225            \left[
226                M(X(t)) + W(x(t))
227            \right] \nabla V(x(t), u(t)) dt
228            + \sqrt{\epsilon} \sigma(x(t), u(t)) dW(t)
229        $$
230        where
231
232        - $M : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the dissipation matrix, which is symmetric positive semi-definite for all $x$
233        - $W : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the conservation matrix, which is anti-symmetric for all $x$
234        - $V : \mathbb{R}^{d} \to \mathbb{R}$ is the potential function
235        - $\sigma: \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the (square) diffusion matrix
236        - $u(t)$ are the additional parameters for the potential and diffusion functions, and note that **the first dimension of $u(t)$ is the temperature $\epsilon$**
237
238        Args:
239            potential (eqx.Module): potential function $V$
240            dissipation (eqx.Module): dissipation matrix $M$
241            conservation (eqx.Module): conservation matrix $W$
242            diffusion (eqx.Module): diffusion matrix $\sigma$
243        """
244        self.potential = potential
245        self.dissipation = dissipation
246        self.conservation = conservation
247        self.diffusion_func = diffusion
248
249    def drift(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array:
250        """Drift function
251
252        Args:
253            t (ArrayLike): time
254            x (ArrayLike): state
255            args (ArrayLike): additional arguments or parameters, the first element is the temperature
256
257        Returns:
258            Array: drift vector field
259        """
260        dvdx = jax.grad(self.potential, argnums=0)(x, args)
261        return -(self.dissipation(x) + self.conservation(x)) @ dvdx
262
263    def diffusion(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array:
264        """Diffusion function
265
266        Args:
267            t (ArrayLike): time
268            x (ArrayLike): state
269            args (ArrayLike): additional arguments or parameters, the first element is the temperature
270
271        Returns:
272            Array: diffusion matrix
273        """
274        temperature = args[0]
275        return jnp.sqrt(temperature) * self.diffusion_func(x, args)
276
277
278# ------------------------------------------------------------------ #
279#        OnsagerNet satifying fluctuation-dissipation relation       #
280# ------------------------------------------------------------------ #
281
282
283class OnsagerNetFD(OnsagerNet):
284    shared: eqx.nn.Shared
285
286    def __init__(
287        self, potential: eqx.Module, dissipation: eqx.Module, conservation: eqx.Module
288    ) -> None:
289        r"""Stochastic OnsagerNet model satisfying a fluctuation-dissipation relation.
290
291        This is a modified version of the Stochastic OnsagerNet model.
292        Let $X(t) \in \mathbb{R}^d$. This model is defined by the SDE
293        $$
294            dX(t) = -
295            \left[
296                M(X(t)) + W(x(t))
297            \right] \nabla V(x(t), u(t)) dt
298            + \sqrt{2 \epsilon} [M(x(t)]^\frac{1}{2}dW(t)
299        $$
300        where
301
302        - $M : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the dissipation matrix,
303          which is symmetric positive semi-definite for all $x$
304        - $W : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the conservation matrix,
305          which is anti-symmetric for all $x$
306        - $V : \mathbb{R}^{d} \to \mathbb{R}$ is the potential function
307        - $u(t)$ are the additional parameters for the potential and diffusion functions,
308          and note that **the first dimension of $u(t)$ is the temperature $\epsilon$**
309
310        Notice that the main difference with `OnsagerNet` is that the
311        diffusion matrix is now given by a (positive semi-definite) square root of the dissipation matrix.
312
313        Args:
314            potential (eqx.Module): potential function $V$
315            dissipation (eqx.Module): dissipation matrix $M$
316            conservation (eqx.Module): conservation matrix $W$
317        """
318        self.potential = potential
319        self.conservation = conservation
320        self.diffusion_func = None
321
322        # Share the dissipation module
323        dissipation_drift = dissipation
324        dissipation_diffusion = dissipation
325        where = lambda shared_layers: shared_layers[0]
326        get = lambda shared_layers: shared_layers[1]
327        self.shared = eqx.nn.Shared(
328            (dissipation_drift, dissipation_diffusion), where, get
329        )
330
331    @property
332    def dissipation(self) -> eqx.Module:
333        """Dissipation matrix wrapper
334
335        Returns:
336            eqx.Module: dissipation matrix module
337        """
338        return self.shared()[0]
339
340    def _matrix_div(self, M: eqx.Module, x: ArrayLike) -> Array:
341        r"""Computes the matrix divergence of a matrix function $M(x)$.
342
343        This is defined in component form as
344        $$
345            [\nabla \cdot M(x)]_i = \sum_j \frac{\partial M_{ij}}{\partial x_j}.
346        $$
347
348        Args:
349            M (eqx.Module): matrix function
350            x (ArrayLike): state
351
352        Returns:
353            Array: \nabla \cdot M(x)
354        """
355        jac_M_x = jax.jacfwd(M)(x)
356        return jnp.trace(jac_M_x, axis1=1, axis2=2)
357
358    def drift(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array:
359        """Drift function
360
361        Args:
362            t (ArrayLike): time
363            x (ArrayLike): state
364            args (ArrayLike): additional arguments or parameters, the first element is the temperature
365
366        Returns:
367            Array: drift vector field
368        """
369        temperature = args[0]
370        dissipation = self.shared()[0]
371        f1 = super().drift(t, x, args)
372        f2 = temperature * self._matrix_div(dissipation, x)
373        return f1 + f2
374
375    def diffusion(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array:
376        """Diffusion function
377
378        Args:
379            t (ArrayLike): time
380            x (ArrayLike): state
381            args (ArrayLike): additional arguments or parameters, the first element is the temperature
382
383        Returns:
384            Array: diffusion matrix
385        """
386        temperature = args[0]
387        dissipation = self.shared()[1]
388        M_x = dissipation(x)
389        sqrt_M_x = jnp.linalg.cholesky(M_x)
390        return jnp.sqrt(2.0 * temperature) * sqrt_M_x
DynamicCallable = typing.Callable[[typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]], jax.Array]
class SDE(equinox._module.Module):
 93class SDE(eqx.Module):
 94    """Base class for stochastic differential equations models."""
 95
 96    @abstractmethod
 97    def drift(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array:
 98        """Drift function
 99
100        Args:
101            t (ArrayLike): time
102            x (ArrayLike): state
103            args (ArrayLike): additional arguments or parameters
104
105        Returns:
106            Array: drift vector field
107        """
108        pass
109
110    @abstractmethod
111    def diffusion(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array:
112        """Diffusion function
113
114        Args:
115            t (ArrayLike): time
116            x (ArrayLike): state
117            args (ArrayLike): additional arguments or parameters
118
119        Returns:
120            Array: diffusion matrix of size (state_dim, bm_dim)
121        """
122        pass

Base class for stochastic differential equations models.

def drift(unknown):

Drift function

Arguments:
  • t (ArrayLike): time
  • x (ArrayLike): state
  • args (ArrayLike): additional arguments or parameters
Returns:

Array: drift vector field

def diffusion(unknown):

Diffusion function

Arguments:
  • t (ArrayLike): time
  • x (ArrayLike): state
  • args (ArrayLike): additional arguments or parameters
Returns:

Array: diffusion matrix of size (state_dim, bm_dim)

class SDEfromFunc(SDE):
125class SDEfromFunc(SDE):
126    """SDE model defined by providing drift and diffusion functions."""
127
128    _drift_func: DynamicCallable
129    _diffusion_func: DynamicCallable
130
131    def __init__(
132        self, drift_func: DynamicCallable, diffusion_func: DynamicCallable
133    ) -> None:
134        r"""SDE model defined by providing drift and diffusion functions.
135
136        This implements the dynamics
137        $$
138            dX(t) = f(t, X(t), u(t)) dt + g(t, X(t), u(t)) dW(t)
139        $$
140        where $f$ is the drift function and $g$ is the diffusion function.
141        The `args` argument (which represents $u(t)$)
142        is used to pass additional parameters to the drift and diffusion functions.
143
144        Args:
145            drift_func (Callable[[ArrayLike, ArrayLike, ArrayLike], Array]): provided drift function
146            diffusion_func (Callable[[ArrayLike, ArrayLike, ArrayLike], Array]): provided diffusion function
147        """
148
149        self._drift_func = drift_func
150        self._diffusion_func = diffusion_func
151
152    def drift(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array:
153        """Drift function
154
155        Args:
156            t (ArrayLike): time
157            x (ArrayLike): state
158            args (ArrayLike): additional arguments or parameters
159
160        Returns:
161            Array: drift vector field
162        """
163        return self._drift_func(t, x, args)
164
165    def diffusion(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array:
166        """Diffusion function
167
168        Args:
169            t (ArrayLike): time
170            x (ArrayLike): state
171            args (ArrayLike): addional arguments or parameters
172
173        Returns:
174            Array: diffusion matrix
175        """
176        return self._diffusion_func(t, x, args)

SDE model defined by providing drift and diffusion functions.

SDEfromFunc( drift_func: Callable[[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]], jax.Array], diffusion_func: Callable[[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]], jax.Array])
131    def __init__(
132        self, drift_func: DynamicCallable, diffusion_func: DynamicCallable
133    ) -> None:
134        r"""SDE model defined by providing drift and diffusion functions.
135
136        This implements the dynamics
137        $$
138            dX(t) = f(t, X(t), u(t)) dt + g(t, X(t), u(t)) dW(t)
139        $$
140        where $f$ is the drift function and $g$ is the diffusion function.
141        The `args` argument (which represents $u(t)$)
142        is used to pass additional parameters to the drift and diffusion functions.
143
144        Args:
145            drift_func (Callable[[ArrayLike, ArrayLike, ArrayLike], Array]): provided drift function
146            diffusion_func (Callable[[ArrayLike, ArrayLike, ArrayLike], Array]): provided diffusion function
147        """
148
149        self._drift_func = drift_func
150        self._diffusion_func = diffusion_func

SDE model defined by providing drift and diffusion functions.

This implements the dynamics $$ dX(t) = f(t, X(t), u(t)) dt + g(t, X(t), u(t)) dW(t) $$ where $f$ is the drift function and $g$ is the diffusion function. The args argument (which represents $u(t)$) is used to pass additional parameters to the drift and diffusion functions.

Arguments:
  • drift_func (Callable[[ArrayLike, ArrayLike, ArrayLike], Array]): provided drift function
  • diffusion_func (Callable[[ArrayLike, ArrayLike, ArrayLike], Array]): provided diffusion function
def drift(unknown):

Drift function

Arguments:
  • t (ArrayLike): time
  • x (ArrayLike): state
  • args (ArrayLike): additional arguments or parameters
Returns:

Array: drift vector field

def diffusion(unknown):

Diffusion function

Arguments:
  • t (ArrayLike): time
  • x (ArrayLike): state
  • args (ArrayLike): addional arguments or parameters
Returns:

Array: diffusion matrix

class ReducedSDE(equinox._module.Module):
179class ReducedSDE(eqx.Module):
180    """SDE model with encoder and decoder with dimensionality reduction or closure modelling."""
181
182    encoder: Encoder
183    decoder: Decoder
184    sde: SDE
185
186    def __init__(self, encoder: Encoder, decoder: Decoder, sde: SDE) -> None:
187        """SDE model with encoder and decoder with dimensionality reduction or closure modelling.
188
189        The `sde` attribute can be any model of the [SDE](#SDE) class or its sub-classes.
190        The `encoder` must be a (sub-)class of `onsagernet.transformations.Encoder` and
191        the `decoder` must be a (sub-)class of `onsagernet.transformations.Decoder`.
192
193        Args:
194            encoder (Encoder): The encoder function mapping the microscopic state to the reduced state
195            decoder (Decoder): The decoder function mapping the reduced state to the microscopic state
196            sde (SDE): The stochastic dynamics for the reduced state
197        """
198        self.encoder = encoder
199        self.decoder = decoder
200        self.sde = sde

SDE model with encoder and decoder with dimensionality reduction or closure modelling.

ReducedSDE( encoder: onsagernet.transformations.Encoder, decoder: onsagernet.transformations.Decoder, sde: SDE)
186    def __init__(self, encoder: Encoder, decoder: Decoder, sde: SDE) -> None:
187        """SDE model with encoder and decoder with dimensionality reduction or closure modelling.
188
189        The `sde` attribute can be any model of the [SDE](#SDE) class or its sub-classes.
190        The `encoder` must be a (sub-)class of `onsagernet.transformations.Encoder` and
191        the `decoder` must be a (sub-)class of `onsagernet.transformations.Decoder`.
192
193        Args:
194            encoder (Encoder): The encoder function mapping the microscopic state to the reduced state
195            decoder (Decoder): The decoder function mapping the reduced state to the microscopic state
196            sde (SDE): The stochastic dynamics for the reduced state
197        """
198        self.encoder = encoder
199        self.decoder = decoder
200        self.sde = sde

SDE model with encoder and decoder with dimensionality reduction or closure modelling.

The sde attribute can be any model of the SDE class or its sub-classes. The encoder must be a (sub-)class of onsagernet.transformations.Encoder and the decoder must be a (sub-)class of onsagernet.transformations.Decoder.

Arguments:
  • encoder (Encoder): The encoder function mapping the microscopic state to the reduced state
  • decoder (Decoder): The decoder function mapping the reduced state to the microscopic state
  • sde (SDE): The stochastic dynamics for the reduced state
sde: SDE
class OnsagerNet(SDE):
208class OnsagerNet(SDE):
209    potential: eqx.Module
210    dissipation: eqx.Module
211    conservation: eqx.Module
212    diffusion_func: eqx.Module
213
214    def __init__(
215        self,
216        potential: eqx.Module,
217        dissipation: eqx.Module,
218        conservation: eqx.Module,
219        diffusion: eqx.Module,
220    ) -> None:
221        r"""Stochastic OnsagerNet model.
222
223        Let $X(t) \in \mathbb{R}^d$. The Stochastic OnsagerNet model is defined by the SDE
224        $$
225            dX(t) = -
226            \left[
227                M(X(t)) + W(x(t))
228            \right] \nabla V(x(t), u(t)) dt
229            + \sqrt{\epsilon} \sigma(x(t), u(t)) dW(t)
230        $$
231        where
232
233        - $M : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the dissipation matrix, which is symmetric positive semi-definite for all $x$
234        - $W : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the conservation matrix, which is anti-symmetric for all $x$
235        - $V : \mathbb{R}^{d} \to \mathbb{R}$ is the potential function
236        - $\sigma: \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the (square) diffusion matrix
237        - $u(t)$ are the additional parameters for the potential and diffusion functions, and note that **the first dimension of $u(t)$ is the temperature $\epsilon$**
238
239        Args:
240            potential (eqx.Module): potential function $V$
241            dissipation (eqx.Module): dissipation matrix $M$
242            conservation (eqx.Module): conservation matrix $W$
243            diffusion (eqx.Module): diffusion matrix $\sigma$
244        """
245        self.potential = potential
246        self.dissipation = dissipation
247        self.conservation = conservation
248        self.diffusion_func = diffusion
249
250    def drift(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array:
251        """Drift function
252
253        Args:
254            t (ArrayLike): time
255            x (ArrayLike): state
256            args (ArrayLike): additional arguments or parameters, the first element is the temperature
257
258        Returns:
259            Array: drift vector field
260        """
261        dvdx = jax.grad(self.potential, argnums=0)(x, args)
262        return -(self.dissipation(x) + self.conservation(x)) @ dvdx
263
264    def diffusion(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array:
265        """Diffusion function
266
267        Args:
268            t (ArrayLike): time
269            x (ArrayLike): state
270            args (ArrayLike): additional arguments or parameters, the first element is the temperature
271
272        Returns:
273            Array: diffusion matrix
274        """
275        temperature = args[0]
276        return jnp.sqrt(temperature) * self.diffusion_func(x, args)
OnsagerNet( potential: equinox._module.Module, dissipation: equinox._module.Module, conservation: equinox._module.Module, diffusion: equinox._module.Module)
214    def __init__(
215        self,
216        potential: eqx.Module,
217        dissipation: eqx.Module,
218        conservation: eqx.Module,
219        diffusion: eqx.Module,
220    ) -> None:
221        r"""Stochastic OnsagerNet model.
222
223        Let $X(t) \in \mathbb{R}^d$. The Stochastic OnsagerNet model is defined by the SDE
224        $$
225            dX(t) = -
226            \left[
227                M(X(t)) + W(x(t))
228            \right] \nabla V(x(t), u(t)) dt
229            + \sqrt{\epsilon} \sigma(x(t), u(t)) dW(t)
230        $$
231        where
232
233        - $M : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the dissipation matrix, which is symmetric positive semi-definite for all $x$
234        - $W : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the conservation matrix, which is anti-symmetric for all $x$
235        - $V : \mathbb{R}^{d} \to \mathbb{R}$ is the potential function
236        - $\sigma: \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the (square) diffusion matrix
237        - $u(t)$ are the additional parameters for the potential and diffusion functions, and note that **the first dimension of $u(t)$ is the temperature $\epsilon$**
238
239        Args:
240            potential (eqx.Module): potential function $V$
241            dissipation (eqx.Module): dissipation matrix $M$
242            conservation (eqx.Module): conservation matrix $W$
243            diffusion (eqx.Module): diffusion matrix $\sigma$
244        """
245        self.potential = potential
246        self.dissipation = dissipation
247        self.conservation = conservation
248        self.diffusion_func = diffusion

Stochastic OnsagerNet model.

Let $X(t) \in \mathbb{R}^d$. The Stochastic OnsagerNet model is defined by the SDE $$ dX(t) = - \left[ M(X(t)) + W(x(t)) \right] \nabla V(x(t), u(t)) dt + \sqrt{\epsilon} \sigma(x(t), u(t)) dW(t) $$ where

  • $M : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the dissipation matrix, which is symmetric positive semi-definite for all $x$
  • $W : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the conservation matrix, which is anti-symmetric for all $x$
  • $V : \mathbb{R}^{d} \to \mathbb{R}$ is the potential function
  • $\sigma: \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the (square) diffusion matrix
  • $u(t)$ are the additional parameters for the potential and diffusion functions, and note that the first dimension of $u(t)$ is the temperature $\epsilon$
Arguments:
  • potential (eqx.Module): potential function $V$
  • dissipation (eqx.Module): dissipation matrix $M$
  • conservation (eqx.Module): conservation matrix $W$
  • diffusion (eqx.Module): diffusion matrix $\sigma$
potential: equinox._module.Module
dissipation: equinox._module.Module
conservation: equinox._module.Module
diffusion_func: equinox._module.Module
def drift(unknown):

Drift function

Arguments:
  • t (ArrayLike): time
  • x (ArrayLike): state
  • args (ArrayLike): additional arguments or parameters, the first element is the temperature
Returns:

Array: drift vector field

def diffusion(unknown):

Diffusion function

Arguments:
  • t (ArrayLike): time
  • x (ArrayLike): state
  • args (ArrayLike): additional arguments or parameters, the first element is the temperature
Returns:

Array: diffusion matrix

class OnsagerNetFD(OnsagerNet):
284class OnsagerNetFD(OnsagerNet):
285    shared: eqx.nn.Shared
286
287    def __init__(
288        self, potential: eqx.Module, dissipation: eqx.Module, conservation: eqx.Module
289    ) -> None:
290        r"""Stochastic OnsagerNet model satisfying a fluctuation-dissipation relation.
291
292        This is a modified version of the Stochastic OnsagerNet model.
293        Let $X(t) \in \mathbb{R}^d$. This model is defined by the SDE
294        $$
295            dX(t) = -
296            \left[
297                M(X(t)) + W(x(t))
298            \right] \nabla V(x(t), u(t)) dt
299            + \sqrt{2 \epsilon} [M(x(t)]^\frac{1}{2}dW(t)
300        $$
301        where
302
303        - $M : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the dissipation matrix,
304          which is symmetric positive semi-definite for all $x$
305        - $W : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the conservation matrix,
306          which is anti-symmetric for all $x$
307        - $V : \mathbb{R}^{d} \to \mathbb{R}$ is the potential function
308        - $u(t)$ are the additional parameters for the potential and diffusion functions,
309          and note that **the first dimension of $u(t)$ is the temperature $\epsilon$**
310
311        Notice that the main difference with `OnsagerNet` is that the
312        diffusion matrix is now given by a (positive semi-definite) square root of the dissipation matrix.
313
314        Args:
315            potential (eqx.Module): potential function $V$
316            dissipation (eqx.Module): dissipation matrix $M$
317            conservation (eqx.Module): conservation matrix $W$
318        """
319        self.potential = potential
320        self.conservation = conservation
321        self.diffusion_func = None
322
323        # Share the dissipation module
324        dissipation_drift = dissipation
325        dissipation_diffusion = dissipation
326        where = lambda shared_layers: shared_layers[0]
327        get = lambda shared_layers: shared_layers[1]
328        self.shared = eqx.nn.Shared(
329            (dissipation_drift, dissipation_diffusion), where, get
330        )
331
332    @property
333    def dissipation(self) -> eqx.Module:
334        """Dissipation matrix wrapper
335
336        Returns:
337            eqx.Module: dissipation matrix module
338        """
339        return self.shared()[0]
340
341    def _matrix_div(self, M: eqx.Module, x: ArrayLike) -> Array:
342        r"""Computes the matrix divergence of a matrix function $M(x)$.
343
344        This is defined in component form as
345        $$
346            [\nabla \cdot M(x)]_i = \sum_j \frac{\partial M_{ij}}{\partial x_j}.
347        $$
348
349        Args:
350            M (eqx.Module): matrix function
351            x (ArrayLike): state
352
353        Returns:
354            Array: \nabla \cdot M(x)
355        """
356        jac_M_x = jax.jacfwd(M)(x)
357        return jnp.trace(jac_M_x, axis1=1, axis2=2)
358
359    def drift(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array:
360        """Drift function
361
362        Args:
363            t (ArrayLike): time
364            x (ArrayLike): state
365            args (ArrayLike): additional arguments or parameters, the first element is the temperature
366
367        Returns:
368            Array: drift vector field
369        """
370        temperature = args[0]
371        dissipation = self.shared()[0]
372        f1 = super().drift(t, x, args)
373        f2 = temperature * self._matrix_div(dissipation, x)
374        return f1 + f2
375
376    def diffusion(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array:
377        """Diffusion function
378
379        Args:
380            t (ArrayLike): time
381            x (ArrayLike): state
382            args (ArrayLike): additional arguments or parameters, the first element is the temperature
383
384        Returns:
385            Array: diffusion matrix
386        """
387        temperature = args[0]
388        dissipation = self.shared()[1]
389        M_x = dissipation(x)
390        sqrt_M_x = jnp.linalg.cholesky(M_x)
391        return jnp.sqrt(2.0 * temperature) * sqrt_M_x
OnsagerNetFD( potential: equinox._module.Module, dissipation: equinox._module.Module, conservation: equinox._module.Module)
287    def __init__(
288        self, potential: eqx.Module, dissipation: eqx.Module, conservation: eqx.Module
289    ) -> None:
290        r"""Stochastic OnsagerNet model satisfying a fluctuation-dissipation relation.
291
292        This is a modified version of the Stochastic OnsagerNet model.
293        Let $X(t) \in \mathbb{R}^d$. This model is defined by the SDE
294        $$
295            dX(t) = -
296            \left[
297                M(X(t)) + W(x(t))
298            \right] \nabla V(x(t), u(t)) dt
299            + \sqrt{2 \epsilon} [M(x(t)]^\frac{1}{2}dW(t)
300        $$
301        where
302
303        - $M : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the dissipation matrix,
304          which is symmetric positive semi-definite for all $x$
305        - $W : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the conservation matrix,
306          which is anti-symmetric for all $x$
307        - $V : \mathbb{R}^{d} \to \mathbb{R}$ is the potential function
308        - $u(t)$ are the additional parameters for the potential and diffusion functions,
309          and note that **the first dimension of $u(t)$ is the temperature $\epsilon$**
310
311        Notice that the main difference with `OnsagerNet` is that the
312        diffusion matrix is now given by a (positive semi-definite) square root of the dissipation matrix.
313
314        Args:
315            potential (eqx.Module): potential function $V$
316            dissipation (eqx.Module): dissipation matrix $M$
317            conservation (eqx.Module): conservation matrix $W$
318        """
319        self.potential = potential
320        self.conservation = conservation
321        self.diffusion_func = None
322
323        # Share the dissipation module
324        dissipation_drift = dissipation
325        dissipation_diffusion = dissipation
326        where = lambda shared_layers: shared_layers[0]
327        get = lambda shared_layers: shared_layers[1]
328        self.shared = eqx.nn.Shared(
329            (dissipation_drift, dissipation_diffusion), where, get
330        )

Stochastic OnsagerNet model satisfying a fluctuation-dissipation relation.

This is a modified version of the Stochastic OnsagerNet model. Let $X(t) \in \mathbb{R}^d$. This model is defined by the SDE $$ dX(t) = - \left[ M(X(t)) + W(x(t)) \right] \nabla V(x(t), u(t)) dt + \sqrt{2 \epsilon} [M(x(t)]^\frac{1}{2}dW(t) $$ where

  • $M : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the dissipation matrix, which is symmetric positive semi-definite for all $x$
  • $W : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the conservation matrix, which is anti-symmetric for all $x$
  • $V : \mathbb{R}^{d} \to \mathbb{R}$ is the potential function
  • $u(t)$ are the additional parameters for the potential and diffusion functions, and note that the first dimension of $u(t)$ is the temperature $\epsilon$

Notice that the main difference with OnsagerNet is that the diffusion matrix is now given by a (positive semi-definite) square root of the dissipation matrix.

Arguments:
  • potential (eqx.Module): potential function $V$
  • dissipation (eqx.Module): dissipation matrix $M$
  • conservation (eqx.Module): conservation matrix $W$
shared: equinox.nn._shared.Shared
potential
conservation
diffusion_func
dissipation: equinox._module.Module
332    @property
333    def dissipation(self) -> eqx.Module:
334        """Dissipation matrix wrapper
335
336        Returns:
337            eqx.Module: dissipation matrix module
338        """
339        return self.shared()[0]

Dissipation matrix wrapper

Returns:

eqx.Module: dissipation matrix module

def drift(unknown):

Drift function

Arguments:
  • t (ArrayLike): time
  • x (ArrayLike): state
  • args (ArrayLike): additional arguments or parameters, the first element is the temperature
Returns:

Array: drift vector field

def diffusion(unknown):

Diffusion function

Arguments:
  • t (ArrayLike): time
  • x (ArrayLike): state
  • args (ArrayLike): additional arguments or parameters, the first element is the temperature
Returns:

Array: diffusion matrix