Loss¶
The loss module provides standard and dynamic loss functions for training.
Dynamic Loss¶
The DynamicLoss class implements the variance-weighted loss used in Samudra 2. It maintains per-channel scaling weights updated via an exponential moving average of inverse prediction error:
- Channels with higher error receive higher weight, preventing the model from neglecting slow-evolving deep-ocean fields
- A configurable
limitparameter (default: 20) clamps the max ratio between channel weights to prevent extreme imbalance - Uses a rolling window of 25 steps to smooth scale estimates
Configuration:
# Standard MSE (Samudra v1)
loss: mse
# Dynamic variance-weighted loss (Samudra 2)
loss:
type: dynamic
metric: mse
limit: 20
API Reference¶
ocean_emulators.utils.loss
¶
DynamicLoss(loss_fn, *, limit, device, num_channels)
¶
A loss function that scales each channel to contribute equally to the loss.
This uses a rolling estimate of the loss of each channel to scale each channel's loss, discouraging the model from focusing on only a few channels.
See: https://openathena.slack.com/archives/C08CYM42DT3/p1752275713570969
Source code in src/ocean_emulators/utils/loss.py
N_WINDOW = 25
class-attribute
instance-attribute
¶
Rolling window size to average over. (~number of steps)
update(loss_per_channel)
¶
Given the unscaled per-channel loss, update the per-channel scale.
Source code in src/ocean_emulators/utils/loss.py
state_dict()
¶
load_state_dict(state)
¶
GradientLoss(loss_fn, *, gradient_weight, pad_mode)
¶
Combine a base loss with a gradient matching penalty.
Applies the provided per-channel loss metric then adds an L1 penalty on
spatial gradients, scaled by gradient_weight.
Source code in src/ocean_emulators/utils/loss.py
decomposed_mse(pred, target)
¶
Standard MSE loss (l2) computed per channel.
decomposed_mae(pred, target)
¶
decomposed_mse_diff_weighted(pred, target)
¶
MSE loss with weighted differences.
Source code in src/ocean_emulators/utils/loss.py
decomposed_mse_scaled(pred, target, scaling)
¶
MSE loss with scaled residuals.
Source code in src/ocean_emulators/utils/loss.py
decomposed_mse_mae(pred, target)
¶
Combined MSE and MAE loss.
Source code in src/ocean_emulators/utils/loss.py
gradient_l1_loss(pred, target, pad_mode)
¶
L1 loss on spatial gradients, averaged per channel.
Source code in src/ocean_emulators/utils/loss.py
decomposed_mae_gradient_weighted(pred, target, gradient_weight, pad_mode='constant')
¶
MAE loss with spatial gradient matching penalty.