onsagernet.models
Custom equinox modules for OnsagerNet components
This module contains custom equinox modules for the components of the OnsagerNet model, which are used in various examples provided in the repository.
For new applications, it is suggested to try the simple models here first and then build upon them, adapting if necessary to the specific problem at hand.
1""" 2# Custom equinox modules for OnsagerNet components 3 4This module contains custom equinox modules for the components of the OnsagerNet model, 5which are used in various examples provided in the repository. 6 7For new applications, it is suggested to try the simple models here first 8and then build upon them, adapting if necessary to the specific problem at hand. 9 10 11 12 13 14""" 15 16import jax 17import jax.numpy as jnp 18import equinox as eqx 19 20from ._activations import get_activation 21from ._layers import ConstantLayer 22 23# ------------------------- Typing imports ------------------------- # 24 25from jax import Array 26from jax.typing import ArrayLike 27from typing import Callable 28from jax.random import PRNGKey 29 30 31# ------------------------------------------------------------------ # 32# Template models # 33# ------------------------------------------------------------------ # 34 35 36class MLP(eqx.Module): 37 """Multi-layer perceptron.""" 38 39 layers: list[eqx.nn.Linear] 40 activation: Callable[[ArrayLike], Array] 41 42 def __init__( 43 self, key: PRNGKey, dim: int, units: list[int], activation: str 44 ) -> None: 45 r"""Multi-layer perceptron. 46 47 Example: 48 `mlp = MLP(key=jax.random.PRNGKey(0), dim=2, units=[32, 32, 1], activation='tanh')` 49 gives a 2-hidden-layer MLP 50 51 $$2 \to 32 \to 32 \to 1$$ 52 53 with tanh activation. 54 55 Args: 56 key (PRNGKey): random key 57 dim (int): dimension of the input 58 units (list[int]): layer sizes 59 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 60 """ 61 num_layers = len(units) 62 units = [dim] + units 63 keys = jax.random.split(key, num_layers) 64 self.layers = [ 65 eqx.nn.Linear(units[i], units[i + 1], key=key) for i, key in enumerate(keys) 66 ] 67 self.activation = get_activation(activation) 68 69 def __call__(self, x: ArrayLike) -> Array: 70 h = x 71 for layer in self.layers[:-1]: 72 h = self.activation(layer(h)) 73 output = self.layers[-1](h) 74 return output 75 76 77# ------------------------------------------------------------------ # 78# Potential networks # 79# ------------------------------------------------------------------ # 80 81 82class PotentialMLP(MLP): 83 """Potential network based on a multi-layer perceptron.""" 84 85 alpha: float 86 dim: int 87 param_dim: int 88 89 def __init__( 90 self, 91 key: PRNGKey, 92 dim: int, 93 units: list[int], 94 activation: str, 95 alpha: float, 96 param_dim: int = 0, 97 ) -> None: 98 r"""Potential network based on a multi-layer perceptron. 99 100 This implements the potential function 101 $$ 102 V(x, args) = \alpha \|(x, args)\|^2 + \text{MLP}(x, args) 103 $$ 104 where $x$ is the input and $u$ are additional parameters. 105 The constant $\alpha \geq 0$ is a regularisation term, 106 which gives a quadratic growth to ensure that the potential is integrable. 107 We are tacitly assuming that MLP is of sub-quadratic growth, 108 so only choose activation functions that have this property 109 (most activation functions are either bounded or of linear growth). 110 111 Args: 112 key (PRNGKey): random key 113 dim (int): dimension of the input 114 units (list[int]): layer sizes 115 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 116 alpha (float): regulariser 117 param_dim (int, optional): dimensions of the parameters. Defaults to 0. 118 """ 119 self.dim = dim 120 units = units + [1] 121 self.param_dim = param_dim 122 super().__init__(key, dim + param_dim, units, activation) 123 self.alpha = alpha 124 125 def __call__(self, x: ArrayLike, args: ArrayLike) -> Array: 126 if self.param_dim > 0: 127 x = jnp.concatenate([x, args[1:]], axis=0) 128 output = super().__call__(x) + self.alpha * (x @ x) 129 return jnp.squeeze(output) 130 131 132class PotentialResMLP(MLP): 133 r"""Potential network with a residual connection.""" 134 135 alpha: float 136 gamma_layer: eqx.nn.Linear 137 dim: int 138 param_dim: int 139 140 def __init__( 141 self, 142 key: PRNGKey, 143 dim: int, 144 units: list[int], 145 activation: str, 146 n_pot: int, 147 alpha: float, 148 param_dim: int = 0, 149 ) -> None: 150 r"""Potential network with a residual connection. 151 152 This implements the modified potential function 153 $$ 154 V(x, args) = \alpha \|(x, args)\|^2 155 + \frac{1}{2} 156 \| \text{MLP}(x, args)+ \Gamma (x, args) \|^2 157 $$ 158 where 159 160 - $\phi$ is a MLP of dim + param_dim -> n_pot 161 - $\Gamma$ ia matrix of size [n_pot, dim + para_dim] 162 - $\alpha > 0$ is a scalar regulariser 163 164 Args: 165 key (PRNGKey): random key 166 dim (int): dimension of the input 167 units (list[int]): layer sizes 168 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 169 n_pot (int): size of the MLP part of the potential 170 alpha (float): regulariser 171 param_dim (int, optional): dimension of the parameters. Defaults to 0. 172 """ 173 self.dim = dim 174 self.param_dim = param_dim 175 units = units + [n_pot] 176 mlp_key, gamma_key = jax.random.split(key) 177 super().__init__(mlp_key, dim + param_dim, units, activation) 178 self.alpha = alpha 179 self.gamma_layer = eqx.nn.Linear( 180 dim + param_dim, n_pot, key=gamma_key, use_bias=False 181 ) 182 183 def __call__(self, x: ArrayLike, args: ArrayLike) -> Array: 184 if self.param_dim > 0: 185 x = jnp.concatenate([x, args[1:]], axis=0) 186 output_phi = super().__call__(x) 187 output_gamma = self.gamma_layer(x) 188 output_combined = (output_phi + output_gamma) @ (output_phi + output_gamma) 189 regularisation = self.alpha * (x @ x) 190 return 0.5 * output_combined + regularisation 191 192 193# ------------------------------------------------------------------ # 194# Dissipation networks # 195# ------------------------------------------------------------------ # 196 197 198class DissipationMatrixMLP(MLP): 199 """Dissipation matrix network based on a multi-layer perceptron.""" 200 201 alpha: float 202 is_bounded: bool 203 dim: int 204 205 def __init__( 206 self, 207 key: PRNGKey, 208 dim: int, 209 units: list[int], 210 activation: str, 211 alpha: float, 212 is_bounded: bool = True, 213 ) -> None: 214 r"""Dissipation matrix network based on a multi-layer perceptron. 215 216 The MLP maps $x$ of dimension `dim` to a matrix $L(x)$ of size `dim` x `dim`, 217 and then reshapes it to a `dim` x `dim` matrix. 218 Then, the output matrix is given by 219 $$ 220 M(x) = \alpha I + L(x) L(x)^\top. 221 $$ 222 If `is_bounded` is set to `True`, then the output is element-wise bounded 223 by applying a `jax.nn.tanh` activation to the output matrix $L$. 224 225 Args: 226 key (PRNGKey): random key 227 dim (int): dimension of the input 228 units (list[int]): layer sizes 229 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 230 alpha (float): regulariser 231 is_bounded (bool, optional): whether to give a element-wise bounded output. Defaults to True. 232 """ 233 self.dim = dim 234 units = units + [dim * dim] 235 super().__init__(key, dim, units, activation) 236 self.alpha = alpha 237 self.is_bounded = is_bounded 238 239 def __call__(self, x: ArrayLike) -> Array: 240 L = super().__call__(x).reshape(self.dim, self.dim) 241 if self.is_bounded: 242 L = jax.nn.tanh(L) 243 return self.alpha * jnp.eye(self.dim) + L @ L.T 244 245 246# ------------------------------------------------------------------ # 247# Conservation networks # 248# ------------------------------------------------------------------ # 249 250 251class ConservationMatrixMLP(MLP): 252 """Conservation matrix network based on a multi-layer perceptron.""" 253 254 is_bounded: bool 255 dim: int 256 257 def __init__( 258 self, 259 key: PRNGKey, 260 dim: int, 261 activation: str, 262 units: list[int], 263 is_bounded: bool = True, 264 ) -> None: 265 r"""Conservation matrix network based on a multi-layer perceptron. 266 267 The MLP maps $x$ of dimension `dim` to a matrix $L(x)$ of size `dim` x `dim`, 268 and then reshapes it to a `dim` x `dim` matrix. 269 Then, the output matrix is given by 270 $$ 271 W(x) = L(x) - L(x)^\top. 272 $$ 273 If `is_bounded` is set to `True`, then the output is element-wise bounded 274 by applying a `jax.nn.tanh` activation to the output matrix $L$. 275 276 Args: 277 key (PRNGKey): random key 278 dim (int): dimension of the input 279 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 280 units (list[int]): layer sizes 281 is_bounded (bool, optional): whether to give a element-wise bounded output. Defaults to True. 282 """ 283 self.dim = dim 284 units = units + [dim * dim] 285 super().__init__(key, dim, units, activation) 286 self.is_bounded = is_bounded 287 288 def __call__(self, x: ArrayLike) -> Array: 289 L = super().__call__(x).reshape(self.dim, self.dim) 290 if self.is_bounded: 291 L = jax.nn.tanh(L) 292 return L - L.T 293 294 295# ------------------------------------------------------------------ # 296# Diffusion networks # 297# ------------------------------------------------------------------ # 298 299 300class DiffusionMLP(MLP): 301 """Diffusion matrix network based on a multi-layer perceptron.""" 302 303 alpha: float 304 dim: int 305 param_dim: int 306 307 def __init__( 308 self, 309 key: PRNGKey, 310 dim: int, 311 units: list[int], 312 activation: str, 313 alpha: float, 314 param_dim: int = 0, 315 ) -> None: 316 r"""Diffusion matrix network based on a multi-layer perceptron. 317 318 This implements the diffusion matrix function 319 $$ 320 \sigma(x, args) = \text{Chol}(\alpha I + \text{MLP}(x, args)) 321 $$ 322 where $\text{Chol}$ is the Cholesky decomposition. 323 Here, MLP maps $(x, args)$ of dimension `dim` + `param_dim` to a matrix of size `dim` x `dim`, 324 325 Args: 326 key (PRNGKey): random key 327 dim (int): dimension of the input 328 units (list[int]): layer sizes 329 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 330 alpha (float): regulariser 331 param_dim (int, optional): dimension of the parameters. Defaults to 0. 332 """ 333 self.dim = dim 334 self.param_dim = param_dim 335 units = units + [dim * dim] 336 super().__init__(key, dim + param_dim, units, activation) 337 self.alpha = alpha 338 339 def __call__(self, x: ArrayLike, args: ArrayLike) -> Array: 340 if self.param_dim > 0: 341 x = jnp.concatenate([x, args[1:]], axis=0) 342 sigma = super().__call__(x).reshape(self.dim, self.dim) 343 sigma_squared_regularised = self.alpha * jnp.eye(self.dim) + sigma @ sigma.T 344 return jnp.linalg.cholesky(sigma_squared_regularised) 345 346 347class DiffusionDiagonalMLP(MLP): 348 """Diagonal diffusion matrix network based on a multi-layer perceptron.""" 349 350 alpha: float 351 dim: int 352 param_dim: int 353 354 def __init__( 355 self, 356 key: PRNGKey, 357 dim: int, 358 units: list[int], 359 activation: str, 360 alpha: float, 361 param_dim: int = 0, 362 ) -> None: 363 r"""Diagonal diffusion matrix network based on a multi-layer perceptron. 364 365 This implements the diffusion matrix function 366 $$ 367 \sigma(x, args) = \text{diag}(\alpha + \text{MLP}(x, args)^2)^{\frac{1}{2}}. 368 $$ 369 Here, MLP maps $(x, args)$ of dimension `dim` + `param_dim` to a vector of size `dim`. 370 371 Args: 372 key (PRNGKey): random key 373 dim (int): dimension of the input 374 units (list[int]): layer sizes 375 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 376 alpha (float): regulariser 377 param_dim (int, optional): dimension of the parameters. Defaults to 0. 378 """ 379 self.dim = dim 380 self.param_dim = param_dim 381 units = units + [dim] 382 super().__init__(key, dim + param_dim, units, activation) 383 self.alpha = alpha 384 385 def __call__(self, x: ArrayLike, args: ArrayLike) -> Array: 386 if self.param_dim > 0: 387 x = jnp.concatenate([x, args[1:]], axis=0) 388 sigma_diag = super().__call__(x) 389 sigma_diag_regularised = jnp.sqrt(self.alpha + sigma_diag**2) 390 return jnp.diag(sigma_diag_regularised) 391 392 393class DiffusionDiagonalConstant(eqx.Module): 394 """Diagonal diffusion matrix network based on a constant layer.""" 395 396 alpha: float 397 constant_layer: ConstantLayer 398 dim: int 399 400 def __init__(self, key: PRNGKey, dim: int, alpha: float) -> None: 401 r"""Diagonal diffusion matrix network based on a constant layer. 402 403 This implements the diffusion matrix function that is constant 404 $$ 405 \sigma(x, args) = \text{diag}(\alpha + \text{Constant}^2)^{\frac{1}{2}}. 406 $$ 407 where $\text{Constant}$ is a vector of size `dim`. 408 Note that by constant we mean that it does not depend on the input $x$ or the parameters $args$, 409 but it can be trained. 410 411 Args: 412 key (PRNGKey): random key 413 dim (int): dimension of the input 414 alpha (float): regulariser 415 """ 416 self.dim = dim 417 self.alpha = alpha 418 self.constant_layer = ConstantLayer(dim, key) 419 420 def __call__(self, x: ArrayLike, args: ArrayLike) -> Array: 421 sigma_diag = self.constant_layer() 422 sigma_squared_regularised = jnp.sqrt(self.alpha + sigma_diag**2) 423 return jnp.diag(sigma_squared_regularised) 424 425 426# ------------------------------------------------------------------ # 427# Dimensionality transforms # 428# ------------------------------------------------------------------ # 429 430 431class PCATransform(eqx.Module): 432 """PCA transform.""" 433 434 mean: ArrayLike 435 components: ArrayLike 436 scaling: ArrayLike 437 centre: bool 438 439 def __init__( 440 self, 441 mean: ArrayLike, 442 components: ArrayLike, 443 scaling: ArrayLike, 444 centre: bool = False, 445 ) -> None: 446 r"""PCA transform. 447 448 Transforms the input vector $x$ via 449 $$ 450 x \mapsto \text{components} ( x ) / \sqrt{\text{scaling}}. 451 $$ 452 If `centre` is set to `True`, then the input is first centered by subtracting the `mean`. 453 454 Args: 455 mean (ArrayLike): mean of the data used to fit the PCA 456 components (ArrayLike): PCA components 457 scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance) 458 centre (bool, optional): whether to center the data using `mean`. Defaults to False. 459 """ 460 self.mean = jnp.array(mean) 461 self.components = jnp.array(components) 462 self.scaling = jnp.array(scaling) 463 self.centre = centre 464 465 def __call__(self, x: ArrayLike) -> Array: 466 if self.centre: 467 x = x - self.mean 468 x_reduced = self.components @ x 469 return x_reduced / jnp.sqrt(self.scaling) 470 471 472class InversePCATransform(PCATransform): 473 """Inverse PCA transform.""" 474 475 def __init__( 476 self, mean: ArrayLike, components: ArrayLike, scaling: ArrayLike 477 ) -> None: 478 r"""Inverse PCA transform. 479 480 Transforms the input vector $z$ via 481 $$ 482 z \mapsto \text{components}^\top ( z \sqrt{\text{scaling}} ) 483 $$ 484 If `centre` is set to `True`, `mean` is added to the output. 485 486 Args: 487 mean (ArrayLike): mean of the data used to fit the PCA 488 components (ArrayLike): PCA components 489 scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance) 490 """ 491 super().__init__(mean, components, scaling) 492 493 def __call__(self, z: ArrayLike) -> Array: 494 x_reconstructed = self.components.T @ (jnp.sqrt(self.scaling) * z) 495 if self.centre: 496 x_reconstructed = x_reconstructed + self.mean 497 return x_reconstructed 498 499 500class PCAResNetTransform(PCATransform): 501 mlp: MLP 502 mlp_scale: float 503 mlp_input_scale: float 504 505 def __init__( 506 self, 507 mean: ArrayLike, 508 components: ArrayLike, 509 scaling: ArrayLike, 510 key: PRNGKey, 511 units: list[int], 512 activation: str, 513 mlp_scale: float, 514 mlp_input_scale: float, 515 ) -> None: 516 r"""PCA-ResNet transform. 517 518 This combines the PCA transform with an MLP to give a ResNet-like architecture. 519 $$ 520 x \mapsto \text{PCA}(x) + \text{mlp\_scale} \times \text{MLP}(\text{mlp\_input\_scale} \times x). 521 $$ 522 The input scale adjusts the input (often not order 1) so that the MLP can learn the correct scale. 523 524 Args: 525 mean (ArrayLike): mean of the data used to fit the PCA 526 components (ArrayLike): PCA components 527 scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance) 528 key (PRNGKey): random key 529 units (list[int]): layer sizes of the MLP 530 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 531 mlp_scale (float): scale of the MLP output 532 mlp_input_scale (float): scale of the input to the MLP 533 """ 534 super().__init__(mean, components, scaling) 535 units = units + [components.shape[0]] 536 dim = components.shape[1] 537 self.mlp = MLP(key, dim, units, activation) 538 self.mlp_scale = mlp_scale 539 self.mlp_input_scale = mlp_input_scale 540 541 def __call__(self, x: ArrayLike) -> Array: 542 pca_features = super().__call__(x) 543 mlp_features = self.mlp(self.mlp_input_scale * x) 544 return pca_features + self.mlp_scale * mlp_features 545 546 def pca_transform(self, x: ArrayLike) -> Array: 547 """Perform the PCA transform. 548 549 Args: 550 x (ArrayLike): state 551 552 Returns: 553 Array: pca features 554 """ 555 return super().__call__(x) 556 557 558class InversePCAResNetTransform(InversePCATransform): 559 """Inverse PCA-ResNet transform.""" 560 561 mlp: MLP 562 mlp_scale: float 563 564 def __init__( 565 self, 566 mean: ArrayLike, 567 components: ArrayLike, 568 scaling: ArrayLike, 569 key: PRNGKey, 570 units: list[int], 571 activation: str, 572 mlp_scale: float, 573 ) -> None: 574 r"""Inverse PCA-ResNet transform. 575 576 This combines the inverse PCA transform with an MLP to give a ResNet-like architecture. 577 $$ 578 z \mapsto \text{PCA}^{-1}(z) + \text{mlp\_scale} \times \text{MLP}(z). 579 $$ 580 581 Args: 582 mean (ArrayLike): mean of the data used to fit the PCA 583 components (ArrayLike): PCA components 584 scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance) 585 key (PRNGKey): random key 586 units (list[int]): layer sizes of the MLP 587 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 588 mlp_scale (float): scale of the MLP output 589 """ 590 super().__init__(mean, components, scaling) 591 units = units + [components.shape[1]] 592 dim = components.shape[0] 593 self.mlp = MLP(key, dim, units, activation) 594 self.mlp_scale = mlp_scale 595 596 def __call__(self, z: ArrayLike) -> Array: 597 inverse_pca_recon = super().__call__(z) 598 mlp_recon = self.mlp(z) 599 return inverse_pca_recon + self.mlp_scale * mlp_recon 600 601 def inverse_pca_transform(self, z: ArrayLike) -> Array: 602 """Inverse PCA transform. 603 604 Args: 605 z (ArrayLike): reduced state 606 607 Returns: 608 Array: reconstructed state 609 """ 610 return super().__call__(z)
37class MLP(eqx.Module): 38 """Multi-layer perceptron.""" 39 40 layers: list[eqx.nn.Linear] 41 activation: Callable[[ArrayLike], Array] 42 43 def __init__( 44 self, key: PRNGKey, dim: int, units: list[int], activation: str 45 ) -> None: 46 r"""Multi-layer perceptron. 47 48 Example: 49 `mlp = MLP(key=jax.random.PRNGKey(0), dim=2, units=[32, 32, 1], activation='tanh')` 50 gives a 2-hidden-layer MLP 51 52 $$2 \to 32 \to 32 \to 1$$ 53 54 with tanh activation. 55 56 Args: 57 key (PRNGKey): random key 58 dim (int): dimension of the input 59 units (list[int]): layer sizes 60 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 61 """ 62 num_layers = len(units) 63 units = [dim] + units 64 keys = jax.random.split(key, num_layers) 65 self.layers = [ 66 eqx.nn.Linear(units[i], units[i + 1], key=key) for i, key in enumerate(keys) 67 ] 68 self.activation = get_activation(activation) 69 70 def __call__(self, x: ArrayLike) -> Array: 71 h = x 72 for layer in self.layers[:-1]: 73 h = self.activation(layer(h)) 74 output = self.layers[-1](h) 75 return output
Multi-layer perceptron.
43 def __init__( 44 self, key: PRNGKey, dim: int, units: list[int], activation: str 45 ) -> None: 46 r"""Multi-layer perceptron. 47 48 Example: 49 `mlp = MLP(key=jax.random.PRNGKey(0), dim=2, units=[32, 32, 1], activation='tanh')` 50 gives a 2-hidden-layer MLP 51 52 $$2 \to 32 \to 32 \to 1$$ 53 54 with tanh activation. 55 56 Args: 57 key (PRNGKey): random key 58 dim (int): dimension of the input 59 units (list[int]): layer sizes 60 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 61 """ 62 num_layers = len(units) 63 units = [dim] + units 64 keys = jax.random.split(key, num_layers) 65 self.layers = [ 66 eqx.nn.Linear(units[i], units[i + 1], key=key) for i, key in enumerate(keys) 67 ] 68 self.activation = get_activation(activation)
Multi-layer perceptron.
Example:
mlp = MLP(key=jax.random.PRNGKey(0), dim=2, units=[32, 32, 1], activation='tanh')
gives a 2-hidden-layer MLP
$$2 \to 32 \to 32 \to 1$$
with tanh activation.
Arguments:
- key (PRNGKey): random key
- dim (int): dimension of the input
- units (list[int]): layer sizes
- activation (str): activation function (can be any in
jax.nn
or custom ones defined inonsagernet._activations
)
83class PotentialMLP(MLP): 84 """Potential network based on a multi-layer perceptron.""" 85 86 alpha: float 87 dim: int 88 param_dim: int 89 90 def __init__( 91 self, 92 key: PRNGKey, 93 dim: int, 94 units: list[int], 95 activation: str, 96 alpha: float, 97 param_dim: int = 0, 98 ) -> None: 99 r"""Potential network based on a multi-layer perceptron. 100 101 This implements the potential function 102 $$ 103 V(x, args) = \alpha \|(x, args)\|^2 + \text{MLP}(x, args) 104 $$ 105 where $x$ is the input and $u$ are additional parameters. 106 The constant $\alpha \geq 0$ is a regularisation term, 107 which gives a quadratic growth to ensure that the potential is integrable. 108 We are tacitly assuming that MLP is of sub-quadratic growth, 109 so only choose activation functions that have this property 110 (most activation functions are either bounded or of linear growth). 111 112 Args: 113 key (PRNGKey): random key 114 dim (int): dimension of the input 115 units (list[int]): layer sizes 116 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 117 alpha (float): regulariser 118 param_dim (int, optional): dimensions of the parameters. Defaults to 0. 119 """ 120 self.dim = dim 121 units = units + [1] 122 self.param_dim = param_dim 123 super().__init__(key, dim + param_dim, units, activation) 124 self.alpha = alpha 125 126 def __call__(self, x: ArrayLike, args: ArrayLike) -> Array: 127 if self.param_dim > 0: 128 x = jnp.concatenate([x, args[1:]], axis=0) 129 output = super().__call__(x) + self.alpha * (x @ x) 130 return jnp.squeeze(output)
Potential network based on a multi-layer perceptron.
90 def __init__( 91 self, 92 key: PRNGKey, 93 dim: int, 94 units: list[int], 95 activation: str, 96 alpha: float, 97 param_dim: int = 0, 98 ) -> None: 99 r"""Potential network based on a multi-layer perceptron. 100 101 This implements the potential function 102 $$ 103 V(x, args) = \alpha \|(x, args)\|^2 + \text{MLP}(x, args) 104 $$ 105 where $x$ is the input and $u$ are additional parameters. 106 The constant $\alpha \geq 0$ is a regularisation term, 107 which gives a quadratic growth to ensure that the potential is integrable. 108 We are tacitly assuming that MLP is of sub-quadratic growth, 109 so only choose activation functions that have this property 110 (most activation functions are either bounded or of linear growth). 111 112 Args: 113 key (PRNGKey): random key 114 dim (int): dimension of the input 115 units (list[int]): layer sizes 116 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 117 alpha (float): regulariser 118 param_dim (int, optional): dimensions of the parameters. Defaults to 0. 119 """ 120 self.dim = dim 121 units = units + [1] 122 self.param_dim = param_dim 123 super().__init__(key, dim + param_dim, units, activation) 124 self.alpha = alpha
Potential network based on a multi-layer perceptron.
This implements the potential function $$ V(x, args) = \alpha \|(x, args)\|^2 + \text{MLP}(x, args) $$ where $x$ is the input and $u$ are additional parameters. The constant $\alpha \geq 0$ is a regularisation term, which gives a quadratic growth to ensure that the potential is integrable. We are tacitly assuming that MLP is of sub-quadratic growth, so only choose activation functions that have this property (most activation functions are either bounded or of linear growth).
Arguments:
- key (PRNGKey): random key
- dim (int): dimension of the input
- units (list[int]): layer sizes
- activation (str): activation function (can be any in
jax.nn
or custom ones defined inonsagernet._activations
) - alpha (float): regulariser
- param_dim (int, optional): dimensions of the parameters. Defaults to 0.
Inherited Members
133class PotentialResMLP(MLP): 134 r"""Potential network with a residual connection.""" 135 136 alpha: float 137 gamma_layer: eqx.nn.Linear 138 dim: int 139 param_dim: int 140 141 def __init__( 142 self, 143 key: PRNGKey, 144 dim: int, 145 units: list[int], 146 activation: str, 147 n_pot: int, 148 alpha: float, 149 param_dim: int = 0, 150 ) -> None: 151 r"""Potential network with a residual connection. 152 153 This implements the modified potential function 154 $$ 155 V(x, args) = \alpha \|(x, args)\|^2 156 + \frac{1}{2} 157 \| \text{MLP}(x, args)+ \Gamma (x, args) \|^2 158 $$ 159 where 160 161 - $\phi$ is a MLP of dim + param_dim -> n_pot 162 - $\Gamma$ ia matrix of size [n_pot, dim + para_dim] 163 - $\alpha > 0$ is a scalar regulariser 164 165 Args: 166 key (PRNGKey): random key 167 dim (int): dimension of the input 168 units (list[int]): layer sizes 169 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 170 n_pot (int): size of the MLP part of the potential 171 alpha (float): regulariser 172 param_dim (int, optional): dimension of the parameters. Defaults to 0. 173 """ 174 self.dim = dim 175 self.param_dim = param_dim 176 units = units + [n_pot] 177 mlp_key, gamma_key = jax.random.split(key) 178 super().__init__(mlp_key, dim + param_dim, units, activation) 179 self.alpha = alpha 180 self.gamma_layer = eqx.nn.Linear( 181 dim + param_dim, n_pot, key=gamma_key, use_bias=False 182 ) 183 184 def __call__(self, x: ArrayLike, args: ArrayLike) -> Array: 185 if self.param_dim > 0: 186 x = jnp.concatenate([x, args[1:]], axis=0) 187 output_phi = super().__call__(x) 188 output_gamma = self.gamma_layer(x) 189 output_combined = (output_phi + output_gamma) @ (output_phi + output_gamma) 190 regularisation = self.alpha * (x @ x) 191 return 0.5 * output_combined + regularisation
Potential network with a residual connection.
141 def __init__( 142 self, 143 key: PRNGKey, 144 dim: int, 145 units: list[int], 146 activation: str, 147 n_pot: int, 148 alpha: float, 149 param_dim: int = 0, 150 ) -> None: 151 r"""Potential network with a residual connection. 152 153 This implements the modified potential function 154 $$ 155 V(x, args) = \alpha \|(x, args)\|^2 156 + \frac{1}{2} 157 \| \text{MLP}(x, args)+ \Gamma (x, args) \|^2 158 $$ 159 where 160 161 - $\phi$ is a MLP of dim + param_dim -> n_pot 162 - $\Gamma$ ia matrix of size [n_pot, dim + para_dim] 163 - $\alpha > 0$ is a scalar regulariser 164 165 Args: 166 key (PRNGKey): random key 167 dim (int): dimension of the input 168 units (list[int]): layer sizes 169 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 170 n_pot (int): size of the MLP part of the potential 171 alpha (float): regulariser 172 param_dim (int, optional): dimension of the parameters. Defaults to 0. 173 """ 174 self.dim = dim 175 self.param_dim = param_dim 176 units = units + [n_pot] 177 mlp_key, gamma_key = jax.random.split(key) 178 super().__init__(mlp_key, dim + param_dim, units, activation) 179 self.alpha = alpha 180 self.gamma_layer = eqx.nn.Linear( 181 dim + param_dim, n_pot, key=gamma_key, use_bias=False 182 )
Potential network with a residual connection.
This implements the modified potential function $$ V(x, args) = \alpha \|(x, args)\|^2 + \frac{1}{2} \| \text{MLP}(x, args)+ \Gamma (x, args) \|^2 $$ where
- $\phi$ is a MLP of dim + param_dim -> n_pot
- $\Gamma$ ia matrix of size [n_pot, dim + para_dim]
- $\alpha > 0$ is a scalar regulariser
Arguments:
- key (PRNGKey): random key
- dim (int): dimension of the input
- units (list[int]): layer sizes
- activation (str): activation function (can be any in
jax.nn
or custom ones defined inonsagernet._activations
) - n_pot (int): size of the MLP part of the potential
- alpha (float): regulariser
- param_dim (int, optional): dimension of the parameters. Defaults to 0.
Inherited Members
199class DissipationMatrixMLP(MLP): 200 """Dissipation matrix network based on a multi-layer perceptron.""" 201 202 alpha: float 203 is_bounded: bool 204 dim: int 205 206 def __init__( 207 self, 208 key: PRNGKey, 209 dim: int, 210 units: list[int], 211 activation: str, 212 alpha: float, 213 is_bounded: bool = True, 214 ) -> None: 215 r"""Dissipation matrix network based on a multi-layer perceptron. 216 217 The MLP maps $x$ of dimension `dim` to a matrix $L(x)$ of size `dim` x `dim`, 218 and then reshapes it to a `dim` x `dim` matrix. 219 Then, the output matrix is given by 220 $$ 221 M(x) = \alpha I + L(x) L(x)^\top. 222 $$ 223 If `is_bounded` is set to `True`, then the output is element-wise bounded 224 by applying a `jax.nn.tanh` activation to the output matrix $L$. 225 226 Args: 227 key (PRNGKey): random key 228 dim (int): dimension of the input 229 units (list[int]): layer sizes 230 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 231 alpha (float): regulariser 232 is_bounded (bool, optional): whether to give a element-wise bounded output. Defaults to True. 233 """ 234 self.dim = dim 235 units = units + [dim * dim] 236 super().__init__(key, dim, units, activation) 237 self.alpha = alpha 238 self.is_bounded = is_bounded 239 240 def __call__(self, x: ArrayLike) -> Array: 241 L = super().__call__(x).reshape(self.dim, self.dim) 242 if self.is_bounded: 243 L = jax.nn.tanh(L) 244 return self.alpha * jnp.eye(self.dim) + L @ L.T
Dissipation matrix network based on a multi-layer perceptron.
206 def __init__( 207 self, 208 key: PRNGKey, 209 dim: int, 210 units: list[int], 211 activation: str, 212 alpha: float, 213 is_bounded: bool = True, 214 ) -> None: 215 r"""Dissipation matrix network based on a multi-layer perceptron. 216 217 The MLP maps $x$ of dimension `dim` to a matrix $L(x)$ of size `dim` x `dim`, 218 and then reshapes it to a `dim` x `dim` matrix. 219 Then, the output matrix is given by 220 $$ 221 M(x) = \alpha I + L(x) L(x)^\top. 222 $$ 223 If `is_bounded` is set to `True`, then the output is element-wise bounded 224 by applying a `jax.nn.tanh` activation to the output matrix $L$. 225 226 Args: 227 key (PRNGKey): random key 228 dim (int): dimension of the input 229 units (list[int]): layer sizes 230 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 231 alpha (float): regulariser 232 is_bounded (bool, optional): whether to give a element-wise bounded output. Defaults to True. 233 """ 234 self.dim = dim 235 units = units + [dim * dim] 236 super().__init__(key, dim, units, activation) 237 self.alpha = alpha 238 self.is_bounded = is_bounded
Dissipation matrix network based on a multi-layer perceptron.
The MLP maps $x$ of dimension dim
to a matrix $L(x)$ of size dim
x dim
,
and then reshapes it to a dim
x dim
matrix.
Then, the output matrix is given by
$$
M(x) = \alpha I + L(x) L(x)^\top.
$$
If is_bounded
is set to True
, then the output is element-wise bounded
by applying a jax.nn.tanh
activation to the output matrix $L$.
Arguments:
- key (PRNGKey): random key
- dim (int): dimension of the input
- units (list[int]): layer sizes
- activation (str): activation function (can be any in
jax.nn
or custom ones defined inonsagernet._activations
) - alpha (float): regulariser
- is_bounded (bool, optional): whether to give a element-wise bounded output. Defaults to True.
Inherited Members
252class ConservationMatrixMLP(MLP): 253 """Conservation matrix network based on a multi-layer perceptron.""" 254 255 is_bounded: bool 256 dim: int 257 258 def __init__( 259 self, 260 key: PRNGKey, 261 dim: int, 262 activation: str, 263 units: list[int], 264 is_bounded: bool = True, 265 ) -> None: 266 r"""Conservation matrix network based on a multi-layer perceptron. 267 268 The MLP maps $x$ of dimension `dim` to a matrix $L(x)$ of size `dim` x `dim`, 269 and then reshapes it to a `dim` x `dim` matrix. 270 Then, the output matrix is given by 271 $$ 272 W(x) = L(x) - L(x)^\top. 273 $$ 274 If `is_bounded` is set to `True`, then the output is element-wise bounded 275 by applying a `jax.nn.tanh` activation to the output matrix $L$. 276 277 Args: 278 key (PRNGKey): random key 279 dim (int): dimension of the input 280 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 281 units (list[int]): layer sizes 282 is_bounded (bool, optional): whether to give a element-wise bounded output. Defaults to True. 283 """ 284 self.dim = dim 285 units = units + [dim * dim] 286 super().__init__(key, dim, units, activation) 287 self.is_bounded = is_bounded 288 289 def __call__(self, x: ArrayLike) -> Array: 290 L = super().__call__(x).reshape(self.dim, self.dim) 291 if self.is_bounded: 292 L = jax.nn.tanh(L) 293 return L - L.T
Conservation matrix network based on a multi-layer perceptron.
258 def __init__( 259 self, 260 key: PRNGKey, 261 dim: int, 262 activation: str, 263 units: list[int], 264 is_bounded: bool = True, 265 ) -> None: 266 r"""Conservation matrix network based on a multi-layer perceptron. 267 268 The MLP maps $x$ of dimension `dim` to a matrix $L(x)$ of size `dim` x `dim`, 269 and then reshapes it to a `dim` x `dim` matrix. 270 Then, the output matrix is given by 271 $$ 272 W(x) = L(x) - L(x)^\top. 273 $$ 274 If `is_bounded` is set to `True`, then the output is element-wise bounded 275 by applying a `jax.nn.tanh` activation to the output matrix $L$. 276 277 Args: 278 key (PRNGKey): random key 279 dim (int): dimension of the input 280 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 281 units (list[int]): layer sizes 282 is_bounded (bool, optional): whether to give a element-wise bounded output. Defaults to True. 283 """ 284 self.dim = dim 285 units = units + [dim * dim] 286 super().__init__(key, dim, units, activation) 287 self.is_bounded = is_bounded
Conservation matrix network based on a multi-layer perceptron.
The MLP maps $x$ of dimension dim
to a matrix $L(x)$ of size dim
x dim
,
and then reshapes it to a dim
x dim
matrix.
Then, the output matrix is given by
$$
W(x) = L(x) - L(x)^\top.
$$
If is_bounded
is set to True
, then the output is element-wise bounded
by applying a jax.nn.tanh
activation to the output matrix $L$.
Arguments:
- key (PRNGKey): random key
- dim (int): dimension of the input
- activation (str): activation function (can be any in
jax.nn
or custom ones defined inonsagernet._activations
) - units (list[int]): layer sizes
- is_bounded (bool, optional): whether to give a element-wise bounded output. Defaults to True.
Inherited Members
301class DiffusionMLP(MLP): 302 """Diffusion matrix network based on a multi-layer perceptron.""" 303 304 alpha: float 305 dim: int 306 param_dim: int 307 308 def __init__( 309 self, 310 key: PRNGKey, 311 dim: int, 312 units: list[int], 313 activation: str, 314 alpha: float, 315 param_dim: int = 0, 316 ) -> None: 317 r"""Diffusion matrix network based on a multi-layer perceptron. 318 319 This implements the diffusion matrix function 320 $$ 321 \sigma(x, args) = \text{Chol}(\alpha I + \text{MLP}(x, args)) 322 $$ 323 where $\text{Chol}$ is the Cholesky decomposition. 324 Here, MLP maps $(x, args)$ of dimension `dim` + `param_dim` to a matrix of size `dim` x `dim`, 325 326 Args: 327 key (PRNGKey): random key 328 dim (int): dimension of the input 329 units (list[int]): layer sizes 330 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 331 alpha (float): regulariser 332 param_dim (int, optional): dimension of the parameters. Defaults to 0. 333 """ 334 self.dim = dim 335 self.param_dim = param_dim 336 units = units + [dim * dim] 337 super().__init__(key, dim + param_dim, units, activation) 338 self.alpha = alpha 339 340 def __call__(self, x: ArrayLike, args: ArrayLike) -> Array: 341 if self.param_dim > 0: 342 x = jnp.concatenate([x, args[1:]], axis=0) 343 sigma = super().__call__(x).reshape(self.dim, self.dim) 344 sigma_squared_regularised = self.alpha * jnp.eye(self.dim) + sigma @ sigma.T 345 return jnp.linalg.cholesky(sigma_squared_regularised)
Diffusion matrix network based on a multi-layer perceptron.
308 def __init__( 309 self, 310 key: PRNGKey, 311 dim: int, 312 units: list[int], 313 activation: str, 314 alpha: float, 315 param_dim: int = 0, 316 ) -> None: 317 r"""Diffusion matrix network based on a multi-layer perceptron. 318 319 This implements the diffusion matrix function 320 $$ 321 \sigma(x, args) = \text{Chol}(\alpha I + \text{MLP}(x, args)) 322 $$ 323 where $\text{Chol}$ is the Cholesky decomposition. 324 Here, MLP maps $(x, args)$ of dimension `dim` + `param_dim` to a matrix of size `dim` x `dim`, 325 326 Args: 327 key (PRNGKey): random key 328 dim (int): dimension of the input 329 units (list[int]): layer sizes 330 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 331 alpha (float): regulariser 332 param_dim (int, optional): dimension of the parameters. Defaults to 0. 333 """ 334 self.dim = dim 335 self.param_dim = param_dim 336 units = units + [dim * dim] 337 super().__init__(key, dim + param_dim, units, activation) 338 self.alpha = alpha
Diffusion matrix network based on a multi-layer perceptron.
This implements the diffusion matrix function
$$
\sigma(x, args) = \text{Chol}(\alpha I + \text{MLP}(x, args))
$$
where $\text{Chol}$ is the Cholesky decomposition.
Here, MLP maps $(x, args)$ of dimension dim
+ param_dim
to a matrix of size dim
x dim
,
Arguments:
- key (PRNGKey): random key
- dim (int): dimension of the input
- units (list[int]): layer sizes
- activation (str): activation function (can be any in
jax.nn
or custom ones defined inonsagernet._activations
) - alpha (float): regulariser
- param_dim (int, optional): dimension of the parameters. Defaults to 0.
Inherited Members
348class DiffusionDiagonalMLP(MLP): 349 """Diagonal diffusion matrix network based on a multi-layer perceptron.""" 350 351 alpha: float 352 dim: int 353 param_dim: int 354 355 def __init__( 356 self, 357 key: PRNGKey, 358 dim: int, 359 units: list[int], 360 activation: str, 361 alpha: float, 362 param_dim: int = 0, 363 ) -> None: 364 r"""Diagonal diffusion matrix network based on a multi-layer perceptron. 365 366 This implements the diffusion matrix function 367 $$ 368 \sigma(x, args) = \text{diag}(\alpha + \text{MLP}(x, args)^2)^{\frac{1}{2}}. 369 $$ 370 Here, MLP maps $(x, args)$ of dimension `dim` + `param_dim` to a vector of size `dim`. 371 372 Args: 373 key (PRNGKey): random key 374 dim (int): dimension of the input 375 units (list[int]): layer sizes 376 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 377 alpha (float): regulariser 378 param_dim (int, optional): dimension of the parameters. Defaults to 0. 379 """ 380 self.dim = dim 381 self.param_dim = param_dim 382 units = units + [dim] 383 super().__init__(key, dim + param_dim, units, activation) 384 self.alpha = alpha 385 386 def __call__(self, x: ArrayLike, args: ArrayLike) -> Array: 387 if self.param_dim > 0: 388 x = jnp.concatenate([x, args[1:]], axis=0) 389 sigma_diag = super().__call__(x) 390 sigma_diag_regularised = jnp.sqrt(self.alpha + sigma_diag**2) 391 return jnp.diag(sigma_diag_regularised)
Diagonal diffusion matrix network based on a multi-layer perceptron.
355 def __init__( 356 self, 357 key: PRNGKey, 358 dim: int, 359 units: list[int], 360 activation: str, 361 alpha: float, 362 param_dim: int = 0, 363 ) -> None: 364 r"""Diagonal diffusion matrix network based on a multi-layer perceptron. 365 366 This implements the diffusion matrix function 367 $$ 368 \sigma(x, args) = \text{diag}(\alpha + \text{MLP}(x, args)^2)^{\frac{1}{2}}. 369 $$ 370 Here, MLP maps $(x, args)$ of dimension `dim` + `param_dim` to a vector of size `dim`. 371 372 Args: 373 key (PRNGKey): random key 374 dim (int): dimension of the input 375 units (list[int]): layer sizes 376 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 377 alpha (float): regulariser 378 param_dim (int, optional): dimension of the parameters. Defaults to 0. 379 """ 380 self.dim = dim 381 self.param_dim = param_dim 382 units = units + [dim] 383 super().__init__(key, dim + param_dim, units, activation) 384 self.alpha = alpha
Diagonal diffusion matrix network based on a multi-layer perceptron.
This implements the diffusion matrix function
$$
\sigma(x, args) = \text{diag}(\alpha + \text{MLP}(x, args)^2)^{\frac{1}{2}}.
$$
Here, MLP maps $(x, args)$ of dimension dim
+ param_dim
to a vector of size dim
.
Arguments:
- key (PRNGKey): random key
- dim (int): dimension of the input
- units (list[int]): layer sizes
- activation (str): activation function (can be any in
jax.nn
or custom ones defined inonsagernet._activations
) - alpha (float): regulariser
- param_dim (int, optional): dimension of the parameters. Defaults to 0.
Inherited Members
394class DiffusionDiagonalConstant(eqx.Module): 395 """Diagonal diffusion matrix network based on a constant layer.""" 396 397 alpha: float 398 constant_layer: ConstantLayer 399 dim: int 400 401 def __init__(self, key: PRNGKey, dim: int, alpha: float) -> None: 402 r"""Diagonal diffusion matrix network based on a constant layer. 403 404 This implements the diffusion matrix function that is constant 405 $$ 406 \sigma(x, args) = \text{diag}(\alpha + \text{Constant}^2)^{\frac{1}{2}}. 407 $$ 408 where $\text{Constant}$ is a vector of size `dim`. 409 Note that by constant we mean that it does not depend on the input $x$ or the parameters $args$, 410 but it can be trained. 411 412 Args: 413 key (PRNGKey): random key 414 dim (int): dimension of the input 415 alpha (float): regulariser 416 """ 417 self.dim = dim 418 self.alpha = alpha 419 self.constant_layer = ConstantLayer(dim, key) 420 421 def __call__(self, x: ArrayLike, args: ArrayLike) -> Array: 422 sigma_diag = self.constant_layer() 423 sigma_squared_regularised = jnp.sqrt(self.alpha + sigma_diag**2) 424 return jnp.diag(sigma_squared_regularised)
Diagonal diffusion matrix network based on a constant layer.
401 def __init__(self, key: PRNGKey, dim: int, alpha: float) -> None: 402 r"""Diagonal diffusion matrix network based on a constant layer. 403 404 This implements the diffusion matrix function that is constant 405 $$ 406 \sigma(x, args) = \text{diag}(\alpha + \text{Constant}^2)^{\frac{1}{2}}. 407 $$ 408 where $\text{Constant}$ is a vector of size `dim`. 409 Note that by constant we mean that it does not depend on the input $x$ or the parameters $args$, 410 but it can be trained. 411 412 Args: 413 key (PRNGKey): random key 414 dim (int): dimension of the input 415 alpha (float): regulariser 416 """ 417 self.dim = dim 418 self.alpha = alpha 419 self.constant_layer = ConstantLayer(dim, key)
Diagonal diffusion matrix network based on a constant layer.
This implements the diffusion matrix function that is constant
$$
\sigma(x, args) = \text{diag}(\alpha + \text{Constant}^2)^{\frac{1}{2}}.
$$
where $\text{Constant}$ is a vector of size dim
.
Note that by constant we mean that it does not depend on the input $x$ or the parameters $args$,
but it can be trained.
Arguments:
- key (PRNGKey): random key
- dim (int): dimension of the input
- alpha (float): regulariser
432class PCATransform(eqx.Module): 433 """PCA transform.""" 434 435 mean: ArrayLike 436 components: ArrayLike 437 scaling: ArrayLike 438 centre: bool 439 440 def __init__( 441 self, 442 mean: ArrayLike, 443 components: ArrayLike, 444 scaling: ArrayLike, 445 centre: bool = False, 446 ) -> None: 447 r"""PCA transform. 448 449 Transforms the input vector $x$ via 450 $$ 451 x \mapsto \text{components} ( x ) / \sqrt{\text{scaling}}. 452 $$ 453 If `centre` is set to `True`, then the input is first centered by subtracting the `mean`. 454 455 Args: 456 mean (ArrayLike): mean of the data used to fit the PCA 457 components (ArrayLike): PCA components 458 scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance) 459 centre (bool, optional): whether to center the data using `mean`. Defaults to False. 460 """ 461 self.mean = jnp.array(mean) 462 self.components = jnp.array(components) 463 self.scaling = jnp.array(scaling) 464 self.centre = centre 465 466 def __call__(self, x: ArrayLike) -> Array: 467 if self.centre: 468 x = x - self.mean 469 x_reduced = self.components @ x 470 return x_reduced / jnp.sqrt(self.scaling)
PCA transform.
440 def __init__( 441 self, 442 mean: ArrayLike, 443 components: ArrayLike, 444 scaling: ArrayLike, 445 centre: bool = False, 446 ) -> None: 447 r"""PCA transform. 448 449 Transforms the input vector $x$ via 450 $$ 451 x \mapsto \text{components} ( x ) / \sqrt{\text{scaling}}. 452 $$ 453 If `centre` is set to `True`, then the input is first centered by subtracting the `mean`. 454 455 Args: 456 mean (ArrayLike): mean of the data used to fit the PCA 457 components (ArrayLike): PCA components 458 scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance) 459 centre (bool, optional): whether to center the data using `mean`. Defaults to False. 460 """ 461 self.mean = jnp.array(mean) 462 self.components = jnp.array(components) 463 self.scaling = jnp.array(scaling) 464 self.centre = centre
PCA transform.
Transforms the input vector $x$ via
$$
x \mapsto \text{components} ( x ) / \sqrt{\text{scaling}}.
$$
If centre
is set to True
, then the input is first centered by subtracting the mean
.
Arguments:
- mean (ArrayLike): mean of the data used to fit the PCA
- components (ArrayLike): PCA components
- scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance)
- centre (bool, optional): whether to center the data using
mean
. Defaults to False.
473class InversePCATransform(PCATransform): 474 """Inverse PCA transform.""" 475 476 def __init__( 477 self, mean: ArrayLike, components: ArrayLike, scaling: ArrayLike 478 ) -> None: 479 r"""Inverse PCA transform. 480 481 Transforms the input vector $z$ via 482 $$ 483 z \mapsto \text{components}^\top ( z \sqrt{\text{scaling}} ) 484 $$ 485 If `centre` is set to `True`, `mean` is added to the output. 486 487 Args: 488 mean (ArrayLike): mean of the data used to fit the PCA 489 components (ArrayLike): PCA components 490 scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance) 491 """ 492 super().__init__(mean, components, scaling) 493 494 def __call__(self, z: ArrayLike) -> Array: 495 x_reconstructed = self.components.T @ (jnp.sqrt(self.scaling) * z) 496 if self.centre: 497 x_reconstructed = x_reconstructed + self.mean 498 return x_reconstructed
Inverse PCA transform.
476 def __init__( 477 self, mean: ArrayLike, components: ArrayLike, scaling: ArrayLike 478 ) -> None: 479 r"""Inverse PCA transform. 480 481 Transforms the input vector $z$ via 482 $$ 483 z \mapsto \text{components}^\top ( z \sqrt{\text{scaling}} ) 484 $$ 485 If `centre` is set to `True`, `mean` is added to the output. 486 487 Args: 488 mean (ArrayLike): mean of the data used to fit the PCA 489 components (ArrayLike): PCA components 490 scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance) 491 """ 492 super().__init__(mean, components, scaling)
Inverse PCA transform.
Transforms the input vector $z$ via
$$
z \mapsto \text{components}^\top ( z \sqrt{\text{scaling}} )
$$
If centre
is set to True
, mean
is added to the output.
Arguments:
- mean (ArrayLike): mean of the data used to fit the PCA
- components (ArrayLike): PCA components
- scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance)
Inherited Members
501class PCAResNetTransform(PCATransform): 502 mlp: MLP 503 mlp_scale: float 504 mlp_input_scale: float 505 506 def __init__( 507 self, 508 mean: ArrayLike, 509 components: ArrayLike, 510 scaling: ArrayLike, 511 key: PRNGKey, 512 units: list[int], 513 activation: str, 514 mlp_scale: float, 515 mlp_input_scale: float, 516 ) -> None: 517 r"""PCA-ResNet transform. 518 519 This combines the PCA transform with an MLP to give a ResNet-like architecture. 520 $$ 521 x \mapsto \text{PCA}(x) + \text{mlp\_scale} \times \text{MLP}(\text{mlp\_input\_scale} \times x). 522 $$ 523 The input scale adjusts the input (often not order 1) so that the MLP can learn the correct scale. 524 525 Args: 526 mean (ArrayLike): mean of the data used to fit the PCA 527 components (ArrayLike): PCA components 528 scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance) 529 key (PRNGKey): random key 530 units (list[int]): layer sizes of the MLP 531 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 532 mlp_scale (float): scale of the MLP output 533 mlp_input_scale (float): scale of the input to the MLP 534 """ 535 super().__init__(mean, components, scaling) 536 units = units + [components.shape[0]] 537 dim = components.shape[1] 538 self.mlp = MLP(key, dim, units, activation) 539 self.mlp_scale = mlp_scale 540 self.mlp_input_scale = mlp_input_scale 541 542 def __call__(self, x: ArrayLike) -> Array: 543 pca_features = super().__call__(x) 544 mlp_features = self.mlp(self.mlp_input_scale * x) 545 return pca_features + self.mlp_scale * mlp_features 546 547 def pca_transform(self, x: ArrayLike) -> Array: 548 """Perform the PCA transform. 549 550 Args: 551 x (ArrayLike): state 552 553 Returns: 554 Array: pca features 555 """ 556 return super().__call__(x)
506 def __init__( 507 self, 508 mean: ArrayLike, 509 components: ArrayLike, 510 scaling: ArrayLike, 511 key: PRNGKey, 512 units: list[int], 513 activation: str, 514 mlp_scale: float, 515 mlp_input_scale: float, 516 ) -> None: 517 r"""PCA-ResNet transform. 518 519 This combines the PCA transform with an MLP to give a ResNet-like architecture. 520 $$ 521 x \mapsto \text{PCA}(x) + \text{mlp\_scale} \times \text{MLP}(\text{mlp\_input\_scale} \times x). 522 $$ 523 The input scale adjusts the input (often not order 1) so that the MLP can learn the correct scale. 524 525 Args: 526 mean (ArrayLike): mean of the data used to fit the PCA 527 components (ArrayLike): PCA components 528 scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance) 529 key (PRNGKey): random key 530 units (list[int]): layer sizes of the MLP 531 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 532 mlp_scale (float): scale of the MLP output 533 mlp_input_scale (float): scale of the input to the MLP 534 """ 535 super().__init__(mean, components, scaling) 536 units = units + [components.shape[0]] 537 dim = components.shape[1] 538 self.mlp = MLP(key, dim, units, activation) 539 self.mlp_scale = mlp_scale 540 self.mlp_input_scale = mlp_input_scale
PCA-ResNet transform.
This combines the PCA transform with an MLP to give a ResNet-like architecture. $$ x \mapsto \text{PCA}(x) + \text{mlp_scale} \times \text{MLP}(\text{mlp_input_scale} \times x). $$ The input scale adjusts the input (often not order 1) so that the MLP can learn the correct scale.
Arguments:
- mean (ArrayLike): mean of the data used to fit the PCA
- components (ArrayLike): PCA components
- scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance)
- key (PRNGKey): random key
- units (list[int]): layer sizes of the MLP
- activation (str): activation function (can be any in
jax.nn
or custom ones defined inonsagernet._activations
) - mlp_scale (float): scale of the MLP output
- mlp_input_scale (float): scale of the input to the MLP
Perform the PCA transform.
Arguments:
- x (ArrayLike): state
Returns:
Array: pca features
Inherited Members
559class InversePCAResNetTransform(InversePCATransform): 560 """Inverse PCA-ResNet transform.""" 561 562 mlp: MLP 563 mlp_scale: float 564 565 def __init__( 566 self, 567 mean: ArrayLike, 568 components: ArrayLike, 569 scaling: ArrayLike, 570 key: PRNGKey, 571 units: list[int], 572 activation: str, 573 mlp_scale: float, 574 ) -> None: 575 r"""Inverse PCA-ResNet transform. 576 577 This combines the inverse PCA transform with an MLP to give a ResNet-like architecture. 578 $$ 579 z \mapsto \text{PCA}^{-1}(z) + \text{mlp\_scale} \times \text{MLP}(z). 580 $$ 581 582 Args: 583 mean (ArrayLike): mean of the data used to fit the PCA 584 components (ArrayLike): PCA components 585 scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance) 586 key (PRNGKey): random key 587 units (list[int]): layer sizes of the MLP 588 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 589 mlp_scale (float): scale of the MLP output 590 """ 591 super().__init__(mean, components, scaling) 592 units = units + [components.shape[1]] 593 dim = components.shape[0] 594 self.mlp = MLP(key, dim, units, activation) 595 self.mlp_scale = mlp_scale 596 597 def __call__(self, z: ArrayLike) -> Array: 598 inverse_pca_recon = super().__call__(z) 599 mlp_recon = self.mlp(z) 600 return inverse_pca_recon + self.mlp_scale * mlp_recon 601 602 def inverse_pca_transform(self, z: ArrayLike) -> Array: 603 """Inverse PCA transform. 604 605 Args: 606 z (ArrayLike): reduced state 607 608 Returns: 609 Array: reconstructed state 610 """ 611 return super().__call__(z)
Inverse PCA-ResNet transform.
565 def __init__( 566 self, 567 mean: ArrayLike, 568 components: ArrayLike, 569 scaling: ArrayLike, 570 key: PRNGKey, 571 units: list[int], 572 activation: str, 573 mlp_scale: float, 574 ) -> None: 575 r"""Inverse PCA-ResNet transform. 576 577 This combines the inverse PCA transform with an MLP to give a ResNet-like architecture. 578 $$ 579 z \mapsto \text{PCA}^{-1}(z) + \text{mlp\_scale} \times \text{MLP}(z). 580 $$ 581 582 Args: 583 mean (ArrayLike): mean of the data used to fit the PCA 584 components (ArrayLike): PCA components 585 scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance) 586 key (PRNGKey): random key 587 units (list[int]): layer sizes of the MLP 588 activation (str): activation function (can be any in `jax.nn` or custom ones defined in `onsagernet._activations`) 589 mlp_scale (float): scale of the MLP output 590 """ 591 super().__init__(mean, components, scaling) 592 units = units + [components.shape[1]] 593 dim = components.shape[0] 594 self.mlp = MLP(key, dim, units, activation) 595 self.mlp_scale = mlp_scale
Inverse PCA-ResNet transform.
This combines the inverse PCA transform with an MLP to give a ResNet-like architecture. $$ z \mapsto \text{PCA}^{-1}(z) + \text{mlp_scale} \times \text{MLP}(z). $$
Arguments:
- mean (ArrayLike): mean of the data used to fit the PCA
- components (ArrayLike): PCA components
- scaling (ArrayLike): scaling of the PCA transform (e.g. explained variance)
- key (PRNGKey): random key
- units (list[int]): layer sizes of the MLP
- activation (str): activation function (can be any in
jax.nn
or custom ones defined inonsagernet._activations
) - mlp_scale (float): scale of the MLP output
Inverse PCA transform.
Arguments:
- z (ArrayLike): reduced state
Returns:
Array: reconstructed state