onsagernet.transformations
Dimensionality transformations
This module contains classes for dimensionality transformations, such as encoders and decoders. These are to be used in the context where reduced dynamics is sought after, but only microscopic data is available.
The main classes are Encoder
and Decoder
, which are abstract classes that should be
implemented by the user. The ClosureEncoder
and ClosureDecoder
classes are concrete
implementations of the Encoder
and Decoder
classes, respectively.
They are used to encode and decode the data into a reduced space,
after which the reduced dynamics can be trained.
In this case the reduced space consist of both known macroscopic coordinates
and learned closure coordinates.
The Encoder
and Decoder
objects are used in onsagernet.dynamics.ReducedSDE
to be used for closure modelling, for example.
from onsagernet.transformations import ClosureEncoder, ClosureDecoder
from onsagernet.dynamics = ReducedSDE, OnsagerNet
encoder = ClosureEncoder(...)
decoder = ClosureDecoder(...)
sde = OnsagerNet(...)
reduced_sde = ReducedSDE(encoder, decoder, sde) # can be used to train, predict, etc
As in onsagernet.dynamics
, only the model assembly logic is provided here.
Some example implementations of the model architecture are provided
in the onsagernet.models
module.
1""" 2# Dimensionality transformations 3 4This module contains classes for dimensionality transformations, such as encoders and decoders. 5These are to be used in the context where reduced dynamics is sought after, 6but only microscopic data is available. 7 8The main classes are `Encoder` and `Decoder`, which are abstract classes that should be 9implemented by the user. The `ClosureEncoder` and `ClosureDecoder` classes are concrete 10implementations of the `Encoder` and `Decoder` classes, respectively. 11They are used to encode and decode the data into a reduced space, 12after which the reduced dynamics can be trained. 13In this case the reduced space consist of both known macroscopic coordinates 14and learned closure coordinates. 15 16The `Encoder` and `Decoder` objects are used in `onsagernet.dynamics.ReducedSDE` 17to be used for closure modelling, for example. 18 19```python 20from onsagernet.transformations import ClosureEncoder, ClosureDecoder 21from onsagernet.dynamics = ReducedSDE, OnsagerNet 22 23encoder = ClosureEncoder(...) 24decoder = ClosureDecoder(...) 25sde = OnsagerNet(...) 26 27reduced_sde = ReducedSDE(encoder, decoder, sde) # can be used to train, predict, etc 28``` 29 30As in `onsagernet.dynamics`, only the model assembly logic is provided here. 31Some example implementations of the model architecture are provided 32in the `onsagernet.models` module. 33 34""" 35 36import equinox as eqx 37import jax.numpy as jnp 38from abc import abstractmethod 39 40# ------------------------- Typing imports ------------------------- # 41from jax import Array 42from jax.typing import ArrayLike 43 44# ------------------------------------------------------------------ # 45# Encoders and decoders # 46# ------------------------------------------------------------------ # 47 48 49class Encoder(eqx.Module): 50 """The base class for encoders.""" 51 52 @abstractmethod 53 def __call__(self, x: ArrayLike) -> Array: 54 pass 55 56 57class Decoder(eqx.Module): 58 """The base class for decoders.""" 59 60 @abstractmethod 61 def __call__(self, z: ArrayLike) -> Array: 62 pass 63 64 65class EncoderfromFunc(Encoder): 66 """Encoder constructed from a given closure transform.""" 67 68 closure_transform: eqx.Module 69 70 def __init__(self, closure_transform: eqx.Module) -> None: 71 r"""Encoder constructed from a given closure transform. 72 73 Takes a given closure transformation $x\mapsto z$ 74 to define the encoder. 75 76 Args: 77 closure_transform (eqx.Module): a given transformation $x\mapsto z$ 78 """ 79 self.closure_transform = closure_transform 80 81 def __call__(self, x: ArrayLike) -> Array: 82 return self.closure_transform(x) 83 84 85class DecoderfromFunc(Decoder): 86 """Decoder constructed from a given inverse closure transform.""" 87 88 inverse_closure_transform: eqx.Module 89 90 def __init__(self, inverse_closure_transform: eqx.Module) -> None: 91 r"""Decoder constructed from a given inverse closure transform. 92 93 Takes a given inverse closure transformation $z\mapsto x$ 94 95 Args: 96 inverse_closure_transform (eqx.Module): a given transformation $z\mapsto x$ 97 """ 98 self.inverse_closure_transform = inverse_closure_transform 99 100 def __call__(self, z: ArrayLike) -> Array: 101 return self.inverse_closure_transform(z) 102 103 104class ClosureEncoder(EncoderfromFunc): 105 """Closure encoder which combines known macroscopic coordinates 106 with learned (or PCA) closure coordinates. 107 """ 108 109 macroscopic_transform: eqx.Module 110 111 def __init__( 112 self, macroscopic_transform: eqx.Module, closure_transform: eqx.Module 113 ) -> None: 114 r"""Closure encoder which combines known macroscopic coordinates 115 with learned (or PCA) closure coordinates. 116 117 $$ 118 x \mapsto z = [\varphi^*(x), \hat\varphi(x)] 119 $$ 120 121 where $\varphi^*$ is the known macroscopic transformation and 122 $\hat\varphi$ is the learned closure transformation. 123 124 Args: 125 macroscopic_transform (eqx.Module): the known macroscopic transformation 126 closure_transform (eqx.Module): the learned closure transformation 127 """ 128 self.macroscopic_transform = macroscopic_transform 129 self.closure_transform = closure_transform 130 131 def __call__(self, x: ArrayLike) -> Array: 132 """Combines the macroscopic and closure coordinates. 133 134 Args: 135 x (ArrayLike): miroscopic state 136 137 Returns: 138 Array: reduced state 139 """ 140 macroscopic_coords = self.macroscopic_transform(x) 141 closure_coords = self.closure_transform(x) 142 reduced_coords = jnp.concatenate([macroscopic_coords, closure_coords]) 143 return reduced_coords 144 145 146class ClosureDecoder(DecoderfromFunc): 147 """Decodes from a closure encoder model output.""" 148 149 macroscopic_dim: int 150 151 def __init__( 152 self, inverse_closure_transform: eqx.Module, macroscopic_dim: int 153 ) -> None: 154 r"""Closure decoder which extracts the closure coordinates from the reduced state 155 and then applies the inverse closure transformation to reconstruct the microscopic state. 156 157 $$ 158 z[\text{macroscopic_dim}:] \mapsto x 159 $$ 160 161 It is assuemd that the first `macroscopic_dim` coordinates are the known 162 macroscopic coordinates and the rest are the learned closure coordinates. 163 164 Args: 165 inverse_closure_transform (eqx.Module): transformation from closure coordinates to microscopic state 166 macroscopic_dim (int): the dimension of the known macroscopic state 167 """ 168 self.inverse_closure_transform = inverse_closure_transform 169 self.macroscopic_dim = macroscopic_dim 170 171 def __call__(self, z: ArrayLike) -> Array: 172 """Extracts the closure coordinates and applies the inverse closure transformation. 173 174 Args: 175 z (ArrayLike): reduced coordinates 176 177 Returns: 178 Array: reconstructed microscopic state 179 """ 180 z_closure = z[self.macroscopic_dim :] 181 return self.inverse_closure_transform(z_closure)
50class Encoder(eqx.Module): 51 """The base class for encoders.""" 52 53 @abstractmethod 54 def __call__(self, x: ArrayLike) -> Array: 55 pass
The base class for encoders.
58class Decoder(eqx.Module): 59 """The base class for decoders.""" 60 61 @abstractmethod 62 def __call__(self, z: ArrayLike) -> Array: 63 pass
The base class for decoders.
66class EncoderfromFunc(Encoder): 67 """Encoder constructed from a given closure transform.""" 68 69 closure_transform: eqx.Module 70 71 def __init__(self, closure_transform: eqx.Module) -> None: 72 r"""Encoder constructed from a given closure transform. 73 74 Takes a given closure transformation $x\mapsto z$ 75 to define the encoder. 76 77 Args: 78 closure_transform (eqx.Module): a given transformation $x\mapsto z$ 79 """ 80 self.closure_transform = closure_transform 81 82 def __call__(self, x: ArrayLike) -> Array: 83 return self.closure_transform(x)
Encoder constructed from a given closure transform.
71 def __init__(self, closure_transform: eqx.Module) -> None: 72 r"""Encoder constructed from a given closure transform. 73 74 Takes a given closure transformation $x\mapsto z$ 75 to define the encoder. 76 77 Args: 78 closure_transform (eqx.Module): a given transformation $x\mapsto z$ 79 """ 80 self.closure_transform = closure_transform
Encoder constructed from a given closure transform.
Takes a given closure transformation $x\mapsto z$ to define the encoder.
Arguments:
- closure_transform (eqx.Module): a given transformation $x\mapsto z$
86class DecoderfromFunc(Decoder): 87 """Decoder constructed from a given inverse closure transform.""" 88 89 inverse_closure_transform: eqx.Module 90 91 def __init__(self, inverse_closure_transform: eqx.Module) -> None: 92 r"""Decoder constructed from a given inverse closure transform. 93 94 Takes a given inverse closure transformation $z\mapsto x$ 95 96 Args: 97 inverse_closure_transform (eqx.Module): a given transformation $z\mapsto x$ 98 """ 99 self.inverse_closure_transform = inverse_closure_transform 100 101 def __call__(self, z: ArrayLike) -> Array: 102 return self.inverse_closure_transform(z)
Decoder constructed from a given inverse closure transform.
91 def __init__(self, inverse_closure_transform: eqx.Module) -> None: 92 r"""Decoder constructed from a given inverse closure transform. 93 94 Takes a given inverse closure transformation $z\mapsto x$ 95 96 Args: 97 inverse_closure_transform (eqx.Module): a given transformation $z\mapsto x$ 98 """ 99 self.inverse_closure_transform = inverse_closure_transform
Decoder constructed from a given inverse closure transform.
Takes a given inverse closure transformation $z\mapsto x$
Arguments:
- inverse_closure_transform (eqx.Module): a given transformation $z\mapsto x$
105class ClosureEncoder(EncoderfromFunc): 106 """Closure encoder which combines known macroscopic coordinates 107 with learned (or PCA) closure coordinates. 108 """ 109 110 macroscopic_transform: eqx.Module 111 112 def __init__( 113 self, macroscopic_transform: eqx.Module, closure_transform: eqx.Module 114 ) -> None: 115 r"""Closure encoder which combines known macroscopic coordinates 116 with learned (or PCA) closure coordinates. 117 118 $$ 119 x \mapsto z = [\varphi^*(x), \hat\varphi(x)] 120 $$ 121 122 where $\varphi^*$ is the known macroscopic transformation and 123 $\hat\varphi$ is the learned closure transformation. 124 125 Args: 126 macroscopic_transform (eqx.Module): the known macroscopic transformation 127 closure_transform (eqx.Module): the learned closure transformation 128 """ 129 self.macroscopic_transform = macroscopic_transform 130 self.closure_transform = closure_transform 131 132 def __call__(self, x: ArrayLike) -> Array: 133 """Combines the macroscopic and closure coordinates. 134 135 Args: 136 x (ArrayLike): miroscopic state 137 138 Returns: 139 Array: reduced state 140 """ 141 macroscopic_coords = self.macroscopic_transform(x) 142 closure_coords = self.closure_transform(x) 143 reduced_coords = jnp.concatenate([macroscopic_coords, closure_coords]) 144 return reduced_coords
Closure encoder which combines known macroscopic coordinates with learned (or PCA) closure coordinates.
112 def __init__( 113 self, macroscopic_transform: eqx.Module, closure_transform: eqx.Module 114 ) -> None: 115 r"""Closure encoder which combines known macroscopic coordinates 116 with learned (or PCA) closure coordinates. 117 118 $$ 119 x \mapsto z = [\varphi^*(x), \hat\varphi(x)] 120 $$ 121 122 where $\varphi^*$ is the known macroscopic transformation and 123 $\hat\varphi$ is the learned closure transformation. 124 125 Args: 126 macroscopic_transform (eqx.Module): the known macroscopic transformation 127 closure_transform (eqx.Module): the learned closure transformation 128 """ 129 self.macroscopic_transform = macroscopic_transform 130 self.closure_transform = closure_transform
Closure encoder which combines known macroscopic coordinates with learned (or PCA) closure coordinates.
$$ x \mapsto z = [\varphi^*(x), \hat\varphi(x)] $$
where $\varphi^*$ is the known macroscopic transformation and $\hat\varphi$ is the learned closure transformation.
Arguments:
- macroscopic_transform (eqx.Module): the known macroscopic transformation
- closure_transform (eqx.Module): the learned closure transformation
147class ClosureDecoder(DecoderfromFunc): 148 """Decodes from a closure encoder model output.""" 149 150 macroscopic_dim: int 151 152 def __init__( 153 self, inverse_closure_transform: eqx.Module, macroscopic_dim: int 154 ) -> None: 155 r"""Closure decoder which extracts the closure coordinates from the reduced state 156 and then applies the inverse closure transformation to reconstruct the microscopic state. 157 158 $$ 159 z[\text{macroscopic_dim}:] \mapsto x 160 $$ 161 162 It is assuemd that the first `macroscopic_dim` coordinates are the known 163 macroscopic coordinates and the rest are the learned closure coordinates. 164 165 Args: 166 inverse_closure_transform (eqx.Module): transformation from closure coordinates to microscopic state 167 macroscopic_dim (int): the dimension of the known macroscopic state 168 """ 169 self.inverse_closure_transform = inverse_closure_transform 170 self.macroscopic_dim = macroscopic_dim 171 172 def __call__(self, z: ArrayLike) -> Array: 173 """Extracts the closure coordinates and applies the inverse closure transformation. 174 175 Args: 176 z (ArrayLike): reduced coordinates 177 178 Returns: 179 Array: reconstructed microscopic state 180 """ 181 z_closure = z[self.macroscopic_dim :] 182 return self.inverse_closure_transform(z_closure)
Decodes from a closure encoder model output.
152 def __init__( 153 self, inverse_closure_transform: eqx.Module, macroscopic_dim: int 154 ) -> None: 155 r"""Closure decoder which extracts the closure coordinates from the reduced state 156 and then applies the inverse closure transformation to reconstruct the microscopic state. 157 158 $$ 159 z[\text{macroscopic_dim}:] \mapsto x 160 $$ 161 162 It is assuemd that the first `macroscopic_dim` coordinates are the known 163 macroscopic coordinates and the rest are the learned closure coordinates. 164 165 Args: 166 inverse_closure_transform (eqx.Module): transformation from closure coordinates to microscopic state 167 macroscopic_dim (int): the dimension of the known macroscopic state 168 """ 169 self.inverse_closure_transform = inverse_closure_transform 170 self.macroscopic_dim = macroscopic_dim
Closure decoder which extracts the closure coordinates from the reduced state and then applies the inverse closure transformation to reconstruct the microscopic state.
$$ z[\text{macroscopic_dim}:] \mapsto x $$
It is assuemd that the first macroscopic_dim
coordinates are the known
macroscopic coordinates and the rest are the learned closure coordinates.
Arguments:
- inverse_closure_transform (eqx.Module): transformation from closure coordinates to microscopic state
- macroscopic_dim (int): the dimension of the known macroscopic state