onsagernet.trainers

Basic trainers for SDE models

This module implements basic training routines for SDE models. The base class is SDETrainer, which provides the base training logic. The sub-classes are required to implement the SDETrainer.loss_func that is used to train the model.

Training routines

We provide here two training routines.

The following example shows how to train an onsagernet.dynamics.OnsagerNet model using the MLETrainer.

from onsagernet.dynamics import OnsagerNet
from onsagernet.trainers import MLETrainer

sde = OnsagerNet(...)
dataset = load_data(...)  # return a datasets.Dataset object

trainer = MLETrainer(opt_options=config.train.opt, rop_options=config.train.rop)
sde, losses, _ = trainer.train(  # trains the model `sde` for 10 epochs with batch size 2
    model=sde,
    dataset=dataset,
    num_epochs=10,
    batch_size=2,  # batch size should be typically small since this yields [n_batch, n_steps, n_dim] data
)

Dataset format

The dataset is assumed to be a huggingface datasets.Dataset object with three columns: t, x, and args. Here, t is time, x is state, and args are additional arguments. Let n_examples be the number of examples, n_dim be the dimension of the state x, n_args be the dimension of the arguments, and n_steps be the number of time steps. Then, the shapes of the data entries are as follows:

  • t: (n_examples, n_steps, 1) - this is the sequence of time steps sampled in the time series, which can vary from one example to another (but the n_steps must be the same), so consider padding if this were not the case
  • x: (n_examples, n_steps, n_dim) - this is the corresponding sequence of states
  • args: (n_examples, n_steps, n_args) - this is the sequence of additional arguments, which can be constant or varying in time. If constant, you must broad-cast the values along the time dimension so that the expected shapes are maintained.

We also assume that n_args >= 1, with the first dimension always being temperature, since it is relevant for all SDE models in physical modelling. Your custom loss_func routine must return a scalar loss value.

In fact, it is not strictly necessary to use a datasets.Dataset object, but any object that can yield the correct data format by calling dataset.iter(batch_size).

Writing custom training routines

You can write custom training routines (with custom losses, for example) by sub-classing SDETrainer. Sub-classes must implement the SDETrainer.loss_func method.

Given a model partitioned by eqx.partition(model, ...), into a trainable part diff_model and a static part static_model, the SDETrainer.loss_func routine computes the loss of the model on the given dataset. We always assume that the batch data is a tuple of (t, x, args).

See above for more on the format of the dataset and MLETrainer.loss_func, or the more complex ClosureMLETrainer.loss_func for examples.

References

  1. Chen, X. et al. Constructing custom thermodynamics using deep learning. Nature Computational Science 4, 66–85 (2024).
  1r"""
  2# Basic trainers for SDE models
  3
  4This module implements basic training routines for SDE models.
  5The base class is `SDETrainer`, which provides the base training logic.
  6The sub-classes are required to implement the `SDETrainer.loss_func` that is used to train the model.
  7
  8## Training routines
  9
 10We provide here two training routines.
 11
 12- `MLETrainer`: this implements the maximum likelihood loss [`MLELoss`](./_losses.html#MLELoss)
 13  to estimate SDE drift and diffusion
 14  through computing the [maximum-likelihood](https://en.wikipedia.org/wiki/Maximum_likelihood_estimation) following the
 15  [Euler-Maruyama discretisation](https://en.wikipedia.org/wiki/Euler–Maruyama_method) of the SDE.
 16- `ClosureMLETrainer`: this combines [`MLELoss`](./_losses.html#MLELoss) with additional losses to enforce closure,
 17  namely a reconstruction loss [`ReconLoss`](./_losses.html#ReconLoss) and a
 18  comparison loss [`CompareLoss`](./_losses.html#CompareLoss).
 19  This follows the implementation of [1].
 20
 21The following example shows how to train an `onsagernet.dynamics.OnsagerNet` model using the `MLETrainer`.
 22```python
 23from onsagernet.dynamics import OnsagerNet
 24from onsagernet.trainers import MLETrainer
 25
 26sde = OnsagerNet(...)
 27dataset = load_data(...)  # return a datasets.Dataset object
 28
 29trainer = MLETrainer(opt_options=config.train.opt, rop_options=config.train.rop)
 30sde, losses, _ = trainer.train(  # trains the model `sde` for 10 epochs with batch size 2
 31    model=sde,
 32    dataset=dataset,
 33    num_epochs=10,
 34    batch_size=2,  # batch size should be typically small since this yields [n_batch, n_steps, n_dim] data
 35)
 36```
 37
 38## Dataset format
 39The dataset is assumed
 40to be a huggingface [`datasets.Dataset`](https://huggingface.co/docs/datasets/en/index)
 41object with three columns: `t`, `x`, and `args`.
 42Here, `t` is time, `x` is state, and `args` are additional arguments.
 43Let `n_examples` be the number of examples, `n_dim` be the dimension of the state x,
 44`n_args` be the dimension of the arguments, and `n_steps` be the number of time steps.
 45Then, the shapes of the data entries are as follows:
 46- t: (`n_examples`, `n_steps`, `1`) - this is the sequence of time steps sampled in the time series,
 47  which can vary from one example to another (but the `n_steps` must be the same), so consider padding
 48  if this were not the case
 49- x: (`n_examples`, `n_steps`, `n_dim`) - this is the corresponding sequence of states
 50- args: (`n_examples`, `n_steps`, `n_args`) - this is the sequence of additional arguments,
 51  which can be constant or varying in time. If constant, you must broad-cast the values
 52  along the time dimension so that the expected shapes are maintained.
 53
 54We also assume that `n_args` >= `1`, with the first dimension always being
 55temperature, since it is relevant for all SDE models in physical modelling.
 56Your custom `loss_func` routine must return a scalar loss value.
 57
 58*In fact, it is not strictly necessary to use a [`datasets.Dataset`](https://huggingface.co/docs/datasets/en/index)
 59object, but any object that can yield the correct data format by calling `dataset.iter(batch_size)`.*
 60
 61## Writing custom training routines
 62
 63You can write custom training routines (with custom losses, for example)
 64by sub-classing `SDETrainer`.
 65Sub-classes must implement the `SDETrainer.loss_func` method.
 66
 67Given a model partitioned by `eqx.partition(model, ...)`,
 68into a trainable part `diff_model` and a static part `static_model`,
 69the `SDETrainer.loss_func` routine computes the loss of the model on the given dataset.
 70We always assume that the batch data is a tuple of `(t, x, args)`.
 71
 72See above for more on the format of the dataset and
 73`MLETrainer.loss_func`, or the more complex `ClosureMLETrainer.loss_func` for examples.
 74
 75## References
 76
 771. Chen, X. et al. *Constructing custom thermodynamics using deep learning*. Nature Computational Science **4**, 66–85 (2024).
 78
 79
 80"""
 81
 82import os
 83import jax
 84import jax.numpy as jnp
 85import equinox as eqx
 86from abc import ABC, abstractmethod
 87
 88from optax import adam, chain
 89from optax.contrib import reduce_on_plateau
 90from optax.tree_utils import tree_get
 91from jax.tree_util import tree_map
 92
 93from tqdm import tqdm
 94
 95from ._losses import MLELoss, ReconLoss, CompareLoss
 96
 97# ------------------------- Typing imports ------------------------- #
 98
 99from chex import ArrayTree  # to silence pdoc warnings
100from .dynamics import SDE, ReducedSDE
101from typing import Optional, Any, Union
102from jax.typing import ArrayLike
103from jax import Array
104from datasets import Dataset
105from optax import GradientTransformation, OptState
106from logging import Logger
107
108DynamicModel = Union[SDE, ReducedSDE]
109
110# ------------------------------------------------------------------ #
111#                              Trainers                              #
112# ------------------------------------------------------------------ #
113
114
115class SDETrainer(ABC):
116    """Base class for training SDE models."""
117
118    def __init__(
119        self, opt_options: dict, rop_options: dict, loss_options: Optional[dict] = None
120    ) -> None:
121        """SDE training routine.
122
123        Args:
124            opt_options (dict): dictionary of options for the optimiser
125            rop_options (dict): dictionary of options for the reduce-on-plateau callback
126            loss_options (Optional[dict], optional): dictionary of options for loss computation. Defaults to None.
127        """
128        self._opt_options = opt_options
129        self._rop_options = rop_options
130        self._loss_options = loss_options
131
132    @eqx.filter_jit
133    @abstractmethod
134    def loss_func(
135        self,
136        diff_model: DynamicModel,
137        static_model: DynamicModel,
138        t: ArrayLike,
139        x: ArrayLike,
140        args: ArrayLike,
141    ) -> float:
142        """Loss function.
143
144        This must be implemented by sub-classes.
145
146        Args:
147            diff_model (DynamicModel): the trainable part of the model
148            static_model (DynamicModel): the static part of the model
149            t (ArrayLike): time
150            x (ArrayLike): state
151            args (ArrayLike): additional arguments or parameters. The first dimension is temperature.
152
153        Returns:
154            float: loss value
155        """
156        pass
157
158    def _make_optimiser(
159        self, opt_options: dict, rop_options: dict
160    ) -> GradientTransformation:
161        """Make an optimiser.
162
163        Args:
164            opt_options (dict): optimiser options
165            rop_options (dict): reduce-on-plateau options
166
167        Returns:
168            GradientTransformation: an optimiser object from `optax`
169        """
170        return chain(adam(**opt_options), reduce_on_plateau(**rop_options))
171
172    @eqx.filter_jit
173    def _make_step(
174        self,
175        model: DynamicModel,
176        data: Dataset,
177        opt: GradientTransformation,
178        opt_state: OptState,
179        filter_spec: Any,
180    ) -> tuple[DynamicModel, OptState, float]:
181        """Make a training step.
182
183        Args:
184            model (DynamicModel): the model to be trained
185            data (Dataset): the dataset object
186            opt (GradientTransformation): optimiser object
187            opt_state (OptState): optimiser state
188            filter_spec (Any): the filtering logic to determine which parts of the model to train
189
190        Returns:
191            tuple[DynamicModel, OptState, float]: trained model, optimiser state, loss value
192        """
193        diff_model, static_model = eqx.partition(model, filter_spec)
194
195        loss_value, grads = eqx.filter_value_and_grad(self.loss_func)(
196            diff_model, static_model, *data
197        )
198        updates, opt_state = opt.update(grads, opt_state, model, value=loss_value)
199        model = eqx.apply_updates(model, updates)
200        return model, loss_value, opt_state
201
202    def _train_epoch(
203        self,
204        model: DynamicModel,
205        dataset: Dataset,
206        batch_size: int,
207        opt: GradientTransformation,
208        opt_state: OptState,
209        filter_spec: Any,
210    ) -> tuple[DynamicModel, float, OptState]:
211        """Train the model for an epoch.
212
213        Args:
214            model (DynamicModel): the model to be trained
215            dataset (Dataset): the dataset object
216            batch_size (int): the batch size
217            opt (GradientTransformation): the optimiser object
218            opt_state (OptState): the optimiser state
219            filter_spec (Any): the filtering logic to determine which parts of the model to train
220
221        Returns:
222            tuple[DynamicModel, float, OptState]: trained model, loss value, optimiser state
223        """
224
225        step_losses = []
226
227        for batch in tqdm(
228            dataset.iter(batch_size),
229            total=dataset.num_rows // batch_size,
230        ):
231            data_batch = (batch["t"], batch["x"], batch["args"])
232            model, train_loss, opt_state = self._make_step(
233                model, data_batch, opt, opt_state, filter_spec
234            )
235            step_losses.append(train_loss)
236        epoch_loss = jnp.mean(jnp.array(step_losses))
237        return model, epoch_loss, opt_state
238
239    def train(
240        self,
241        model: DynamicModel,
242        dataset: Dataset,
243        num_epochs: int,
244        batch_size: int,
245        logger: Optional[Logger] = None,
246        opt_state: Optional[OptState] = None,
247        filter_spec: Optional[Any] = None,
248        checkpoint_dir: Optional[str] = None,
249        checkpoint_every: Optional[int] = None,
250    ) -> tuple[DynamicModel, list[float], OptState]:
251        """The main training routine.
252
253        Args:
254            model (DynamicModel): the model to be trained
255            dataset (Dataset): the dataset
256            num_epochs (int): number of epochs to train
257            batch_size (int): the batch size
258            logger (Optional[Logger], optional): the logging object. Defaults to None.
259            opt_state (Optional[OptState], optional): the starting optimiser state. Defaults to None.
260            filter_spec (Optional[Any], optional): the filtering logic. Defaults to None.
261            checkpoint_dir (Optional[str], optional): the directory to save checkpoints. Defaults to None.
262            checkpoint_every (Optional[int], optional): checkpoints are saved every `checkpoint_every` number of epochs. Defaults to None.
263
264        Returns:
265            tuple[DynamicModel, list[float], OptState]: trained model, list of losses, optimiser state
266        """
267        opt = self._make_optimiser(self._opt_options, self._rop_options)
268        if opt_state is None:
269            opt_state = opt.init(eqx.filter(model, eqx.is_array))
270
271        if filter_spec is None:
272            filter_spec = tree_map(lambda _: True, model)
273
274        losses = []
275        for epoch in range(num_epochs):
276            model, step_loss, opt_state = self._train_epoch(
277                model=model,
278                dataset=dataset,
279                batch_size=batch_size,
280                opt=opt,
281                opt_state=opt_state,
282                filter_spec=filter_spec,
283            )
284            losses.append(step_loss)
285            if logger:
286                lr_scale = tree_get(opt_state, "scale")
287                logger.info(
288                    f"epoch={epoch:05d}, loss={step_loss:.6f}, lr_scale={lr_scale:.4f}"
289                )
290
291            if checkpoint_dir is not None and epoch % checkpoint_every == 0:
292                model_path = os.path.join(
293                    checkpoint_dir, f"model_epoch_{epoch:05d}.eqx"
294                )
295                eqx.tree_serialise_leaves(model_path, model)
296
297        return model, losses, opt_state
298
299
300class MLETrainer(SDETrainer):
301
302    @eqx.filter_jit
303    def loss_func(
304        self,
305        diff_model: DynamicModel,
306        static_model: DynamicModel,
307        t: ArrayLike,
308        x: ArrayLike,
309        args: ArrayLike,
310    ) -> float:
311        """The MLE loss function.
312
313        See [`MLELoss`](./_losses.html#MLELoss) for more details.
314
315        Args:
316            diff_model (DynamicModel): the trainable part of the model
317            static_model (DynamicModel): the static part of the model
318            t (ArrayLike): time
319            x (ArrayLike): state
320            args (ArrayLike): additional arguments or parameters.
321
322        Returns:
323            float: the computed loss
324        """
325        model = eqx.combine(diff_model, static_model)
326        return MLELoss()(model, t, x, args)
327
328
329class ClosureMLETrainer(MLETrainer):
330
331    @eqx.filter_jit
332    def loss_func(
333        self,
334        diff_model: DynamicModel,
335        static_model: DynamicModel,
336        t: ArrayLike,
337        x: ArrayLike,
338        args: ArrayLike,
339    ) -> float:
340        r"""The combined loss function for MLE training with closure modelling.
341
342        The losses are applied to the combine model
343        `model = eqx.combine(diff_model, static_model)`
344        with three parts
345        - [`MLELoss`](./_losses.html#MLELoss) applied to `model.sde`
346        - [`ReconLoss`](./_losses.html#ReconLoss) applied to the `model`
347        - [`CompareLoss`](./_losses.html#CompareLoss) applied to the `model`
348
349        The combined loss is given by
350        $$
351            \text{loss} = \text{loss_sde}
352            + \text{recon_weight} \times \text{loss_recon}
353            + \text{compare_weight} \times \text{loss_compare}
354        $$
355
356        The variables `recon_weight` and `compare_weight` are set
357        in the `loss_options` attribute.
358
359        Args:
360            diff_model (DynamicModel): the trainable part of the model
361            static_model (DynamicModel): the static part of the model
362            t (ArrayLike): time
363            x (ArrayLike): state
364            args (ArrayLike): additional arguments or parameters.
365
366        Returns:
367            float: the computed loss
368        """
369        model = eqx.combine(diff_model, static_model)
370        z = jax.vmap(jax.vmap(model.encoder))(x)
371        loss_sde = MLELoss()(model.sde, t, z, args)
372        loss_recon = ReconLoss()(model, x)
373        loss_compare = CompareLoss()(model, x)
374        return (
375            loss_sde
376            + self._loss_options["recon_weight"] * loss_recon
377            + self._loss_options["compare_weight"] * loss_compare
378        )
class SDETrainer(abc.ABC):
116class SDETrainer(ABC):
117    """Base class for training SDE models."""
118
119    def __init__(
120        self, opt_options: dict, rop_options: dict, loss_options: Optional[dict] = None
121    ) -> None:
122        """SDE training routine.
123
124        Args:
125            opt_options (dict): dictionary of options for the optimiser
126            rop_options (dict): dictionary of options for the reduce-on-plateau callback
127            loss_options (Optional[dict], optional): dictionary of options for loss computation. Defaults to None.
128        """
129        self._opt_options = opt_options
130        self._rop_options = rop_options
131        self._loss_options = loss_options
132
133    @eqx.filter_jit
134    @abstractmethod
135    def loss_func(
136        self,
137        diff_model: DynamicModel,
138        static_model: DynamicModel,
139        t: ArrayLike,
140        x: ArrayLike,
141        args: ArrayLike,
142    ) -> float:
143        """Loss function.
144
145        This must be implemented by sub-classes.
146
147        Args:
148            diff_model (DynamicModel): the trainable part of the model
149            static_model (DynamicModel): the static part of the model
150            t (ArrayLike): time
151            x (ArrayLike): state
152            args (ArrayLike): additional arguments or parameters. The first dimension is temperature.
153
154        Returns:
155            float: loss value
156        """
157        pass
158
159    def _make_optimiser(
160        self, opt_options: dict, rop_options: dict
161    ) -> GradientTransformation:
162        """Make an optimiser.
163
164        Args:
165            opt_options (dict): optimiser options
166            rop_options (dict): reduce-on-plateau options
167
168        Returns:
169            GradientTransformation: an optimiser object from `optax`
170        """
171        return chain(adam(**opt_options), reduce_on_plateau(**rop_options))
172
173    @eqx.filter_jit
174    def _make_step(
175        self,
176        model: DynamicModel,
177        data: Dataset,
178        opt: GradientTransformation,
179        opt_state: OptState,
180        filter_spec: Any,
181    ) -> tuple[DynamicModel, OptState, float]:
182        """Make a training step.
183
184        Args:
185            model (DynamicModel): the model to be trained
186            data (Dataset): the dataset object
187            opt (GradientTransformation): optimiser object
188            opt_state (OptState): optimiser state
189            filter_spec (Any): the filtering logic to determine which parts of the model to train
190
191        Returns:
192            tuple[DynamicModel, OptState, float]: trained model, optimiser state, loss value
193        """
194        diff_model, static_model = eqx.partition(model, filter_spec)
195
196        loss_value, grads = eqx.filter_value_and_grad(self.loss_func)(
197            diff_model, static_model, *data
198        )
199        updates, opt_state = opt.update(grads, opt_state, model, value=loss_value)
200        model = eqx.apply_updates(model, updates)
201        return model, loss_value, opt_state
202
203    def _train_epoch(
204        self,
205        model: DynamicModel,
206        dataset: Dataset,
207        batch_size: int,
208        opt: GradientTransformation,
209        opt_state: OptState,
210        filter_spec: Any,
211    ) -> tuple[DynamicModel, float, OptState]:
212        """Train the model for an epoch.
213
214        Args:
215            model (DynamicModel): the model to be trained
216            dataset (Dataset): the dataset object
217            batch_size (int): the batch size
218            opt (GradientTransformation): the optimiser object
219            opt_state (OptState): the optimiser state
220            filter_spec (Any): the filtering logic to determine which parts of the model to train
221
222        Returns:
223            tuple[DynamicModel, float, OptState]: trained model, loss value, optimiser state
224        """
225
226        step_losses = []
227
228        for batch in tqdm(
229            dataset.iter(batch_size),
230            total=dataset.num_rows // batch_size,
231        ):
232            data_batch = (batch["t"], batch["x"], batch["args"])
233            model, train_loss, opt_state = self._make_step(
234                model, data_batch, opt, opt_state, filter_spec
235            )
236            step_losses.append(train_loss)
237        epoch_loss = jnp.mean(jnp.array(step_losses))
238        return model, epoch_loss, opt_state
239
240    def train(
241        self,
242        model: DynamicModel,
243        dataset: Dataset,
244        num_epochs: int,
245        batch_size: int,
246        logger: Optional[Logger] = None,
247        opt_state: Optional[OptState] = None,
248        filter_spec: Optional[Any] = None,
249        checkpoint_dir: Optional[str] = None,
250        checkpoint_every: Optional[int] = None,
251    ) -> tuple[DynamicModel, list[float], OptState]:
252        """The main training routine.
253
254        Args:
255            model (DynamicModel): the model to be trained
256            dataset (Dataset): the dataset
257            num_epochs (int): number of epochs to train
258            batch_size (int): the batch size
259            logger (Optional[Logger], optional): the logging object. Defaults to None.
260            opt_state (Optional[OptState], optional): the starting optimiser state. Defaults to None.
261            filter_spec (Optional[Any], optional): the filtering logic. Defaults to None.
262            checkpoint_dir (Optional[str], optional): the directory to save checkpoints. Defaults to None.
263            checkpoint_every (Optional[int], optional): checkpoints are saved every `checkpoint_every` number of epochs. Defaults to None.
264
265        Returns:
266            tuple[DynamicModel, list[float], OptState]: trained model, list of losses, optimiser state
267        """
268        opt = self._make_optimiser(self._opt_options, self._rop_options)
269        if opt_state is None:
270            opt_state = opt.init(eqx.filter(model, eqx.is_array))
271
272        if filter_spec is None:
273            filter_spec = tree_map(lambda _: True, model)
274
275        losses = []
276        for epoch in range(num_epochs):
277            model, step_loss, opt_state = self._train_epoch(
278                model=model,
279                dataset=dataset,
280                batch_size=batch_size,
281                opt=opt,
282                opt_state=opt_state,
283                filter_spec=filter_spec,
284            )
285            losses.append(step_loss)
286            if logger:
287                lr_scale = tree_get(opt_state, "scale")
288                logger.info(
289                    f"epoch={epoch:05d}, loss={step_loss:.6f}, lr_scale={lr_scale:.4f}"
290                )
291
292            if checkpoint_dir is not None and epoch % checkpoint_every == 0:
293                model_path = os.path.join(
294                    checkpoint_dir, f"model_epoch_{epoch:05d}.eqx"
295                )
296                eqx.tree_serialise_leaves(model_path, model)
297
298        return model, losses, opt_state

Base class for training SDE models.

SDETrainer( opt_options: dict, rop_options: dict, loss_options: Optional[dict] = None)
119    def __init__(
120        self, opt_options: dict, rop_options: dict, loss_options: Optional[dict] = None
121    ) -> None:
122        """SDE training routine.
123
124        Args:
125            opt_options (dict): dictionary of options for the optimiser
126            rop_options (dict): dictionary of options for the reduce-on-plateau callback
127            loss_options (Optional[dict], optional): dictionary of options for loss computation. Defaults to None.
128        """
129        self._opt_options = opt_options
130        self._rop_options = rop_options
131        self._loss_options = loss_options

SDE training routine.

Arguments:
  • opt_options (dict): dictionary of options for the optimiser
  • rop_options (dict): dictionary of options for the reduce-on-plateau callback
  • loss_options (Optional[dict], optional): dictionary of options for loss computation. Defaults to None.
@eqx.filter_jit
@abstractmethod
def loss_func( self, diff_model: Union[onsagernet.dynamics.SDE, onsagernet.dynamics.ReducedSDE], static_model: Union[onsagernet.dynamics.SDE, onsagernet.dynamics.ReducedSDE], t: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], x: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], args: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]) -> float:
133    @eqx.filter_jit
134    @abstractmethod
135    def loss_func(
136        self,
137        diff_model: DynamicModel,
138        static_model: DynamicModel,
139        t: ArrayLike,
140        x: ArrayLike,
141        args: ArrayLike,
142    ) -> float:
143        """Loss function.
144
145        This must be implemented by sub-classes.
146
147        Args:
148            diff_model (DynamicModel): the trainable part of the model
149            static_model (DynamicModel): the static part of the model
150            t (ArrayLike): time
151            x (ArrayLike): state
152            args (ArrayLike): additional arguments or parameters. The first dimension is temperature.
153
154        Returns:
155            float: loss value
156        """
157        pass

Loss function.

This must be implemented by sub-classes.

Arguments:
  • diff_model (DynamicModel): the trainable part of the model
  • static_model (DynamicModel): the static part of the model
  • t (ArrayLike): time
  • x (ArrayLike): state
  • args (ArrayLike): additional arguments or parameters. The first dimension is temperature.
Returns:

float: loss value

def train( self, model: Union[onsagernet.dynamics.SDE, onsagernet.dynamics.ReducedSDE], dataset: datasets.arrow_dataset.Dataset, num_epochs: int, batch_size: int, logger: Optional[logging.Logger] = None, opt_state: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, Iterable[Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, Iterable[ForwardRef('ArrayTree')], Mapping[Any, ForwardRef('ArrayTree')]]], Mapping[Any, Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, Iterable[ForwardRef('ArrayTree')], Mapping[Any, ForwardRef('ArrayTree')]]], NoneType] = None, filter_spec: Optional[Any] = None, checkpoint_dir: Optional[str] = None, checkpoint_every: Optional[int] = None) -> tuple[typing.Union[onsagernet.dynamics.SDE, onsagernet.dynamics.ReducedSDE], list[float], typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, typing.Iterable[typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, typing.Iterable[ForwardRef('ArrayTree')], typing.Mapping[typing.Any, ForwardRef('ArrayTree')]]], typing.Mapping[typing.Any, typing.Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, typing.Iterable[ForwardRef('ArrayTree')], typing.Mapping[typing.Any, ForwardRef('ArrayTree')]]]]]:
240    def train(
241        self,
242        model: DynamicModel,
243        dataset: Dataset,
244        num_epochs: int,
245        batch_size: int,
246        logger: Optional[Logger] = None,
247        opt_state: Optional[OptState] = None,
248        filter_spec: Optional[Any] = None,
249        checkpoint_dir: Optional[str] = None,
250        checkpoint_every: Optional[int] = None,
251    ) -> tuple[DynamicModel, list[float], OptState]:
252        """The main training routine.
253
254        Args:
255            model (DynamicModel): the model to be trained
256            dataset (Dataset): the dataset
257            num_epochs (int): number of epochs to train
258            batch_size (int): the batch size
259            logger (Optional[Logger], optional): the logging object. Defaults to None.
260            opt_state (Optional[OptState], optional): the starting optimiser state. Defaults to None.
261            filter_spec (Optional[Any], optional): the filtering logic. Defaults to None.
262            checkpoint_dir (Optional[str], optional): the directory to save checkpoints. Defaults to None.
263            checkpoint_every (Optional[int], optional): checkpoints are saved every `checkpoint_every` number of epochs. Defaults to None.
264
265        Returns:
266            tuple[DynamicModel, list[float], OptState]: trained model, list of losses, optimiser state
267        """
268        opt = self._make_optimiser(self._opt_options, self._rop_options)
269        if opt_state is None:
270            opt_state = opt.init(eqx.filter(model, eqx.is_array))
271
272        if filter_spec is None:
273            filter_spec = tree_map(lambda _: True, model)
274
275        losses = []
276        for epoch in range(num_epochs):
277            model, step_loss, opt_state = self._train_epoch(
278                model=model,
279                dataset=dataset,
280                batch_size=batch_size,
281                opt=opt,
282                opt_state=opt_state,
283                filter_spec=filter_spec,
284            )
285            losses.append(step_loss)
286            if logger:
287                lr_scale = tree_get(opt_state, "scale")
288                logger.info(
289                    f"epoch={epoch:05d}, loss={step_loss:.6f}, lr_scale={lr_scale:.4f}"
290                )
291
292            if checkpoint_dir is not None and epoch % checkpoint_every == 0:
293                model_path = os.path.join(
294                    checkpoint_dir, f"model_epoch_{epoch:05d}.eqx"
295                )
296                eqx.tree_serialise_leaves(model_path, model)
297
298        return model, losses, opt_state

The main training routine.

Arguments:
  • model (DynamicModel): the model to be trained
  • dataset (Dataset): the dataset
  • num_epochs (int): number of epochs to train
  • batch_size (int): the batch size
  • logger (Optional[Logger], optional): the logging object. Defaults to None.
  • opt_state (Optional[OptState], optional): the starting optimiser state. Defaults to None.
  • filter_spec (Optional[Any], optional): the filtering logic. Defaults to None.
  • checkpoint_dir (Optional[str], optional): the directory to save checkpoints. Defaults to None.
  • checkpoint_every (Optional[int], optional): checkpoints are saved every checkpoint_every number of epochs. Defaults to None.
Returns:

tuple[DynamicModel, list[float], OptState]: trained model, list of losses, optimiser state

class MLETrainer(SDETrainer):
301class MLETrainer(SDETrainer):
302
303    @eqx.filter_jit
304    def loss_func(
305        self,
306        diff_model: DynamicModel,
307        static_model: DynamicModel,
308        t: ArrayLike,
309        x: ArrayLike,
310        args: ArrayLike,
311    ) -> float:
312        """The MLE loss function.
313
314        See [`MLELoss`](./_losses.html#MLELoss) for more details.
315
316        Args:
317            diff_model (DynamicModel): the trainable part of the model
318            static_model (DynamicModel): the static part of the model
319            t (ArrayLike): time
320            x (ArrayLike): state
321            args (ArrayLike): additional arguments or parameters.
322
323        Returns:
324            float: the computed loss
325        """
326        model = eqx.combine(diff_model, static_model)
327        return MLELoss()(model, t, x, args)

Base class for training SDE models.

@eqx.filter_jit
def loss_func( self, diff_model: Union[onsagernet.dynamics.SDE, onsagernet.dynamics.ReducedSDE], static_model: Union[onsagernet.dynamics.SDE, onsagernet.dynamics.ReducedSDE], t: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], x: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], args: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]) -> float:
303    @eqx.filter_jit
304    def loss_func(
305        self,
306        diff_model: DynamicModel,
307        static_model: DynamicModel,
308        t: ArrayLike,
309        x: ArrayLike,
310        args: ArrayLike,
311    ) -> float:
312        """The MLE loss function.
313
314        See [`MLELoss`](./_losses.html#MLELoss) for more details.
315
316        Args:
317            diff_model (DynamicModel): the trainable part of the model
318            static_model (DynamicModel): the static part of the model
319            t (ArrayLike): time
320            x (ArrayLike): state
321            args (ArrayLike): additional arguments or parameters.
322
323        Returns:
324            float: the computed loss
325        """
326        model = eqx.combine(diff_model, static_model)
327        return MLELoss()(model, t, x, args)

The MLE loss function.

See MLELoss for more details.

Arguments:
  • diff_model (DynamicModel): the trainable part of the model
  • static_model (DynamicModel): the static part of the model
  • t (ArrayLike): time
  • x (ArrayLike): state
  • args (ArrayLike): additional arguments or parameters.
Returns:

float: the computed loss

Inherited Members
SDETrainer
SDETrainer
train
class ClosureMLETrainer(MLETrainer):
330class ClosureMLETrainer(MLETrainer):
331
332    @eqx.filter_jit
333    def loss_func(
334        self,
335        diff_model: DynamicModel,
336        static_model: DynamicModel,
337        t: ArrayLike,
338        x: ArrayLike,
339        args: ArrayLike,
340    ) -> float:
341        r"""The combined loss function for MLE training with closure modelling.
342
343        The losses are applied to the combine model
344        `model = eqx.combine(diff_model, static_model)`
345        with three parts
346        - [`MLELoss`](./_losses.html#MLELoss) applied to `model.sde`
347        - [`ReconLoss`](./_losses.html#ReconLoss) applied to the `model`
348        - [`CompareLoss`](./_losses.html#CompareLoss) applied to the `model`
349
350        The combined loss is given by
351        $$
352            \text{loss} = \text{loss_sde}
353            + \text{recon_weight} \times \text{loss_recon}
354            + \text{compare_weight} \times \text{loss_compare}
355        $$
356
357        The variables `recon_weight` and `compare_weight` are set
358        in the `loss_options` attribute.
359
360        Args:
361            diff_model (DynamicModel): the trainable part of the model
362            static_model (DynamicModel): the static part of the model
363            t (ArrayLike): time
364            x (ArrayLike): state
365            args (ArrayLike): additional arguments or parameters.
366
367        Returns:
368            float: the computed loss
369        """
370        model = eqx.combine(diff_model, static_model)
371        z = jax.vmap(jax.vmap(model.encoder))(x)
372        loss_sde = MLELoss()(model.sde, t, z, args)
373        loss_recon = ReconLoss()(model, x)
374        loss_compare = CompareLoss()(model, x)
375        return (
376            loss_sde
377            + self._loss_options["recon_weight"] * loss_recon
378            + self._loss_options["compare_weight"] * loss_compare
379        )

Base class for training SDE models.

@eqx.filter_jit
def loss_func( self, diff_model: Union[onsagernet.dynamics.SDE, onsagernet.dynamics.ReducedSDE], static_model: Union[onsagernet.dynamics.SDE, onsagernet.dynamics.ReducedSDE], t: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], x: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex], args: Union[jax.Array, numpy.ndarray, numpy.bool, numpy.number, bool, int, float, complex]) -> float:
332    @eqx.filter_jit
333    def loss_func(
334        self,
335        diff_model: DynamicModel,
336        static_model: DynamicModel,
337        t: ArrayLike,
338        x: ArrayLike,
339        args: ArrayLike,
340    ) -> float:
341        r"""The combined loss function for MLE training with closure modelling.
342
343        The losses are applied to the combine model
344        `model = eqx.combine(diff_model, static_model)`
345        with three parts
346        - [`MLELoss`](./_losses.html#MLELoss) applied to `model.sde`
347        - [`ReconLoss`](./_losses.html#ReconLoss) applied to the `model`
348        - [`CompareLoss`](./_losses.html#CompareLoss) applied to the `model`
349
350        The combined loss is given by
351        $$
352            \text{loss} = \text{loss_sde}
353            + \text{recon_weight} \times \text{loss_recon}
354            + \text{compare_weight} \times \text{loss_compare}
355        $$
356
357        The variables `recon_weight` and `compare_weight` are set
358        in the `loss_options` attribute.
359
360        Args:
361            diff_model (DynamicModel): the trainable part of the model
362            static_model (DynamicModel): the static part of the model
363            t (ArrayLike): time
364            x (ArrayLike): state
365            args (ArrayLike): additional arguments or parameters.
366
367        Returns:
368            float: the computed loss
369        """
370        model = eqx.combine(diff_model, static_model)
371        z = jax.vmap(jax.vmap(model.encoder))(x)
372        loss_sde = MLELoss()(model.sde, t, z, args)
373        loss_recon = ReconLoss()(model, x)
374        loss_compare = CompareLoss()(model, x)
375        return (
376            loss_sde
377            + self._loss_options["recon_weight"] * loss_recon
378            + self._loss_options["compare_weight"] * loss_compare
379        )

The combined loss function for MLE training with closure modelling.

The losses are applied to the combine model model = eqx.combine(diff_model, static_model) with three parts

The combined loss is given by $$ \text{loss} = \text{loss_sde} + \text{recon_weight} \times \text{loss_recon} + \text{compare_weight} \times \text{loss_compare} $$

The variables recon_weight and compare_weight are set in the loss_options attribute.

Arguments:
  • diff_model (DynamicModel): the trainable part of the model
  • static_model (DynamicModel): the static part of the model
  • t (ArrayLike): time
  • x (ArrayLike): state
  • args (ArrayLike): additional arguments or parameters.
Returns:

float: the computed loss

Inherited Members
SDETrainer
SDETrainer
train