onsagernet._losses

Custom loss functions

This module contains custom loss functions for training the models.

The case class Loss is an abstract class that defines the interface for loss functions. Sub-classes must implement the Loss.compute_sample_loss method.

#

  1"""
  2# Custom loss functions
  3
  4This module contains custom loss functions for training the models.
  5
  6The case class `Loss` is an abstract class that defines the interface for loss functions.
  7Sub-classes must implement the `Loss.compute_sample_loss` method.
  8
  9# """
 10
 11import jax
 12import equinox as eqx
 13import jax.numpy as jnp
 14from abc import abstractmethod
 15from jax.scipy.stats import multivariate_normal
 16from jax.typing import ArrayLike
 17from jax import Array
 18from .dynamics import SDE, ReducedSDE
 19
 20
 21class Loss(eqx.Module):
 22    """Base class for losses
 23
 24    Computes the loss over a batch of samples.pdoc --math my_module
 25
 26    Args:
 27        model (eqx.Module): the model to be trained
 28        data_arrs (tuple[ArrayLike, ...]): input data arrays of the format (t, x, args)
 29
 30    Returns:
 31        float: loss value
 32    """
 33
 34    @abstractmethod
 35    def compute_sample_loss(
 36        self, model: eqx.Module, *processed_data_arrs: tuple[ArrayLike, ...]
 37    ) -> Array:
 38        """Computes the loss for a single sample.
 39
 40        Args:
 41            model (eqx.Module): the model to be trained
 42            processed_data_arrs (tuple[ArrayLike, ...]): processed data arrays
 43
 44        Returns:
 45            Array: the loss value
 46        """
 47        pass
 48
 49    def _process_data_arrs(
 50        self, *data_arrs: tuple[ArrayLike, ...]
 51    ) -> tuple[ArrayLike, ...]:
 52        """Process the data arrays to a format that can be used by the loss function.
 53
 54        The input data arrays are in the standard format of (t, x, args).
 55        The format of the output arrays depend on the loss.
 56
 57        Args:
 58            data_arrs (tuple[ArrayLike, ...]): input data arrays
 59
 60        Returns:
 61            tuple[ArrayLike, ...]: processed data arrays
 62        """
 63        return data_arrs
 64
 65    @eqx.filter_jit
 66    def __call__(self, model: eqx.Module, *data_arrs: tuple[ArrayLike, ...]) -> float:
 67        data_arrs_processed = self._process_data_arrs(*data_arrs)
 68        in_axes = (None,) + len(data_arrs_processed) * (0,)
 69        compute_loss_over_time = jax.vmap(self.compute_sample_loss, in_axes=in_axes)
 70        compute_loss = jax.vmap(compute_loss_over_time, in_axes=in_axes)
 71        sample_losses = compute_loss(model, *data_arrs_processed)
 72
 73        return jnp.mean(sample_losses)
 74
 75
 76class MLELoss(Loss):
 77    r"""Computes the maximum likelihood estimation loss for the SDE
 78
 79    The loss is the negative log-likelihood of the data given the model.
 80    By an Euler-Maruyama discretization of the SDE
 81    $$
 82        dX(t) = f(t, X(t), \theta) dt + g(t, X(t), \theta) dW(t),
 83    $$
 84    which gives
 85    $$
 86        X(t + \Delta t) = X(t) + f(t, X(t), \theta) \Delta t + g(t, X(t), \theta) \Delta W(t),
 87    $$
 88    where $\Delta W(t) \sim N(0, \Delta t)$.
 89    Thus, the negative log-likelihood $X(t+\Delta t)$ is given by
 90    $$
 91        -\log p(X(t + \Delta t) | X(t), \theta) = \frac{1}{2} \log(2\pi)
 92        + \frac{1}{2} \log(\text{det}(\Sigma))
 93        + \frac{1}{2} (X(t + \Delta t) - X(t))^T \Sigma^{-1} (X(t + \Delta t) - X(t)).
 94    $$
 95    We will use `scipy.stats.multivariate_normal` to compute the log-likelihood.
 96    """
 97
 98    def _process_data_arrs(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array:
 99        return t[:, :-1, :], t[:, 1:, :], x[:, :-1, :], x[:, 1:, :], args[:, :-1, :]
100
101    def compute_sample_loss(
102        self,
103        model: SDE,
104        t: ArrayLike,
105        t_plus: ArrayLike,
106        x: ArrayLike,
107        x_next: ArrayLike,
108        args: ArrayLike,
109    ) -> float:
110        """Computes the loss for a single sample.
111
112        Args:
113            model (SDE): the model to be trained
114            t (ArrayLike): time
115            t_plus (ArrayLike): time shifted by one step
116            x (ArrayLike): state
117            x_next (ArrayLike): state shifted by one step
118            args (ArrayLike): arguments for the current time step
119
120        Returns:
121            float: loss value
122        """
123        drift = model.drift(t, x, args)
124        diffusion = model.diffusion(t, x, args)
125        dt = t_plus - t
126        data = (x_next - x) / dt
127        mean = drift
128        cov = (1 / dt) * diffusion @ diffusion.T
129        return -multivariate_normal.logpdf(data, mean, cov)
130
131
132class ReconLoss(Loss):
133    r"""Computes the reconstruction loss for the encoder and decoder components.
134
135    The loss is defined by
136    $$
137        \| X - \text{decoder}(\text{encoder}(X)) \|^2.
138    $$
139    """
140
141    def compute_sample_loss(self, model: ReducedSDE, x: ArrayLike) -> float:
142        """Computes the loss for a single sample.
143
144        Args:
145            model (ReducedSDE): the model to be trained
146            x (ArrayLike): state
147
148        Returns:
149            float: loss value
150        """
151        z = model.encoder(x)
152        x_recon = model.decoder(z)
153        return jnp.mean((x - x_recon) ** 2)
154
155
156class CompareLoss(Loss):
157    r"""Computes the comparison loss for the encoder and decoder components against PCA.
158
159    This loss is defined by
160    $$
161        \max(0, \log(\text{recon_loss_model}) - \log(\text{recon_loss_pca})),
162    $$
163    where `recon_loss_model` is the reconstruction loss for the model
164    and `recon_loss_pca` is the reconstruction loss for PCA.
165    """
166
167    def compute_sample_loss(self, model: ReducedSDE, x: ArrayLike) -> float:
168        """Computes the loss for a single sample.
169
170        Args:
171            model (ReducedSDE): the model to be trained
172            x (ArrayLike): state
173
174        Returns:
175            float: loss value
176        """
177        x_recon_model = model.decoder(model.encoder(x))
178        x_recon_pca = model.decoder.inverse_closure_transform.inverse_pca_transform(
179            model.encoder.closure_transform.pca_transform(x)
180        )
181
182        recon_loss_model = jnp.mean((x - x_recon_model) ** 2)
183        recon_loss_pca = jnp.mean((x - x_recon_pca) ** 2)
184
185        return jax.nn.relu(jnp.log(recon_loss_model) - jnp.log(recon_loss_pca))
class Loss(equinox._module.Module):
22class Loss(eqx.Module):
23    """Base class for losses
24
25    Computes the loss over a batch of samples.pdoc --math my_module
26
27    Args:
28        model (eqx.Module): the model to be trained
29        data_arrs (tuple[ArrayLike, ...]): input data arrays of the format (t, x, args)
30
31    Returns:
32        float: loss value
33    """
34
35    @abstractmethod
36    def compute_sample_loss(
37        self, model: eqx.Module, *processed_data_arrs: tuple[ArrayLike, ...]
38    ) -> Array:
39        """Computes the loss for a single sample.
40
41        Args:
42            model (eqx.Module): the model to be trained
43            processed_data_arrs (tuple[ArrayLike, ...]): processed data arrays
44
45        Returns:
46            Array: the loss value
47        """
48        pass
49
50    def _process_data_arrs(
51        self, *data_arrs: tuple[ArrayLike, ...]
52    ) -> tuple[ArrayLike, ...]:
53        """Process the data arrays to a format that can be used by the loss function.
54
55        The input data arrays are in the standard format of (t, x, args).
56        The format of the output arrays depend on the loss.
57
58        Args:
59            data_arrs (tuple[ArrayLike, ...]): input data arrays
60
61        Returns:
62            tuple[ArrayLike, ...]: processed data arrays
63        """
64        return data_arrs
65
66    @eqx.filter_jit
67    def __call__(self, model: eqx.Module, *data_arrs: tuple[ArrayLike, ...]) -> float:
68        data_arrs_processed = self._process_data_arrs(*data_arrs)
69        in_axes = (None,) + len(data_arrs_processed) * (0,)
70        compute_loss_over_time = jax.vmap(self.compute_sample_loss, in_axes=in_axes)
71        compute_loss = jax.vmap(compute_loss_over_time, in_axes=in_axes)
72        sample_losses = compute_loss(model, *data_arrs_processed)
73
74        return jnp.mean(sample_losses)

Base class for losses

Computes the loss over a batch of samples.pdoc --math my_module

Arguments:
  • model (eqx.Module): the model to be trained
  • data_arrs (tuple[ArrayLike, ...]): input data arrays of the format (t, x, args)
Returns:

float: loss value

def compute_sample_loss(unknown):

Computes the loss for a single sample.

Arguments:
  • model (eqx.Module): the model to be trained
  • processed_data_arrs (tuple[ArrayLike, ...]): processed data arrays
Returns:

Array: the loss value

class MLELoss(Loss):
 77class MLELoss(Loss):
 78    r"""Computes the maximum likelihood estimation loss for the SDE
 79
 80    The loss is the negative log-likelihood of the data given the model.
 81    By an Euler-Maruyama discretization of the SDE
 82    $$
 83        dX(t) = f(t, X(t), \theta) dt + g(t, X(t), \theta) dW(t),
 84    $$
 85    which gives
 86    $$
 87        X(t + \Delta t) = X(t) + f(t, X(t), \theta) \Delta t + g(t, X(t), \theta) \Delta W(t),
 88    $$
 89    where $\Delta W(t) \sim N(0, \Delta t)$.
 90    Thus, the negative log-likelihood $X(t+\Delta t)$ is given by
 91    $$
 92        -\log p(X(t + \Delta t) | X(t), \theta) = \frac{1}{2} \log(2\pi)
 93        + \frac{1}{2} \log(\text{det}(\Sigma))
 94        + \frac{1}{2} (X(t + \Delta t) - X(t))^T \Sigma^{-1} (X(t + \Delta t) - X(t)).
 95    $$
 96    We will use `scipy.stats.multivariate_normal` to compute the log-likelihood.
 97    """
 98
 99    def _process_data_arrs(self, t: ArrayLike, x: ArrayLike, args: ArrayLike) -> Array:
100        return t[:, :-1, :], t[:, 1:, :], x[:, :-1, :], x[:, 1:, :], args[:, :-1, :]
101
102    def compute_sample_loss(
103        self,
104        model: SDE,
105        t: ArrayLike,
106        t_plus: ArrayLike,
107        x: ArrayLike,
108        x_next: ArrayLike,
109        args: ArrayLike,
110    ) -> float:
111        """Computes the loss for a single sample.
112
113        Args:
114            model (SDE): the model to be trained
115            t (ArrayLike): time
116            t_plus (ArrayLike): time shifted by one step
117            x (ArrayLike): state
118            x_next (ArrayLike): state shifted by one step
119            args (ArrayLike): arguments for the current time step
120
121        Returns:
122            float: loss value
123        """
124        drift = model.drift(t, x, args)
125        diffusion = model.diffusion(t, x, args)
126        dt = t_plus - t
127        data = (x_next - x) / dt
128        mean = drift
129        cov = (1 / dt) * diffusion @ diffusion.T
130        return -multivariate_normal.logpdf(data, mean, cov)

Computes the maximum likelihood estimation loss for the SDE

The loss is the negative log-likelihood of the data given the model. By an Euler-Maruyama discretization of the SDE $$ dX(t) = f(t, X(t), \theta) dt + g(t, X(t), \theta) dW(t), $$ which gives $$ X(t + \Delta t) = X(t) + f(t, X(t), \theta) \Delta t + g(t, X(t), \theta) \Delta W(t), $$ where $\Delta W(t) \sim N(0, \Delta t)$. Thus, the negative log-likelihood $X(t+\Delta t)$ is given by $$ -\log p(X(t + \Delta t) | X(t), \theta) = \frac{1}{2} \log(2\pi) + \frac{1}{2} \log(\text{det}(\Sigma)) + \frac{1}{2} (X(t + \Delta t) - X(t))^T \Sigma^{-1} (X(t + \Delta t) - X(t)). $$ We will use scipy.stats.multivariate_normal to compute the log-likelihood.

def compute_sample_loss(unknown):

Computes the loss for a single sample.

Arguments:
  • model (SDE): the model to be trained
  • t (ArrayLike): time
  • t_plus (ArrayLike): time shifted by one step
  • x (ArrayLike): state
  • x_next (ArrayLike): state shifted by one step
  • args (ArrayLike): arguments for the current time step
Returns:

float: loss value

class ReconLoss(Loss):
133class ReconLoss(Loss):
134    r"""Computes the reconstruction loss for the encoder and decoder components.
135
136    The loss is defined by
137    $$
138        \| X - \text{decoder}(\text{encoder}(X)) \|^2.
139    $$
140    """
141
142    def compute_sample_loss(self, model: ReducedSDE, x: ArrayLike) -> float:
143        """Computes the loss for a single sample.
144
145        Args:
146            model (ReducedSDE): the model to be trained
147            x (ArrayLike): state
148
149        Returns:
150            float: loss value
151        """
152        z = model.encoder(x)
153        x_recon = model.decoder(z)
154        return jnp.mean((x - x_recon) ** 2)

Computes the reconstruction loss for the encoder and decoder components.

The loss is defined by $$ \| X - \text{decoder}(\text{encoder}(X)) \|^2. $$

def compute_sample_loss(unknown):

Computes the loss for a single sample.

Arguments:
  • model (ReducedSDE): the model to be trained
  • x (ArrayLike): state
Returns:

float: loss value

class CompareLoss(Loss):
157class CompareLoss(Loss):
158    r"""Computes the comparison loss for the encoder and decoder components against PCA.
159
160    This loss is defined by
161    $$
162        \max(0, \log(\text{recon_loss_model}) - \log(\text{recon_loss_pca})),
163    $$
164    where `recon_loss_model` is the reconstruction loss for the model
165    and `recon_loss_pca` is the reconstruction loss for PCA.
166    """
167
168    def compute_sample_loss(self, model: ReducedSDE, x: ArrayLike) -> float:
169        """Computes the loss for a single sample.
170
171        Args:
172            model (ReducedSDE): the model to be trained
173            x (ArrayLike): state
174
175        Returns:
176            float: loss value
177        """
178        x_recon_model = model.decoder(model.encoder(x))
179        x_recon_pca = model.decoder.inverse_closure_transform.inverse_pca_transform(
180            model.encoder.closure_transform.pca_transform(x)
181        )
182
183        recon_loss_model = jnp.mean((x - x_recon_model) ** 2)
184        recon_loss_pca = jnp.mean((x - x_recon_pca) ** 2)
185
186        return jax.nn.relu(jnp.log(recon_loss_model) - jnp.log(recon_loss_pca))

Computes the comparison loss for the encoder and decoder components against PCA.

This loss is defined by $$ \max(0, \log(\text{recon_loss_model}) - \log(\text{recon_loss_pca})), $$ where recon_loss_model is the reconstruction loss for the model and recon_loss_pca is the reconstruction loss for PCA.

def compute_sample_loss(unknown):

Computes the loss for a single sample.

Arguments:
  • model (ReducedSDE): the model to be trained
  • x (ArrayLike): state
Returns:

float: loss value