onsagernet.dynamics
Dynamical models
This module defines the dynamical models used in the OnsagerNet and related architectures. Note that the actual model components' architectures are not defined here, but rather it is the assembly logic for the model components that is handled here.
The following is an example of how to use the OnsagerNet
model
to create the stochastic OnsagerNet dynamics
$$ dX(t) = - \left[ M(X(t)) + W(x(t)) \right] \nabla V(x(t), u(t)) dt + \sqrt{\epsilon} \sigma(x(t), u(t)) dW(t) \qquad X(t) \in \mathbb{R}^d, \quad u(t) \in \mathbb{R}^m. $$
import equinox as eqx
from onsagernet.dynamics import OnsagerNet
class MyPotential(eqx.Module):
"""Implement your V function here
This should be a function (d + m) -> (1)
"""
class MyDissipation(eqx.Module):
"""Implement your M function here
This should be a function (d) -> (d, d)
"""
class MyConservation(eqx.Module):
"""Implement your W function here
This should be a function (d) -> (d, d)
"""
class MyDiffusion(eqx.Module):
"""Implement your sigma function here
This should be a function (d + m) -> (d, d)
"""
potential = MyPotential()
dissipation = MyDissipation()
conservation = MyConservation()
diffusion = MyDiffusion()
sde = OnsagerNet(
potential=potential,
dissipation=dissipation,
conservation=conservation,
diffusion=diffusion,
)
The sde
instance can then be used to simulate the dynamics
of the system or perform training.
- Some simple definition of the potential, dissipation, conservation, and diffusion functions are
provided in
onsagernet.models
- The
ReducedSDE
class includes both anSDE
component and a dimensionality reduction component involving both anonsagernet.transformations.Encoder
and aonsagernet.transformations.Decoder
- Standard training routines are provided in
onsagernet.trainers
1r''' 2# Dynamical models 3 4This module defines the dynamical models used in the OnsagerNet and related architectures. 5Note that the actual model components' architectures are not defined here, but rather it is the 6assembly logic for the model components that is handled here. 7 8The following is an example of how to use the `OnsagerNet` model 9to create the stochastic OnsagerNet dynamics 10 11$$ 12 dX(t) = - 13 \left[ 14 M(X(t)) + W(x(t)) 15 \right] \nabla V(x(t), u(t)) dt 16 + \sqrt{\epsilon} \sigma(x(t), u(t)) dW(t) 17 \qquad 18 X(t) \in \mathbb{R}^d, \quad u(t) \in \mathbb{R}^m. 19$$ 20 21```python 22import equinox as eqx 23from onsagernet.dynamics import OnsagerNet 24 25class MyPotential(eqx.Module): 26 """Implement your V function here 27 28 This should be a function (d + m) -> (1) 29 """ 30 31class MyDissipation(eqx.Module): 32 """Implement your M function here 33 34 This should be a function (d) -> (d, d) 35 """ 36 37class MyConservation(eqx.Module): 38 """Implement your W function here 39 40 This should be a function (d) -> (d, d) 41 """ 42 43class MyDiffusion(eqx.Module): 44 """Implement your sigma function here 45 46 This should be a function (d + m) -> (d, d) 47 """ 48 49 50potential = MyPotential() 51dissipation = MyDissipation() 52conservation = MyConservation() 53diffusion = MyDiffusion() 54 55sde = OnsagerNet( 56 potential=potential, 57 dissipation=dissipation, 58 conservation=conservation, 59 diffusion=diffusion, 60) 61``` 62 63The `sde` instance can then be used to simulate the dynamics 64of the system or perform training. 65 66- Some simple definition of the potential, dissipation, conservation, and diffusion functions are 67 provided in `onsagernet.models` 68- The `ReducedSDE` class includes both an `SDE` component and a dimensionality reduction component 69 involving both an `onsagernet.transformations.Encoder` and a `onsagernet.transformations.Decoder` 70- Standard training routines are provided in `onsagernet.trainers` 71''' 72 73import jax 74import jax.numpy as jnp 75import equinox as eqx 76from abc import abstractmethod 77 78# ------------------------- Typing imports ------------------------- # 79 80from typing import Callable 81from jax.typing import ArrayLike 82from jax import Array 83from .transformations import Encoder, Decoder 84 85DynamicCallable = Callable[[ArrayLike, ArrayLike, ArrayLike], Array] 86 87# ------------------------------------------------------------------ # 88# SDE # 89# ------------------------------------------------------------------ # 90 91 92class SDE(eqx.Module): 93 """Base class for stochastic differential equations models.""" 94 95 @abstractmethod 96 def drift(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array: 97 """Drift function 98 99 Args: 100 t (ArrayLike): time 101 x (ArrayLike): state 102 args (ArrayLike): additional arguments or parameters 103 104 Returns: 105 Array: drift vector field 106 """ 107 pass 108 109 @abstractmethod 110 def diffusion(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array: 111 """Diffusion function 112 113 Args: 114 t (ArrayLike): time 115 x (ArrayLike): state 116 args (ArrayLike): additional arguments or parameters 117 118 Returns: 119 Array: diffusion matrix of size (state_dim, bm_dim) 120 """ 121 pass 122 123 124class SDEfromFunc(SDE): 125 """SDE model defined by providing drift and diffusion functions.""" 126 127 _drift_func: DynamicCallable 128 _diffusion_func: DynamicCallable 129 130 def __init__( 131 self, drift_func: DynamicCallable, diffusion_func: DynamicCallable 132 ) -> None: 133 r"""SDE model defined by providing drift and diffusion functions. 134 135 This implements the dynamics 136 $$ 137 dX(t) = f(t, X(t), u(t)) dt + g(t, X(t), u(t)) dW(t) 138 $$ 139 where $f$ is the drift function and $g$ is the diffusion function. 140 The `args` argument (which represents $u(t)$) 141 is used to pass additional parameters to the drift and diffusion functions. 142 143 Args: 144 drift_func (Callable[[ArrayLike, ArrayLike, ArrayLike], Array]): provided drift function 145 diffusion_func (Callable[[ArrayLike, ArrayLike, ArrayLike], Array]): provided diffusion function 146 """ 147 148 self._drift_func = drift_func 149 self._diffusion_func = diffusion_func 150 151 def drift(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array: 152 """Drift function 153 154 Args: 155 t (ArrayLike): time 156 x (ArrayLike): state 157 args (ArrayLike): additional arguments or parameters 158 159 Returns: 160 Array: drift vector field 161 """ 162 return self._drift_func(t, x, args) 163 164 def diffusion(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array: 165 """Diffusion function 166 167 Args: 168 t (ArrayLike): time 169 x (ArrayLike): state 170 args (ArrayLike): addional arguments or parameters 171 172 Returns: 173 Array: diffusion matrix 174 """ 175 return self._diffusion_func(t, x, args) 176 177 178class ReducedSDE(eqx.Module): 179 """SDE model with encoder and decoder with dimensionality reduction or closure modelling.""" 180 181 encoder: Encoder 182 decoder: Decoder 183 sde: SDE 184 185 def __init__(self, encoder: Encoder, decoder: Decoder, sde: SDE) -> None: 186 """SDE model with encoder and decoder with dimensionality reduction or closure modelling. 187 188 The `sde` attribute can be any model of the [SDE](#SDE) class or its sub-classes. 189 The `encoder` must be a (sub-)class of `onsagernet.transformations.Encoder` and 190 the `decoder` must be a (sub-)class of `onsagernet.transformations.Decoder`. 191 192 Args: 193 encoder (Encoder): The encoder function mapping the microscopic state to the reduced state 194 decoder (Decoder): The decoder function mapping the reduced state to the microscopic state 195 sde (SDE): The stochastic dynamics for the reduced state 196 """ 197 self.encoder = encoder 198 self.decoder = decoder 199 self.sde = sde 200 201 202# ------------------------------------------------------------------ # 203# OnsagerNet # 204# ------------------------------------------------------------------ # 205 206 207class OnsagerNet(SDE): 208 potential: eqx.Module 209 dissipation: eqx.Module 210 conservation: eqx.Module 211 diffusion_func: eqx.Module 212 213 def __init__( 214 self, 215 potential: eqx.Module, 216 dissipation: eqx.Module, 217 conservation: eqx.Module, 218 diffusion: eqx.Module, 219 ) -> None: 220 r"""Stochastic OnsagerNet model. 221 222 Let $X(t) \in \mathbb{R}^d$. The Stochastic OnsagerNet model is defined by the SDE 223 $$ 224 dX(t) = - 225 \left[ 226 M(X(t)) + W(x(t)) 227 \right] \nabla V(x(t), u(t)) dt 228 + \sqrt{\epsilon} \sigma(x(t), u(t)) dW(t) 229 $$ 230 where 231 232 - $M : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the dissipation matrix, which is symmetric positive semi-definite for all $x$ 233 - $W : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the conservation matrix, which is anti-symmetric for all $x$ 234 - $V : \mathbb{R}^{d} \to \mathbb{R}$ is the potential function 235 - $\sigma: \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the (square) diffusion matrix 236 - $u(t)$ are the additional parameters for the potential and diffusion functions, and note that **the first dimension of $u(t)$ is the temperature $\epsilon$** 237 238 Args: 239 potential (eqx.Module): potential function $V$ 240 dissipation (eqx.Module): dissipation matrix $M$ 241 conservation (eqx.Module): conservation matrix $W$ 242 diffusion (eqx.Module): diffusion matrix $\sigma$ 243 """ 244 self.potential = potential 245 self.dissipation = dissipation 246 self.conservation = conservation 247 self.diffusion_func = diffusion 248 249 def drift(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array: 250 """Drift function 251 252 Args: 253 t (ArrayLike): time 254 x (ArrayLike): state 255 args (ArrayLike): additional arguments or parameters, the first element is the temperature 256 257 Returns: 258 Array: drift vector field 259 """ 260 dvdx = jax.grad(self.potential, argnums=0)(x, args) 261 return -(self.dissipation(x) + self.conservation(x)) @ dvdx 262 263 def diffusion(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array: 264 """Diffusion function 265 266 Args: 267 t (ArrayLike): time 268 x (ArrayLike): state 269 args (ArrayLike): additional arguments or parameters, the first element is the temperature 270 271 Returns: 272 Array: diffusion matrix 273 """ 274 temperature = args[0] 275 return jnp.sqrt(temperature) * self.diffusion_func(x, args) 276 277 278# ------------------------------------------------------------------ # 279# OnsagerNet satifying fluctuation-dissipation relation # 280# ------------------------------------------------------------------ # 281 282 283class OnsagerNetFD(OnsagerNet): 284 shared: eqx.nn.Shared 285 286 def __init__( 287 self, potential: eqx.Module, dissipation: eqx.Module, conservation: eqx.Module 288 ) -> None: 289 r"""Stochastic OnsagerNet model satisfying a fluctuation-dissipation relation. 290 291 This is a modified version of the Stochastic OnsagerNet model. 292 Let $X(t) \in \mathbb{R}^d$. This model is defined by the SDE 293 $$ 294 dX(t) = - 295 \left[ 296 M(X(t)) + W(x(t)) 297 \right] \nabla V(x(t), u(t)) dt 298 + \sqrt{2 \epsilon} [M(x(t)]^\frac{1}{2}dW(t) 299 $$ 300 where 301 302 - $M : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the dissipation matrix, 303 which is symmetric positive semi-definite for all $x$ 304 - $W : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the conservation matrix, 305 which is anti-symmetric for all $x$ 306 - $V : \mathbb{R}^{d} \to \mathbb{R}$ is the potential function 307 - $u(t)$ are the additional parameters for the potential and diffusion functions, 308 and note that **the first dimension of $u(t)$ is the temperature $\epsilon$** 309 310 Notice that the main difference with `OnsagerNet` is that the 311 diffusion matrix is now given by a (positive semi-definite) square root of the dissipation matrix. 312 313 Args: 314 potential (eqx.Module): potential function $V$ 315 dissipation (eqx.Module): dissipation matrix $M$ 316 conservation (eqx.Module): conservation matrix $W$ 317 """ 318 self.potential = potential 319 self.conservation = conservation 320 self.diffusion_func = None 321 322 # Share the dissipation module 323 dissipation_drift = dissipation 324 dissipation_diffusion = dissipation 325 where = lambda shared_layers: shared_layers[0] 326 get = lambda shared_layers: shared_layers[1] 327 self.shared = eqx.nn.Shared( 328 (dissipation_drift, dissipation_diffusion), where, get 329 ) 330 331 @property 332 def dissipation(self) -> eqx.Module: 333 """Dissipation matrix wrapper 334 335 Returns: 336 eqx.Module: dissipation matrix module 337 """ 338 return self.shared()[0] 339 340 def _matrix_div(self, M: eqx.Module, x: ArrayLike) -> Array: 341 r"""Computes the matrix divergence of a matrix function $M(x)$. 342 343 This is defined in component form as 344 $$ 345 [\nabla \cdot M(x)]_i = \sum_j \frac{\partial M_{ij}}{\partial x_j}. 346 $$ 347 348 Args: 349 M (eqx.Module): matrix function 350 x (ArrayLike): state 351 352 Returns: 353 Array: \nabla \cdot M(x) 354 """ 355 jac_M_x = jax.jacfwd(M)(x) 356 return jnp.trace(jac_M_x, axis1=1, axis2=2) 357 358 def drift(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array: 359 """Drift function 360 361 Args: 362 t (ArrayLike): time 363 x (ArrayLike): state 364 args (ArrayLike): additional arguments or parameters, the first element is the temperature 365 366 Returns: 367 Array: drift vector field 368 """ 369 temperature = args[0] 370 dissipation = self.shared()[0] 371 f1 = super().drift(t, x, args) 372 f2 = temperature * self._matrix_div(dissipation, x) 373 return f1 + f2 374 375 def diffusion(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array: 376 """Diffusion function 377 378 Args: 379 t (ArrayLike): time 380 x (ArrayLike): state 381 args (ArrayLike): additional arguments or parameters, the first element is the temperature 382 383 Returns: 384 Array: diffusion matrix 385 """ 386 temperature = args[0] 387 dissipation = self.shared()[1] 388 M_x = dissipation(x) 389 sqrt_M_x = jnp.linalg.cholesky(M_x) 390 return jnp.sqrt(2.0 * temperature) * sqrt_M_x
93class SDE(eqx.Module): 94 """Base class for stochastic differential equations models.""" 95 96 @abstractmethod 97 def drift(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array: 98 """Drift function 99 100 Args: 101 t (ArrayLike): time 102 x (ArrayLike): state 103 args (ArrayLike): additional arguments or parameters 104 105 Returns: 106 Array: drift vector field 107 """ 108 pass 109 110 @abstractmethod 111 def diffusion(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array: 112 """Diffusion function 113 114 Args: 115 t (ArrayLike): time 116 x (ArrayLike): state 117 args (ArrayLike): additional arguments or parameters 118 119 Returns: 120 Array: diffusion matrix of size (state_dim, bm_dim) 121 """ 122 pass
Base class for stochastic differential equations models.
125class SDEfromFunc(SDE): 126 """SDE model defined by providing drift and diffusion functions.""" 127 128 _drift_func: DynamicCallable 129 _diffusion_func: DynamicCallable 130 131 def __init__( 132 self, drift_func: DynamicCallable, diffusion_func: DynamicCallable 133 ) -> None: 134 r"""SDE model defined by providing drift and diffusion functions. 135 136 This implements the dynamics 137 $$ 138 dX(t) = f(t, X(t), u(t)) dt + g(t, X(t), u(t)) dW(t) 139 $$ 140 where $f$ is the drift function and $g$ is the diffusion function. 141 The `args` argument (which represents $u(t)$) 142 is used to pass additional parameters to the drift and diffusion functions. 143 144 Args: 145 drift_func (Callable[[ArrayLike, ArrayLike, ArrayLike], Array]): provided drift function 146 diffusion_func (Callable[[ArrayLike, ArrayLike, ArrayLike], Array]): provided diffusion function 147 """ 148 149 self._drift_func = drift_func 150 self._diffusion_func = diffusion_func 151 152 def drift(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array: 153 """Drift function 154 155 Args: 156 t (ArrayLike): time 157 x (ArrayLike): state 158 args (ArrayLike): additional arguments or parameters 159 160 Returns: 161 Array: drift vector field 162 """ 163 return self._drift_func(t, x, args) 164 165 def diffusion(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array: 166 """Diffusion function 167 168 Args: 169 t (ArrayLike): time 170 x (ArrayLike): state 171 args (ArrayLike): addional arguments or parameters 172 173 Returns: 174 Array: diffusion matrix 175 """ 176 return self._diffusion_func(t, x, args)
SDE model defined by providing drift and diffusion functions.
131 def __init__( 132 self, drift_func: DynamicCallable, diffusion_func: DynamicCallable 133 ) -> None: 134 r"""SDE model defined by providing drift and diffusion functions. 135 136 This implements the dynamics 137 $$ 138 dX(t) = f(t, X(t), u(t)) dt + g(t, X(t), u(t)) dW(t) 139 $$ 140 where $f$ is the drift function and $g$ is the diffusion function. 141 The `args` argument (which represents $u(t)$) 142 is used to pass additional parameters to the drift and diffusion functions. 143 144 Args: 145 drift_func (Callable[[ArrayLike, ArrayLike, ArrayLike], Array]): provided drift function 146 diffusion_func (Callable[[ArrayLike, ArrayLike, ArrayLike], Array]): provided diffusion function 147 """ 148 149 self._drift_func = drift_func 150 self._diffusion_func = diffusion_func
SDE model defined by providing drift and diffusion functions.
This implements the dynamics
$$
dX(t) = f(t, X(t), u(t)) dt + g(t, X(t), u(t)) dW(t)
$$
where $f$ is the drift function and $g$ is the diffusion function.
The args
argument (which represents $u(t)$)
is used to pass additional parameters to the drift and diffusion functions.
Arguments:
- drift_func (Callable[[ArrayLike, ArrayLike, ArrayLike], Array]): provided drift function
- diffusion_func (Callable[[ArrayLike, ArrayLike, ArrayLike], Array]): provided diffusion function
179class ReducedSDE(eqx.Module): 180 """SDE model with encoder and decoder with dimensionality reduction or closure modelling.""" 181 182 encoder: Encoder 183 decoder: Decoder 184 sde: SDE 185 186 def __init__(self, encoder: Encoder, decoder: Decoder, sde: SDE) -> None: 187 """SDE model with encoder and decoder with dimensionality reduction or closure modelling. 188 189 The `sde` attribute can be any model of the [SDE](#SDE) class or its sub-classes. 190 The `encoder` must be a (sub-)class of `onsagernet.transformations.Encoder` and 191 the `decoder` must be a (sub-)class of `onsagernet.transformations.Decoder`. 192 193 Args: 194 encoder (Encoder): The encoder function mapping the microscopic state to the reduced state 195 decoder (Decoder): The decoder function mapping the reduced state to the microscopic state 196 sde (SDE): The stochastic dynamics for the reduced state 197 """ 198 self.encoder = encoder 199 self.decoder = decoder 200 self.sde = sde
SDE model with encoder and decoder with dimensionality reduction or closure modelling.
186 def __init__(self, encoder: Encoder, decoder: Decoder, sde: SDE) -> None: 187 """SDE model with encoder and decoder with dimensionality reduction or closure modelling. 188 189 The `sde` attribute can be any model of the [SDE](#SDE) class or its sub-classes. 190 The `encoder` must be a (sub-)class of `onsagernet.transformations.Encoder` and 191 the `decoder` must be a (sub-)class of `onsagernet.transformations.Decoder`. 192 193 Args: 194 encoder (Encoder): The encoder function mapping the microscopic state to the reduced state 195 decoder (Decoder): The decoder function mapping the reduced state to the microscopic state 196 sde (SDE): The stochastic dynamics for the reduced state 197 """ 198 self.encoder = encoder 199 self.decoder = decoder 200 self.sde = sde
SDE model with encoder and decoder with dimensionality reduction or closure modelling.
The sde
attribute can be any model of the SDE class or its sub-classes.
The encoder
must be a (sub-)class of onsagernet.transformations.Encoder
and
the decoder
must be a (sub-)class of onsagernet.transformations.Decoder
.
Arguments:
- encoder (Encoder): The encoder function mapping the microscopic state to the reduced state
- decoder (Decoder): The decoder function mapping the reduced state to the microscopic state
- sde (SDE): The stochastic dynamics for the reduced state
208class OnsagerNet(SDE): 209 potential: eqx.Module 210 dissipation: eqx.Module 211 conservation: eqx.Module 212 diffusion_func: eqx.Module 213 214 def __init__( 215 self, 216 potential: eqx.Module, 217 dissipation: eqx.Module, 218 conservation: eqx.Module, 219 diffusion: eqx.Module, 220 ) -> None: 221 r"""Stochastic OnsagerNet model. 222 223 Let $X(t) \in \mathbb{R}^d$. The Stochastic OnsagerNet model is defined by the SDE 224 $$ 225 dX(t) = - 226 \left[ 227 M(X(t)) + W(x(t)) 228 \right] \nabla V(x(t), u(t)) dt 229 + \sqrt{\epsilon} \sigma(x(t), u(t)) dW(t) 230 $$ 231 where 232 233 - $M : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the dissipation matrix, which is symmetric positive semi-definite for all $x$ 234 - $W : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the conservation matrix, which is anti-symmetric for all $x$ 235 - $V : \mathbb{R}^{d} \to \mathbb{R}$ is the potential function 236 - $\sigma: \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the (square) diffusion matrix 237 - $u(t)$ are the additional parameters for the potential and diffusion functions, and note that **the first dimension of $u(t)$ is the temperature $\epsilon$** 238 239 Args: 240 potential (eqx.Module): potential function $V$ 241 dissipation (eqx.Module): dissipation matrix $M$ 242 conservation (eqx.Module): conservation matrix $W$ 243 diffusion (eqx.Module): diffusion matrix $\sigma$ 244 """ 245 self.potential = potential 246 self.dissipation = dissipation 247 self.conservation = conservation 248 self.diffusion_func = diffusion 249 250 def drift(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array: 251 """Drift function 252 253 Args: 254 t (ArrayLike): time 255 x (ArrayLike): state 256 args (ArrayLike): additional arguments or parameters, the first element is the temperature 257 258 Returns: 259 Array: drift vector field 260 """ 261 dvdx = jax.grad(self.potential, argnums=0)(x, args) 262 return -(self.dissipation(x) + self.conservation(x)) @ dvdx 263 264 def diffusion(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array: 265 """Diffusion function 266 267 Args: 268 t (ArrayLike): time 269 x (ArrayLike): state 270 args (ArrayLike): additional arguments or parameters, the first element is the temperature 271 272 Returns: 273 Array: diffusion matrix 274 """ 275 temperature = args[0] 276 return jnp.sqrt(temperature) * self.diffusion_func(x, args)
214 def __init__( 215 self, 216 potential: eqx.Module, 217 dissipation: eqx.Module, 218 conservation: eqx.Module, 219 diffusion: eqx.Module, 220 ) -> None: 221 r"""Stochastic OnsagerNet model. 222 223 Let $X(t) \in \mathbb{R}^d$. The Stochastic OnsagerNet model is defined by the SDE 224 $$ 225 dX(t) = - 226 \left[ 227 M(X(t)) + W(x(t)) 228 \right] \nabla V(x(t), u(t)) dt 229 + \sqrt{\epsilon} \sigma(x(t), u(t)) dW(t) 230 $$ 231 where 232 233 - $M : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the dissipation matrix, which is symmetric positive semi-definite for all $x$ 234 - $W : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the conservation matrix, which is anti-symmetric for all $x$ 235 - $V : \mathbb{R}^{d} \to \mathbb{R}$ is the potential function 236 - $\sigma: \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the (square) diffusion matrix 237 - $u(t)$ are the additional parameters for the potential and diffusion functions, and note that **the first dimension of $u(t)$ is the temperature $\epsilon$** 238 239 Args: 240 potential (eqx.Module): potential function $V$ 241 dissipation (eqx.Module): dissipation matrix $M$ 242 conservation (eqx.Module): conservation matrix $W$ 243 diffusion (eqx.Module): diffusion matrix $\sigma$ 244 """ 245 self.potential = potential 246 self.dissipation = dissipation 247 self.conservation = conservation 248 self.diffusion_func = diffusion
Stochastic OnsagerNet model.
Let $X(t) \in \mathbb{R}^d$. The Stochastic OnsagerNet model is defined by the SDE $$ dX(t) = - \left[ M(X(t)) + W(x(t)) \right] \nabla V(x(t), u(t)) dt + \sqrt{\epsilon} \sigma(x(t), u(t)) dW(t) $$ where
- $M : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the dissipation matrix, which is symmetric positive semi-definite for all $x$
- $W : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the conservation matrix, which is anti-symmetric for all $x$
- $V : \mathbb{R}^{d} \to \mathbb{R}$ is the potential function
- $\sigma: \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the (square) diffusion matrix
- $u(t)$ are the additional parameters for the potential and diffusion functions, and note that the first dimension of $u(t)$ is the temperature $\epsilon$
Arguments:
- potential (eqx.Module): potential function $V$
- dissipation (eqx.Module): dissipation matrix $M$
- conservation (eqx.Module): conservation matrix $W$
- diffusion (eqx.Module): diffusion matrix $\sigma$
284class OnsagerNetFD(OnsagerNet): 285 shared: eqx.nn.Shared 286 287 def __init__( 288 self, potential: eqx.Module, dissipation: eqx.Module, conservation: eqx.Module 289 ) -> None: 290 r"""Stochastic OnsagerNet model satisfying a fluctuation-dissipation relation. 291 292 This is a modified version of the Stochastic OnsagerNet model. 293 Let $X(t) \in \mathbb{R}^d$. This model is defined by the SDE 294 $$ 295 dX(t) = - 296 \left[ 297 M(X(t)) + W(x(t)) 298 \right] \nabla V(x(t), u(t)) dt 299 + \sqrt{2 \epsilon} [M(x(t)]^\frac{1}{2}dW(t) 300 $$ 301 where 302 303 - $M : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the dissipation matrix, 304 which is symmetric positive semi-definite for all $x$ 305 - $W : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the conservation matrix, 306 which is anti-symmetric for all $x$ 307 - $V : \mathbb{R}^{d} \to \mathbb{R}$ is the potential function 308 - $u(t)$ are the additional parameters for the potential and diffusion functions, 309 and note that **the first dimension of $u(t)$ is the temperature $\epsilon$** 310 311 Notice that the main difference with `OnsagerNet` is that the 312 diffusion matrix is now given by a (positive semi-definite) square root of the dissipation matrix. 313 314 Args: 315 potential (eqx.Module): potential function $V$ 316 dissipation (eqx.Module): dissipation matrix $M$ 317 conservation (eqx.Module): conservation matrix $W$ 318 """ 319 self.potential = potential 320 self.conservation = conservation 321 self.diffusion_func = None 322 323 # Share the dissipation module 324 dissipation_drift = dissipation 325 dissipation_diffusion = dissipation 326 where = lambda shared_layers: shared_layers[0] 327 get = lambda shared_layers: shared_layers[1] 328 self.shared = eqx.nn.Shared( 329 (dissipation_drift, dissipation_diffusion), where, get 330 ) 331 332 @property 333 def dissipation(self) -> eqx.Module: 334 """Dissipation matrix wrapper 335 336 Returns: 337 eqx.Module: dissipation matrix module 338 """ 339 return self.shared()[0] 340 341 def _matrix_div(self, M: eqx.Module, x: ArrayLike) -> Array: 342 r"""Computes the matrix divergence of a matrix function $M(x)$. 343 344 This is defined in component form as 345 $$ 346 [\nabla \cdot M(x)]_i = \sum_j \frac{\partial M_{ij}}{\partial x_j}. 347 $$ 348 349 Args: 350 M (eqx.Module): matrix function 351 x (ArrayLike): state 352 353 Returns: 354 Array: \nabla \cdot M(x) 355 """ 356 jac_M_x = jax.jacfwd(M)(x) 357 return jnp.trace(jac_M_x, axis1=1, axis2=2) 358 359 def drift(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array: 360 """Drift function 361 362 Args: 363 t (ArrayLike): time 364 x (ArrayLike): state 365 args (ArrayLike): additional arguments or parameters, the first element is the temperature 366 367 Returns: 368 Array: drift vector field 369 """ 370 temperature = args[0] 371 dissipation = self.shared()[0] 372 f1 = super().drift(t, x, args) 373 f2 = temperature * self._matrix_div(dissipation, x) 374 return f1 + f2 375 376 def diffusion(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array: 377 """Diffusion function 378 379 Args: 380 t (ArrayLike): time 381 x (ArrayLike): state 382 args (ArrayLike): additional arguments or parameters, the first element is the temperature 383 384 Returns: 385 Array: diffusion matrix 386 """ 387 temperature = args[0] 388 dissipation = self.shared()[1] 389 M_x = dissipation(x) 390 sqrt_M_x = jnp.linalg.cholesky(M_x) 391 return jnp.sqrt(2.0 * temperature) * sqrt_M_x
287 def __init__( 288 self, potential: eqx.Module, dissipation: eqx.Module, conservation: eqx.Module 289 ) -> None: 290 r"""Stochastic OnsagerNet model satisfying a fluctuation-dissipation relation. 291 292 This is a modified version of the Stochastic OnsagerNet model. 293 Let $X(t) \in \mathbb{R}^d$. This model is defined by the SDE 294 $$ 295 dX(t) = - 296 \left[ 297 M(X(t)) + W(x(t)) 298 \right] \nabla V(x(t), u(t)) dt 299 + \sqrt{2 \epsilon} [M(x(t)]^\frac{1}{2}dW(t) 300 $$ 301 where 302 303 - $M : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the dissipation matrix, 304 which is symmetric positive semi-definite for all $x$ 305 - $W : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the conservation matrix, 306 which is anti-symmetric for all $x$ 307 - $V : \mathbb{R}^{d} \to \mathbb{R}$ is the potential function 308 - $u(t)$ are the additional parameters for the potential and diffusion functions, 309 and note that **the first dimension of $u(t)$ is the temperature $\epsilon$** 310 311 Notice that the main difference with `OnsagerNet` is that the 312 diffusion matrix is now given by a (positive semi-definite) square root of the dissipation matrix. 313 314 Args: 315 potential (eqx.Module): potential function $V$ 316 dissipation (eqx.Module): dissipation matrix $M$ 317 conservation (eqx.Module): conservation matrix $W$ 318 """ 319 self.potential = potential 320 self.conservation = conservation 321 self.diffusion_func = None 322 323 # Share the dissipation module 324 dissipation_drift = dissipation 325 dissipation_diffusion = dissipation 326 where = lambda shared_layers: shared_layers[0] 327 get = lambda shared_layers: shared_layers[1] 328 self.shared = eqx.nn.Shared( 329 (dissipation_drift, dissipation_diffusion), where, get 330 )
Stochastic OnsagerNet model satisfying a fluctuation-dissipation relation.
This is a modified version of the Stochastic OnsagerNet model. Let $X(t) \in \mathbb{R}^d$. This model is defined by the SDE $$ dX(t) = - \left[ M(X(t)) + W(x(t)) \right] \nabla V(x(t), u(t)) dt + \sqrt{2 \epsilon} [M(x(t)]^\frac{1}{2}dW(t) $$ where
- $M : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the dissipation matrix, which is symmetric positive semi-definite for all $x$
- $W : \mathbb{R}^{d} \to \mathbb{R}^{d\times d}$ is the conservation matrix, which is anti-symmetric for all $x$
- $V : \mathbb{R}^{d} \to \mathbb{R}$ is the potential function
- $u(t)$ are the additional parameters for the potential and diffusion functions, and note that the first dimension of $u(t)$ is the temperature $\epsilon$
Notice that the main difference with OnsagerNet
is that the
diffusion matrix is now given by a (positive semi-definite) square root of the dissipation matrix.
Arguments:
- potential (eqx.Module): potential function $V$
- dissipation (eqx.Module): dissipation matrix $M$
- conservation (eqx.Module): conservation matrix $W$
332 @property 333 def dissipation(self) -> eqx.Module: 334 """Dissipation matrix wrapper 335 336 Returns: 337 eqx.Module: dissipation matrix module 338 """ 339 return self.shared()[0]
Dissipation matrix wrapper
Returns:
eqx.Module: dissipation matrix module