onsagernet.trainers
Basic trainers for SDE models
This module implements basic training routines for SDE models.
The base class is SDETrainer
, which provides the base training logic.
The sub-classes are required to implement the SDETrainer.loss_func
that is used to train the model.
Training routines
We provide here two training routines.
MLETrainer
: this implements the maximum likelihood lossMLELoss
to estimate SDE drift and diffusion through computing the maximum-likelihood following the Euler-Maruyama discretisation of the SDE.ClosureMLETrainer
: this combinesMLELoss
with additional losses to enforce closure, namely a reconstruction lossReconLoss
and a comparison lossCompareLoss
. This follows the implementation of [1].
The following example shows how to train an onsagernet.dynamics.OnsagerNet
model using the MLETrainer
.
from onsagernet.dynamics import OnsagerNet
from onsagernet.trainers import MLETrainer
sde = OnsagerNet(...)
dataset = load_data(...) # return a datasets.Dataset object
trainer = MLETrainer(opt_options=config.train.opt, rop_options=config.train.rop)
sde, losses, _ = trainer.train( # trains the model `sde` for 10 epochs with batch size 2
model=sde,
dataset=dataset,
num_epochs=10,
batch_size=2, # batch size should be typically small since this yields [n_batch, n_steps, n_dim] data
)
Dataset format
The dataset is assumed
to be a huggingface datasets.Dataset
object with three columns: t
, x
, and args
.
Here, t
is time, x
is state, and args
are additional arguments.
Let n_examples
be the number of examples, n_dim
be the dimension of the state x,
n_args
be the dimension of the arguments, and n_steps
be the number of time steps.
Then, the shapes of the data entries are as follows:
- t: (
n_examples
,n_steps
,1
) - this is the sequence of time steps sampled in the time series, which can vary from one example to another (but then_steps
must be the same), so consider padding if this were not the case - x: (
n_examples
,n_steps
,n_dim
) - this is the corresponding sequence of states - args: (
n_examples
,n_steps
,n_args
) - this is the sequence of additional arguments, which can be constant or varying in time. If constant, you must broad-cast the values along the time dimension so that the expected shapes are maintained.
We also assume that n_args
>= 1
, with the first dimension always being
temperature, since it is relevant for all SDE models in physical modelling.
Your custom loss_func
routine must return a scalar loss value.
In fact, it is not strictly necessary to use a datasets.Dataset
object, but any object that can yield the correct data format by calling dataset.iter(batch_size)
.
Writing custom training routines
You can write custom training routines (with custom losses, for example)
by sub-classing SDETrainer
.
Sub-classes must implement the SDETrainer.loss_func
method.
Given a model partitioned by eqx.partition(model, ...)
,
into a trainable part diff_model
and a static part static_model
,
the SDETrainer.loss_func
routine computes the loss of the model on the given dataset.
We always assume that the batch data is a tuple of (t, x, args)
.
See above for more on the format of the dataset and
MLETrainer.loss_func
, or the more complex ClosureMLETrainer.loss_func
for examples.
References
- Chen, X. et al. Constructing custom thermodynamics using deep learning. Nature Computational Science 4, 66–85 (2024).
1r""" 2# Basic trainers for SDE models 3 4This module implements basic training routines for SDE models. 5The base class is `SDETrainer`, which provides the base training logic. 6The sub-classes are required to implement the `SDETrainer.loss_func` that is used to train the model. 7 8## Training routines 9 10We provide here two training routines. 11 12- `MLETrainer`: this implements the maximum likelihood loss [`MLELoss`](./_losses.html#MLELoss) 13 to estimate SDE drift and diffusion 14 through computing the [maximum-likelihood](https://en.wikipedia.org/wiki/Maximum_likelihood_estimation) following the 15 [Euler-Maruyama discretisation](https://en.wikipedia.org/wiki/Euler–Maruyama_method) of the SDE. 16- `ClosureMLETrainer`: this combines [`MLELoss`](./_losses.html#MLELoss) with additional losses to enforce closure, 17 namely a reconstruction loss [`ReconLoss`](./_losses.html#ReconLoss) and a 18 comparison loss [`CompareLoss`](./_losses.html#CompareLoss). 19 This follows the implementation of [1]. 20 21The following example shows how to train an `onsagernet.dynamics.OnsagerNet` model using the `MLETrainer`. 22```python 23from onsagernet.dynamics import OnsagerNet 24from onsagernet.trainers import MLETrainer 25 26sde = OnsagerNet(...) 27dataset = load_data(...) # return a datasets.Dataset object 28 29trainer = MLETrainer(opt_options=config.train.opt, rop_options=config.train.rop) 30sde, losses, _ = trainer.train( # trains the model `sde` for 10 epochs with batch size 2 31 model=sde, 32 dataset=dataset, 33 num_epochs=10, 34 batch_size=2, # batch size should be typically small since this yields [n_batch, n_steps, n_dim] data 35) 36``` 37 38## Dataset format 39The dataset is assumed 40to be a huggingface [`datasets.Dataset`](https://huggingface.co/docs/datasets/en/index) 41object with three columns: `t`, `x`, and `args`. 42Here, `t` is time, `x` is state, and `args` are additional arguments. 43Let `n_examples` be the number of examples, `n_dim` be the dimension of the state x, 44`n_args` be the dimension of the arguments, and `n_steps` be the number of time steps. 45Then, the shapes of the data entries are as follows: 46- t: (`n_examples`, `n_steps`, `1`) - this is the sequence of time steps sampled in the time series, 47 which can vary from one example to another (but the `n_steps` must be the same), so consider padding 48 if this were not the case 49- x: (`n_examples`, `n_steps`, `n_dim`) - this is the corresponding sequence of states 50- args: (`n_examples`, `n_steps`, `n_args`) - this is the sequence of additional arguments, 51 which can be constant or varying in time. If constant, you must broad-cast the values 52 along the time dimension so that the expected shapes are maintained. 53 54We also assume that `n_args` >= `1`, with the first dimension always being 55temperature, since it is relevant for all SDE models in physical modelling. 56Your custom `loss_func` routine must return a scalar loss value. 57 58*In fact, it is not strictly necessary to use a [`datasets.Dataset`](https://huggingface.co/docs/datasets/en/index) 59object, but any object that can yield the correct data format by calling `dataset.iter(batch_size)`.* 60 61## Writing custom training routines 62 63You can write custom training routines (with custom losses, for example) 64by sub-classing `SDETrainer`. 65Sub-classes must implement the `SDETrainer.loss_func` method. 66 67Given a model partitioned by `eqx.partition(model, ...)`, 68into a trainable part `diff_model` and a static part `static_model`, 69the `SDETrainer.loss_func` routine computes the loss of the model on the given dataset. 70We always assume that the batch data is a tuple of `(t, x, args)`. 71 72See above for more on the format of the dataset and 73`MLETrainer.loss_func`, or the more complex `ClosureMLETrainer.loss_func` for examples. 74 75## References 76 771. Chen, X. et al. *Constructing custom thermodynamics using deep learning*. Nature Computational Science **4**, 66–85 (2024). 78 79 80""" 81 82import os 83import jax 84import jax.numpy as jnp 85import equinox as eqx 86from abc import ABC, abstractmethod 87 88from optax import adam, chain 89from optax.contrib import reduce_on_plateau 90from optax.tree_utils import tree_get 91from jax.tree_util import tree_map 92 93from tqdm import tqdm 94 95from ._losses import MLELoss, ReconLoss, CompareLoss 96 97# ------------------------- Typing imports ------------------------- # 98 99from chex import ArrayTree # to silence pdoc warnings 100from .dynamics import SDE, ReducedSDE 101from typing import Optional, Any, Union 102from jax.typing import ArrayLike 103from jax import Array 104from datasets import Dataset 105from optax import GradientTransformation, OptState 106from logging import Logger 107 108DynamicModel = Union[SDE, ReducedSDE] 109 110# ------------------------------------------------------------------ # 111# Trainers # 112# ------------------------------------------------------------------ # 113 114 115class SDETrainer(ABC): 116 """Base class for training SDE models.""" 117 118 def __init__( 119 self, opt_options: dict, rop_options: dict, loss_options: Optional[dict] = None 120 ) -> None: 121 """SDE training routine. 122 123 Args: 124 opt_options (dict): dictionary of options for the optimiser 125 rop_options (dict): dictionary of options for the reduce-on-plateau callback 126 loss_options (Optional[dict], optional): dictionary of options for loss computation. Defaults to None. 127 """ 128 self._opt_options = opt_options 129 self._rop_options = rop_options 130 self._loss_options = loss_options 131 132 @eqx.filter_jit 133 @abstractmethod 134 def loss_func( 135 self, 136 diff_model: DynamicModel, 137 static_model: DynamicModel, 138 t: ArrayLike, 139 x: ArrayLike, 140 args: ArrayLike, 141 ) -> float: 142 """Loss function. 143 144 This must be implemented by sub-classes. 145 146 Args: 147 diff_model (DynamicModel): the trainable part of the model 148 static_model (DynamicModel): the static part of the model 149 t (ArrayLike): time 150 x (ArrayLike): state 151 args (ArrayLike): additional arguments or parameters. The first dimension is temperature. 152 153 Returns: 154 float: loss value 155 """ 156 pass 157 158 def _make_optimiser( 159 self, opt_options: dict, rop_options: dict 160 ) -> GradientTransformation: 161 """Make an optimiser. 162 163 Args: 164 opt_options (dict): optimiser options 165 rop_options (dict): reduce-on-plateau options 166 167 Returns: 168 GradientTransformation: an optimiser object from `optax` 169 """ 170 return chain(adam(**opt_options), reduce_on_plateau(**rop_options)) 171 172 @eqx.filter_jit 173 def _make_step( 174 self, 175 model: DynamicModel, 176 data: Dataset, 177 opt: GradientTransformation, 178 opt_state: OptState, 179 filter_spec: Any, 180 ) -> tuple[DynamicModel, OptState, float]: 181 """Make a training step. 182 183 Args: 184 model (DynamicModel): the model to be trained 185 data (Dataset): the dataset object 186 opt (GradientTransformation): optimiser object 187 opt_state (OptState): optimiser state 188 filter_spec (Any): the filtering logic to determine which parts of the model to train 189 190 Returns: 191 tuple[DynamicModel, OptState, float]: trained model, optimiser state, loss value 192 """ 193 diff_model, static_model = eqx.partition(model, filter_spec) 194 195 loss_value, grads = eqx.filter_value_and_grad(self.loss_func)( 196 diff_model, static_model, *data 197 ) 198 updates, opt_state = opt.update(grads, opt_state, model, value=loss_value) 199 model = eqx.apply_updates(model, updates) 200 return model, loss_value, opt_state 201 202 def _train_epoch( 203 self, 204 model: DynamicModel, 205 dataset: Dataset, 206 batch_size: int, 207 opt: GradientTransformation, 208 opt_state: OptState, 209 filter_spec: Any, 210 ) -> tuple[DynamicModel, float, OptState]: 211 """Train the model for an epoch. 212 213 Args: 214 model (DynamicModel): the model to be trained 215 dataset (Dataset): the dataset object 216 batch_size (int): the batch size 217 opt (GradientTransformation): the optimiser object 218 opt_state (OptState): the optimiser state 219 filter_spec (Any): the filtering logic to determine which parts of the model to train 220 221 Returns: 222 tuple[DynamicModel, float, OptState]: trained model, loss value, optimiser state 223 """ 224 225 step_losses = [] 226 227 for batch in tqdm( 228 dataset.iter(batch_size), 229 total=dataset.num_rows // batch_size, 230 ): 231 data_batch = (batch["t"], batch["x"], batch["args"]) 232 model, train_loss, opt_state = self._make_step( 233 model, data_batch, opt, opt_state, filter_spec 234 ) 235 step_losses.append(train_loss) 236 epoch_loss = jnp.mean(jnp.array(step_losses)) 237 return model, epoch_loss, opt_state 238 239 def train( 240 self, 241 model: DynamicModel, 242 dataset: Dataset, 243 num_epochs: int, 244 batch_size: int, 245 logger: Optional[Logger] = None, 246 opt_state: Optional[OptState] = None, 247 filter_spec: Optional[Any] = None, 248 checkpoint_dir: Optional[str] = None, 249 checkpoint_every: Optional[int] = None, 250 ) -> tuple[DynamicModel, list[float], OptState]: 251 """The main training routine. 252 253 Args: 254 model (DynamicModel): the model to be trained 255 dataset (Dataset): the dataset 256 num_epochs (int): number of epochs to train 257 batch_size (int): the batch size 258 logger (Optional[Logger], optional): the logging object. Defaults to None. 259 opt_state (Optional[OptState], optional): the starting optimiser state. Defaults to None. 260 filter_spec (Optional[Any], optional): the filtering logic. Defaults to None. 261 checkpoint_dir (Optional[str], optional): the directory to save checkpoints. Defaults to None. 262 checkpoint_every (Optional[int], optional): checkpoints are saved every `checkpoint_every` number of epochs. Defaults to None. 263 264 Returns: 265 tuple[DynamicModel, list[float], OptState]: trained model, list of losses, optimiser state 266 """ 267 opt = self._make_optimiser(self._opt_options, self._rop_options) 268 if opt_state is None: 269 opt_state = opt.init(eqx.filter(model, eqx.is_array)) 270 271 if filter_spec is None: 272 filter_spec = tree_map(lambda _: True, model) 273 274 losses = [] 275 for epoch in range(num_epochs): 276 model, step_loss, opt_state = self._train_epoch( 277 model=model, 278 dataset=dataset, 279 batch_size=batch_size, 280 opt=opt, 281 opt_state=opt_state, 282 filter_spec=filter_spec, 283 ) 284 losses.append(step_loss) 285 if logger: 286 lr_scale = tree_get(opt_state, "scale") 287 logger.info( 288 f"epoch={epoch:05d}, loss={step_loss:.6f}, lr_scale={lr_scale:.4f}" 289 ) 290 291 if checkpoint_dir is not None and epoch % checkpoint_every == 0: 292 model_path = os.path.join( 293 checkpoint_dir, f"model_epoch_{epoch:05d}.eqx" 294 ) 295 eqx.tree_serialise_leaves(model_path, model) 296 297 return model, losses, opt_state 298 299 300class MLETrainer(SDETrainer): 301 302 @eqx.filter_jit 303 def loss_func( 304 self, 305 diff_model: DynamicModel, 306 static_model: DynamicModel, 307 t: ArrayLike, 308 x: ArrayLike, 309 args: ArrayLike, 310 ) -> float: 311 """The MLE loss function. 312 313 See [`MLELoss`](./_losses.html#MLELoss) for more details. 314 315 Args: 316 diff_model (DynamicModel): the trainable part of the model 317 static_model (DynamicModel): the static part of the model 318 t (ArrayLike): time 319 x (ArrayLike): state 320 args (ArrayLike): additional arguments or parameters. 321 322 Returns: 323 float: the computed loss 324 """ 325 model = eqx.combine(diff_model, static_model) 326 return MLELoss()(model, t, x, args) 327 328 329class ClosureMLETrainer(MLETrainer): 330 331 @eqx.filter_jit 332 def loss_func( 333 self, 334 diff_model: DynamicModel, 335 static_model: DynamicModel, 336 t: ArrayLike, 337 x: ArrayLike, 338 args: ArrayLike, 339 ) -> float: 340 r"""The combined loss function for MLE training with closure modelling. 341 342 The losses are applied to the combine model 343 `model = eqx.combine(diff_model, static_model)` 344 with three parts 345 - [`MLELoss`](./_losses.html#MLELoss) applied to `model.sde` 346 - [`ReconLoss`](./_losses.html#ReconLoss) applied to the `model` 347 - [`CompareLoss`](./_losses.html#CompareLoss) applied to the `model` 348 349 The combined loss is given by 350 $$ 351 \text{loss} = \text{loss_sde} 352 + \text{recon_weight} \times \text{loss_recon} 353 + \text{compare_weight} \times \text{loss_compare} 354 $$ 355 356 The variables `recon_weight` and `compare_weight` are set 357 in the `loss_options` attribute. 358 359 Args: 360 diff_model (DynamicModel): the trainable part of the model 361 static_model (DynamicModel): the static part of the model 362 t (ArrayLike): time 363 x (ArrayLike): state 364 args (ArrayLike): additional arguments or parameters. 365 366 Returns: 367 float: the computed loss 368 """ 369 model = eqx.combine(diff_model, static_model) 370 z = jax.vmap(jax.vmap(model.encoder))(x) 371 loss_sde = MLELoss()(model.sde, t, z, args) 372 loss_recon = ReconLoss()(model, x) 373 loss_compare = CompareLoss()(model, x) 374 return ( 375 loss_sde 376 + self._loss_options["recon_weight"] * loss_recon 377 + self._loss_options["compare_weight"] * loss_compare 378 )
116class SDETrainer(ABC): 117 """Base class for training SDE models.""" 118 119 def __init__( 120 self, opt_options: dict, rop_options: dict, loss_options: Optional[dict] = None 121 ) -> None: 122 """SDE training routine. 123 124 Args: 125 opt_options (dict): dictionary of options for the optimiser 126 rop_options (dict): dictionary of options for the reduce-on-plateau callback 127 loss_options (Optional[dict], optional): dictionary of options for loss computation. Defaults to None. 128 """ 129 self._opt_options = opt_options 130 self._rop_options = rop_options 131 self._loss_options = loss_options 132 133 @eqx.filter_jit 134 @abstractmethod 135 def loss_func( 136 self, 137 diff_model: DynamicModel, 138 static_model: DynamicModel, 139 t: ArrayLike, 140 x: ArrayLike, 141 args: ArrayLike, 142 ) -> float: 143 """Loss function. 144 145 This must be implemented by sub-classes. 146 147 Args: 148 diff_model (DynamicModel): the trainable part of the model 149 static_model (DynamicModel): the static part of the model 150 t (ArrayLike): time 151 x (ArrayLike): state 152 args (ArrayLike): additional arguments or parameters. The first dimension is temperature. 153 154 Returns: 155 float: loss value 156 """ 157 pass 158 159 def _make_optimiser( 160 self, opt_options: dict, rop_options: dict 161 ) -> GradientTransformation: 162 """Make an optimiser. 163 164 Args: 165 opt_options (dict): optimiser options 166 rop_options (dict): reduce-on-plateau options 167 168 Returns: 169 GradientTransformation: an optimiser object from `optax` 170 """ 171 return chain(adam(**opt_options), reduce_on_plateau(**rop_options)) 172 173 @eqx.filter_jit 174 def _make_step( 175 self, 176 model: DynamicModel, 177 data: Dataset, 178 opt: GradientTransformation, 179 opt_state: OptState, 180 filter_spec: Any, 181 ) -> tuple[DynamicModel, OptState, float]: 182 """Make a training step. 183 184 Args: 185 model (DynamicModel): the model to be trained 186 data (Dataset): the dataset object 187 opt (GradientTransformation): optimiser object 188 opt_state (OptState): optimiser state 189 filter_spec (Any): the filtering logic to determine which parts of the model to train 190 191 Returns: 192 tuple[DynamicModel, OptState, float]: trained model, optimiser state, loss value 193 """ 194 diff_model, static_model = eqx.partition(model, filter_spec) 195 196 loss_value, grads = eqx.filter_value_and_grad(self.loss_func)( 197 diff_model, static_model, *data 198 ) 199 updates, opt_state = opt.update(grads, opt_state, model, value=loss_value) 200 model = eqx.apply_updates(model, updates) 201 return model, loss_value, opt_state 202 203 def _train_epoch( 204 self, 205 model: DynamicModel, 206 dataset: Dataset, 207 batch_size: int, 208 opt: GradientTransformation, 209 opt_state: OptState, 210 filter_spec: Any, 211 ) -> tuple[DynamicModel, float, OptState]: 212 """Train the model for an epoch. 213 214 Args: 215 model (DynamicModel): the model to be trained 216 dataset (Dataset): the dataset object 217 batch_size (int): the batch size 218 opt (GradientTransformation): the optimiser object 219 opt_state (OptState): the optimiser state 220 filter_spec (Any): the filtering logic to determine which parts of the model to train 221 222 Returns: 223 tuple[DynamicModel, float, OptState]: trained model, loss value, optimiser state 224 """ 225 226 step_losses = [] 227 228 for batch in tqdm( 229 dataset.iter(batch_size), 230 total=dataset.num_rows // batch_size, 231 ): 232 data_batch = (batch["t"], batch["x"], batch["args"]) 233 model, train_loss, opt_state = self._make_step( 234 model, data_batch, opt, opt_state, filter_spec 235 ) 236 step_losses.append(train_loss) 237 epoch_loss = jnp.mean(jnp.array(step_losses)) 238 return model, epoch_loss, opt_state 239 240 def train( 241 self, 242 model: DynamicModel, 243 dataset: Dataset, 244 num_epochs: int, 245 batch_size: int, 246 logger: Optional[Logger] = None, 247 opt_state: Optional[OptState] = None, 248 filter_spec: Optional[Any] = None, 249 checkpoint_dir: Optional[str] = None, 250 checkpoint_every: Optional[int] = None, 251 ) -> tuple[DynamicModel, list[float], OptState]: 252 """The main training routine. 253 254 Args: 255 model (DynamicModel): the model to be trained 256 dataset (Dataset): the dataset 257 num_epochs (int): number of epochs to train 258 batch_size (int): the batch size 259 logger (Optional[Logger], optional): the logging object. Defaults to None. 260 opt_state (Optional[OptState], optional): the starting optimiser state. Defaults to None. 261 filter_spec (Optional[Any], optional): the filtering logic. Defaults to None. 262 checkpoint_dir (Optional[str], optional): the directory to save checkpoints. Defaults to None. 263 checkpoint_every (Optional[int], optional): checkpoints are saved every `checkpoint_every` number of epochs. Defaults to None. 264 265 Returns: 266 tuple[DynamicModel, list[float], OptState]: trained model, list of losses, optimiser state 267 """ 268 opt = self._make_optimiser(self._opt_options, self._rop_options) 269 if opt_state is None: 270 opt_state = opt.init(eqx.filter(model, eqx.is_array)) 271 272 if filter_spec is None: 273 filter_spec = tree_map(lambda _: True, model) 274 275 losses = [] 276 for epoch in range(num_epochs): 277 model, step_loss, opt_state = self._train_epoch( 278 model=model, 279 dataset=dataset, 280 batch_size=batch_size, 281 opt=opt, 282 opt_state=opt_state, 283 filter_spec=filter_spec, 284 ) 285 losses.append(step_loss) 286 if logger: 287 lr_scale = tree_get(opt_state, "scale") 288 logger.info( 289 f"epoch={epoch:05d}, loss={step_loss:.6f}, lr_scale={lr_scale:.4f}" 290 ) 291 292 if checkpoint_dir is not None and epoch % checkpoint_every == 0: 293 model_path = os.path.join( 294 checkpoint_dir, f"model_epoch_{epoch:05d}.eqx" 295 ) 296 eqx.tree_serialise_leaves(model_path, model) 297 298 return model, losses, opt_state
Base class for training SDE models.
119 def __init__( 120 self, opt_options: dict, rop_options: dict, loss_options: Optional[dict] = None 121 ) -> None: 122 """SDE training routine. 123 124 Args: 125 opt_options (dict): dictionary of options for the optimiser 126 rop_options (dict): dictionary of options for the reduce-on-plateau callback 127 loss_options (Optional[dict], optional): dictionary of options for loss computation. Defaults to None. 128 """ 129 self._opt_options = opt_options 130 self._rop_options = rop_options 131 self._loss_options = loss_options
SDE training routine.
Arguments:
- opt_options (dict): dictionary of options for the optimiser
- rop_options (dict): dictionary of options for the reduce-on-plateau callback
- loss_options (Optional[dict], optional): dictionary of options for loss computation. Defaults to None.
133 @eqx.filter_jit 134 @abstractmethod 135 def loss_func( 136 self, 137 diff_model: DynamicModel, 138 static_model: DynamicModel, 139 t: ArrayLike, 140 x: ArrayLike, 141 args: ArrayLike, 142 ) -> float: 143 """Loss function. 144 145 This must be implemented by sub-classes. 146 147 Args: 148 diff_model (DynamicModel): the trainable part of the model 149 static_model (DynamicModel): the static part of the model 150 t (ArrayLike): time 151 x (ArrayLike): state 152 args (ArrayLike): additional arguments or parameters. The first dimension is temperature. 153 154 Returns: 155 float: loss value 156 """ 157 pass
Loss function.
This must be implemented by sub-classes.
Arguments:
- diff_model (DynamicModel): the trainable part of the model
- static_model (DynamicModel): the static part of the model
- t (ArrayLike): time
- x (ArrayLike): state
- args (ArrayLike): additional arguments or parameters. The first dimension is temperature.
Returns:
float: loss value
240 def train( 241 self, 242 model: DynamicModel, 243 dataset: Dataset, 244 num_epochs: int, 245 batch_size: int, 246 logger: Optional[Logger] = None, 247 opt_state: Optional[OptState] = None, 248 filter_spec: Optional[Any] = None, 249 checkpoint_dir: Optional[str] = None, 250 checkpoint_every: Optional[int] = None, 251 ) -> tuple[DynamicModel, list[float], OptState]: 252 """The main training routine. 253 254 Args: 255 model (DynamicModel): the model to be trained 256 dataset (Dataset): the dataset 257 num_epochs (int): number of epochs to train 258 batch_size (int): the batch size 259 logger (Optional[Logger], optional): the logging object. Defaults to None. 260 opt_state (Optional[OptState], optional): the starting optimiser state. Defaults to None. 261 filter_spec (Optional[Any], optional): the filtering logic. Defaults to None. 262 checkpoint_dir (Optional[str], optional): the directory to save checkpoints. Defaults to None. 263 checkpoint_every (Optional[int], optional): checkpoints are saved every `checkpoint_every` number of epochs. Defaults to None. 264 265 Returns: 266 tuple[DynamicModel, list[float], OptState]: trained model, list of losses, optimiser state 267 """ 268 opt = self._make_optimiser(self._opt_options, self._rop_options) 269 if opt_state is None: 270 opt_state = opt.init(eqx.filter(model, eqx.is_array)) 271 272 if filter_spec is None: 273 filter_spec = tree_map(lambda _: True, model) 274 275 losses = [] 276 for epoch in range(num_epochs): 277 model, step_loss, opt_state = self._train_epoch( 278 model=model, 279 dataset=dataset, 280 batch_size=batch_size, 281 opt=opt, 282 opt_state=opt_state, 283 filter_spec=filter_spec, 284 ) 285 losses.append(step_loss) 286 if logger: 287 lr_scale = tree_get(opt_state, "scale") 288 logger.info( 289 f"epoch={epoch:05d}, loss={step_loss:.6f}, lr_scale={lr_scale:.4f}" 290 ) 291 292 if checkpoint_dir is not None and epoch % checkpoint_every == 0: 293 model_path = os.path.join( 294 checkpoint_dir, f"model_epoch_{epoch:05d}.eqx" 295 ) 296 eqx.tree_serialise_leaves(model_path, model) 297 298 return model, losses, opt_state
The main training routine.
Arguments:
- model (DynamicModel): the model to be trained
- dataset (Dataset): the dataset
- num_epochs (int): number of epochs to train
- batch_size (int): the batch size
- logger (Optional[Logger], optional): the logging object. Defaults to None.
- opt_state (Optional[OptState], optional): the starting optimiser state. Defaults to None.
- filter_spec (Optional[Any], optional): the filtering logic. Defaults to None.
- checkpoint_dir (Optional[str], optional): the directory to save checkpoints. Defaults to None.
- checkpoint_every (Optional[int], optional): checkpoints are saved every
checkpoint_every
number of epochs. Defaults to None.
Returns:
tuple[DynamicModel, list[float], OptState]: trained model, list of losses, optimiser state
301class MLETrainer(SDETrainer): 302 303 @eqx.filter_jit 304 def loss_func( 305 self, 306 diff_model: DynamicModel, 307 static_model: DynamicModel, 308 t: ArrayLike, 309 x: ArrayLike, 310 args: ArrayLike, 311 ) -> float: 312 """The MLE loss function. 313 314 See [`MLELoss`](./_losses.html#MLELoss) for more details. 315 316 Args: 317 diff_model (DynamicModel): the trainable part of the model 318 static_model (DynamicModel): the static part of the model 319 t (ArrayLike): time 320 x (ArrayLike): state 321 args (ArrayLike): additional arguments or parameters. 322 323 Returns: 324 float: the computed loss 325 """ 326 model = eqx.combine(diff_model, static_model) 327 return MLELoss()(model, t, x, args)
Base class for training SDE models.
303 @eqx.filter_jit 304 def loss_func( 305 self, 306 diff_model: DynamicModel, 307 static_model: DynamicModel, 308 t: ArrayLike, 309 x: ArrayLike, 310 args: ArrayLike, 311 ) -> float: 312 """The MLE loss function. 313 314 See [`MLELoss`](./_losses.html#MLELoss) for more details. 315 316 Args: 317 diff_model (DynamicModel): the trainable part of the model 318 static_model (DynamicModel): the static part of the model 319 t (ArrayLike): time 320 x (ArrayLike): state 321 args (ArrayLike): additional arguments or parameters. 322 323 Returns: 324 float: the computed loss 325 """ 326 model = eqx.combine(diff_model, static_model) 327 return MLELoss()(model, t, x, args)
The MLE loss function.
See MLELoss
for more details.
Arguments:
- diff_model (DynamicModel): the trainable part of the model
- static_model (DynamicModel): the static part of the model
- t (ArrayLike): time
- x (ArrayLike): state
- args (ArrayLike): additional arguments or parameters.
Returns:
float: the computed loss
Inherited Members
330class ClosureMLETrainer(MLETrainer): 331 332 @eqx.filter_jit 333 def loss_func( 334 self, 335 diff_model: DynamicModel, 336 static_model: DynamicModel, 337 t: ArrayLike, 338 x: ArrayLike, 339 args: ArrayLike, 340 ) -> float: 341 r"""The combined loss function for MLE training with closure modelling. 342 343 The losses are applied to the combine model 344 `model = eqx.combine(diff_model, static_model)` 345 with three parts 346 - [`MLELoss`](./_losses.html#MLELoss) applied to `model.sde` 347 - [`ReconLoss`](./_losses.html#ReconLoss) applied to the `model` 348 - [`CompareLoss`](./_losses.html#CompareLoss) applied to the `model` 349 350 The combined loss is given by 351 $$ 352 \text{loss} = \text{loss_sde} 353 + \text{recon_weight} \times \text{loss_recon} 354 + \text{compare_weight} \times \text{loss_compare} 355 $$ 356 357 The variables `recon_weight` and `compare_weight` are set 358 in the `loss_options` attribute. 359 360 Args: 361 diff_model (DynamicModel): the trainable part of the model 362 static_model (DynamicModel): the static part of the model 363 t (ArrayLike): time 364 x (ArrayLike): state 365 args (ArrayLike): additional arguments or parameters. 366 367 Returns: 368 float: the computed loss 369 """ 370 model = eqx.combine(diff_model, static_model) 371 z = jax.vmap(jax.vmap(model.encoder))(x) 372 loss_sde = MLELoss()(model.sde, t, z, args) 373 loss_recon = ReconLoss()(model, x) 374 loss_compare = CompareLoss()(model, x) 375 return ( 376 loss_sde 377 + self._loss_options["recon_weight"] * loss_recon 378 + self._loss_options["compare_weight"] * loss_compare 379 )
Base class for training SDE models.
332 @eqx.filter_jit 333 def loss_func( 334 self, 335 diff_model: DynamicModel, 336 static_model: DynamicModel, 337 t: ArrayLike, 338 x: ArrayLike, 339 args: ArrayLike, 340 ) -> float: 341 r"""The combined loss function for MLE training with closure modelling. 342 343 The losses are applied to the combine model 344 `model = eqx.combine(diff_model, static_model)` 345 with three parts 346 - [`MLELoss`](./_losses.html#MLELoss) applied to `model.sde` 347 - [`ReconLoss`](./_losses.html#ReconLoss) applied to the `model` 348 - [`CompareLoss`](./_losses.html#CompareLoss) applied to the `model` 349 350 The combined loss is given by 351 $$ 352 \text{loss} = \text{loss_sde} 353 + \text{recon_weight} \times \text{loss_recon} 354 + \text{compare_weight} \times \text{loss_compare} 355 $$ 356 357 The variables `recon_weight` and `compare_weight` are set 358 in the `loss_options` attribute. 359 360 Args: 361 diff_model (DynamicModel): the trainable part of the model 362 static_model (DynamicModel): the static part of the model 363 t (ArrayLike): time 364 x (ArrayLike): state 365 args (ArrayLike): additional arguments or parameters. 366 367 Returns: 368 float: the computed loss 369 """ 370 model = eqx.combine(diff_model, static_model) 371 z = jax.vmap(jax.vmap(model.encoder))(x) 372 loss_sde = MLELoss()(model.sde, t, z, args) 373 loss_recon = ReconLoss()(model, x) 374 loss_compare = CompareLoss()(model, x) 375 return ( 376 loss_sde 377 + self._loss_options["recon_weight"] * loss_recon 378 + self._loss_options["compare_weight"] * loss_compare 379 )
The combined loss function for MLE training with closure modelling.
The losses are applied to the combine model
model = eqx.combine(diff_model, static_model)
with three parts
MLELoss
applied tomodel.sde
ReconLoss
applied to themodel
CompareLoss
applied to themodel
The combined loss is given by $$ \text{loss} = \text{loss_sde} + \text{recon_weight} \times \text{loss_recon} + \text{compare_weight} \times \text{loss_compare} $$
The variables recon_weight
and compare_weight
are set
in the loss_options
attribute.
Arguments:
- diff_model (DynamicModel): the trainable part of the model
- static_model (DynamicModel): the static part of the model
- t (ArrayLike): time
- x (ArrayLike): state
- args (ArrayLike): additional arguments or parameters.
Returns:
float: the computed loss