onsagernet._losses
Custom loss functions
This module contains custom loss functions for training the models.
The case class Loss
is an abstract class that defines the interface for loss functions.
Sub-classes must implement the Loss.compute_sample_loss
method.
#
1""" 2# Custom loss functions 3 4This module contains custom loss functions for training the models. 5 6The case class `Loss` is an abstract class that defines the interface for loss functions. 7Sub-classes must implement the `Loss.compute_sample_loss` method. 8 9# """ 10 11import jax 12import equinox as eqx 13import jax.numpy as jnp 14from abc import abstractmethod 15from jax.scipy.stats import multivariate_normal 16from jax.typing import ArrayLike 17from jax import Array 18from .dynamics import SDE, ReducedSDE 19 20 21class Loss(eqx.Module): 22 """Base class for losses 23 24 Computes the loss over a batch of samples.pdoc --math my_module 25 26 Args: 27 model (eqx.Module): the model to be trained 28 data_arrs (tuple[ArrayLike, ...]): input data arrays of the format (t, x, args) 29 30 Returns: 31 float: loss value 32 """ 33 34 @abstractmethod 35 def compute_sample_loss( 36 self, model: eqx.Module, *processed_data_arrs: tuple[ArrayLike, ...] 37 ) -> Array: 38 """Computes the loss for a single sample. 39 40 Args: 41 model (eqx.Module): the model to be trained 42 processed_data_arrs (tuple[ArrayLike, ...]): processed data arrays 43 44 Returns: 45 Array: the loss value 46 """ 47 pass 48 49 def _process_data_arrs( 50 self, *data_arrs: tuple[ArrayLike, ...] 51 ) -> tuple[ArrayLike, ...]: 52 """Process the data arrays to a format that can be used by the loss function. 53 54 The input data arrays are in the standard format of (t, x, args). 55 The format of the output arrays depend on the loss. 56 57 Args: 58 data_arrs (tuple[ArrayLike, ...]): input data arrays 59 60 Returns: 61 tuple[ArrayLike, ...]: processed data arrays 62 """ 63 return data_arrs 64 65 @eqx.filter_jit 66 def __call__(self, model: eqx.Module, *data_arrs: tuple[ArrayLike, ...]) -> float: 67 data_arrs_processed = self._process_data_arrs(*data_arrs) 68 in_axes = (None,) + len(data_arrs_processed) * (0,) 69 compute_loss_over_time = jax.vmap(self.compute_sample_loss, in_axes=in_axes) 70 compute_loss = jax.vmap(compute_loss_over_time, in_axes=in_axes) 71 sample_losses = compute_loss(model, *data_arrs_processed) 72 73 return jnp.mean(sample_losses) 74 75 76class MLELoss(Loss): 77 r"""Computes the maximum likelihood estimation loss for the SDE 78 79 The loss is the negative log-likelihood of the data given the model. 80 By an Euler-Maruyama discretization of the SDE 81 $$ 82 dX(t) = f(t, X(t), \theta) dt + g(t, X(t), \theta) dW(t), 83 $$ 84 which gives 85 $$ 86 X(t + \Delta t) = X(t) + f(t, X(t), \theta) \Delta t + g(t, X(t), \theta) \Delta W(t), 87 $$ 88 where $\Delta W(t) \sim N(0, \Delta t)$. 89 Thus, the negative log-likelihood $X(t+\Delta t)$ is given by 90 $$ 91 -\log p(X(t + \Delta t) | X(t), \theta) = \frac{1}{2} \log(2\pi) 92 + \frac{1}{2} \log(\text{det}(\Sigma)) 93 + \frac{1}{2} (X(t + \Delta t) - X(t))^T \Sigma^{-1} (X(t + \Delta t) - X(t)). 94 $$ 95 We will use `scipy.stats.multivariate_normal` to compute the log-likelihood. 96 """ 97 98 def _process_data_arrs(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array: 99 return t[:, :-1, :], t[:, 1:, :], x[:, :-1, :], x[:, 1:, :], args[:, :-1, :] 100 101 def compute_sample_loss( 102 self, 103 model: SDE, 104 t: ArrayLike, 105 t_plus: ArrayLike, 106 x: ArrayLike, 107 x_next: ArrayLike, 108 args: ArrayLike, 109 ) -> float: 110 """Computes the loss for a single sample. 111 112 Args: 113 model (SDE): the model to be trained 114 t (ArrayLike): time 115 t_plus (ArrayLike): time shifted by one step 116 x (ArrayLike): state 117 x_next (ArrayLike): state shifted by one step 118 args (ArrayLike): arguments for the current time step 119 120 Returns: 121 float: loss value 122 """ 123 drift = model.drift(t, x, args) 124 diffusion = model.diffusion(t, x, args) 125 dt = t_plus - t 126 data = (x_next - x) / dt 127 mean = drift 128 cov = (1 / dt) * diffusion @ diffusion.T 129 return -multivariate_normal.logpdf(data, mean, cov) 130 131 132class ReconLoss(Loss): 133 r"""Computes the reconstruction loss for the encoder and decoder components. 134 135 The loss is defined by 136 $$ 137 \| X - \text{decoder}(\text{encoder}(X)) \|^2. 138 $$ 139 """ 140 141 def compute_sample_loss(self, model: ReducedSDE, x: ArrayLike) -> float: 142 """Computes the loss for a single sample. 143 144 Args: 145 model (ReducedSDE): the model to be trained 146 x (ArrayLike): state 147 148 Returns: 149 float: loss value 150 """ 151 z = model.encoder(x) 152 x_recon = model.decoder(z) 153 return jnp.mean((x - x_recon) ** 2) 154 155 156class CompareLoss(Loss): 157 r"""Computes the comparison loss for the encoder and decoder components against PCA. 158 159 This loss is defined by 160 $$ 161 \max(0, \log(\text{recon_loss_model}) - \log(\text{recon_loss_pca})), 162 $$ 163 where `recon_loss_model` is the reconstruction loss for the model 164 and `recon_loss_pca` is the reconstruction loss for PCA. 165 """ 166 167 def compute_sample_loss(self, model: ReducedSDE, x: ArrayLike) -> float: 168 """Computes the loss for a single sample. 169 170 Args: 171 model (ReducedSDE): the model to be trained 172 x (ArrayLike): state 173 174 Returns: 175 float: loss value 176 """ 177 x_recon_model = model.decoder(model.encoder(x)) 178 x_recon_pca = model.decoder.inverse_closure_transform.inverse_pca_transform( 179 model.encoder.closure_transform.pca_transform(x) 180 ) 181 182 recon_loss_model = jnp.mean((x - x_recon_model) ** 2) 183 recon_loss_pca = jnp.mean((x - x_recon_pca) ** 2) 184 185 return jax.nn.relu(jnp.log(recon_loss_model) - jnp.log(recon_loss_pca))
22class Loss(eqx.Module): 23 """Base class for losses 24 25 Computes the loss over a batch of samples.pdoc --math my_module 26 27 Args: 28 model (eqx.Module): the model to be trained 29 data_arrs (tuple[ArrayLike, ...]): input data arrays of the format (t, x, args) 30 31 Returns: 32 float: loss value 33 """ 34 35 @abstractmethod 36 def compute_sample_loss( 37 self, model: eqx.Module, *processed_data_arrs: tuple[ArrayLike, ...] 38 ) -> Array: 39 """Computes the loss for a single sample. 40 41 Args: 42 model (eqx.Module): the model to be trained 43 processed_data_arrs (tuple[ArrayLike, ...]): processed data arrays 44 45 Returns: 46 Array: the loss value 47 """ 48 pass 49 50 def _process_data_arrs( 51 self, *data_arrs: tuple[ArrayLike, ...] 52 ) -> tuple[ArrayLike, ...]: 53 """Process the data arrays to a format that can be used by the loss function. 54 55 The input data arrays are in the standard format of (t, x, args). 56 The format of the output arrays depend on the loss. 57 58 Args: 59 data_arrs (tuple[ArrayLike, ...]): input data arrays 60 61 Returns: 62 tuple[ArrayLike, ...]: processed data arrays 63 """ 64 return data_arrs 65 66 @eqx.filter_jit 67 def __call__(self, model: eqx.Module, *data_arrs: tuple[ArrayLike, ...]) -> float: 68 data_arrs_processed = self._process_data_arrs(*data_arrs) 69 in_axes = (None,) + len(data_arrs_processed) * (0,) 70 compute_loss_over_time = jax.vmap(self.compute_sample_loss, in_axes=in_axes) 71 compute_loss = jax.vmap(compute_loss_over_time, in_axes=in_axes) 72 sample_losses = compute_loss(model, *data_arrs_processed) 73 74 return jnp.mean(sample_losses)
Base class for losses
Computes the loss over a batch of samples.pdoc --math my_module
Arguments:
- model (eqx.Module): the model to be trained
- data_arrs (tuple[ArrayLike, ...]): input data arrays of the format (t, x, args)
Returns:
float: loss value
77class MLELoss(Loss): 78 r"""Computes the maximum likelihood estimation loss for the SDE 79 80 The loss is the negative log-likelihood of the data given the model. 81 By an Euler-Maruyama discretization of the SDE 82 $$ 83 dX(t) = f(t, X(t), \theta) dt + g(t, X(t), \theta) dW(t), 84 $$ 85 which gives 86 $$ 87 X(t + \Delta t) = X(t) + f(t, X(t), \theta) \Delta t + g(t, X(t), \theta) \Delta W(t), 88 $$ 89 where $\Delta W(t) \sim N(0, \Delta t)$. 90 Thus, the negative log-likelihood $X(t+\Delta t)$ is given by 91 $$ 92 -\log p(X(t + \Delta t) | X(t), \theta) = \frac{1}{2} \log(2\pi) 93 + \frac{1}{2} \log(\text{det}(\Sigma)) 94 + \frac{1}{2} (X(t + \Delta t) - X(t))^T \Sigma^{-1} (X(t + \Delta t) - X(t)). 95 $$ 96 We will use `scipy.stats.multivariate_normal` to compute the log-likelihood. 97 """ 98 99 def _process_data_arrs(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array: 100 return t[:, :-1, :], t[:, 1:, :], x[:, :-1, :], x[:, 1:, :], args[:, :-1, :] 101 102 def compute_sample_loss( 103 self, 104 model: SDE, 105 t: ArrayLike, 106 t_plus: ArrayLike, 107 x: ArrayLike, 108 x_next: ArrayLike, 109 args: ArrayLike, 110 ) -> float: 111 """Computes the loss for a single sample. 112 113 Args: 114 model (SDE): the model to be trained 115 t (ArrayLike): time 116 t_plus (ArrayLike): time shifted by one step 117 x (ArrayLike): state 118 x_next (ArrayLike): state shifted by one step 119 args (ArrayLike): arguments for the current time step 120 121 Returns: 122 float: loss value 123 """ 124 drift = model.drift(t, x, args) 125 diffusion = model.diffusion(t, x, args) 126 dt = t_plus - t 127 data = (x_next - x) / dt 128 mean = drift 129 cov = (1 / dt) * diffusion @ diffusion.T 130 return -multivariate_normal.logpdf(data, mean, cov)
Computes the maximum likelihood estimation loss for the SDE
The loss is the negative log-likelihood of the data given the model.
By an Euler-Maruyama discretization of the SDE
$$
dX(t) = f(t, X(t), \theta) dt + g(t, X(t), \theta) dW(t),
$$
which gives
$$
X(t + \Delta t) = X(t) + f(t, X(t), \theta) \Delta t + g(t, X(t), \theta) \Delta W(t),
$$
where $\Delta W(t) \sim N(0, \Delta t)$.
Thus, the negative log-likelihood $X(t+\Delta t)$ is given by
$$
-\log p(X(t + \Delta t) | X(t), \theta) = \frac{1}{2} \log(2\pi)
+ \frac{1}{2} \log(\text{det}(\Sigma))
+ \frac{1}{2} (X(t + \Delta t) - X(t))^T \Sigma^{-1} (X(t + \Delta t) - X(t)).
$$
We will use scipy.stats.multivariate_normal
to compute the log-likelihood.
Computes the loss for a single sample.
Arguments:
- model (SDE): the model to be trained
- t (ArrayLike): time
- t_plus (ArrayLike): time shifted by one step
- x (ArrayLike): state
- x_next (ArrayLike): state shifted by one step
- args (ArrayLike): arguments for the current time step
Returns:
float: loss value
133class ReconLoss(Loss): 134 r"""Computes the reconstruction loss for the encoder and decoder components. 135 136 The loss is defined by 137 $$ 138 \| X - \text{decoder}(\text{encoder}(X)) \|^2. 139 $$ 140 """ 141 142 def compute_sample_loss(self, model: ReducedSDE, x: ArrayLike) -> float: 143 """Computes the loss for a single sample. 144 145 Args: 146 model (ReducedSDE): the model to be trained 147 x (ArrayLike): state 148 149 Returns: 150 float: loss value 151 """ 152 z = model.encoder(x) 153 x_recon = model.decoder(z) 154 return jnp.mean((x - x_recon) ** 2)
Computes the reconstruction loss for the encoder and decoder components.
The loss is defined by $$ \| X - \text{decoder}(\text{encoder}(X)) \|^2. $$
157class CompareLoss(Loss): 158 r"""Computes the comparison loss for the encoder and decoder components against PCA. 159 160 This loss is defined by 161 $$ 162 \max(0, \log(\text{recon_loss_model}) - \log(\text{recon_loss_pca})), 163 $$ 164 where `recon_loss_model` is the reconstruction loss for the model 165 and `recon_loss_pca` is the reconstruction loss for PCA. 166 """ 167 168 def compute_sample_loss(self, model: ReducedSDE, x: ArrayLike) -> float: 169 """Computes the loss for a single sample. 170 171 Args: 172 model (ReducedSDE): the model to be trained 173 x (ArrayLike): state 174 175 Returns: 176 float: loss value 177 """ 178 x_recon_model = model.decoder(model.encoder(x)) 179 x_recon_pca = model.decoder.inverse_closure_transform.inverse_pca_transform( 180 model.encoder.closure_transform.pca_transform(x) 181 ) 182 183 recon_loss_model = jnp.mean((x - x_recon_model) ** 2) 184 recon_loss_pca = jnp.mean((x - x_recon_pca) ** 2) 185 186 return jax.nn.relu(jnp.log(recon_loss_model) - jnp.log(recon_loss_pca))
Computes the comparison loss for the encoder and decoder components against PCA.
This loss is defined by
$$
\max(0, \log(\text{recon_loss_model}) - \log(\text{recon_loss_pca})),
$$
where recon_loss_model
is the reconstruction loss for the model
and recon_loss_pca
is the reconstruction loss for PCA.