Skip to content

Loss

The loss module provides standard and dynamic loss functions for training.

Dynamic Loss

The DynamicLoss class implements the variance-weighted loss used in Samudra 2. It maintains per-channel scaling weights updated via an exponential moving average of inverse prediction error:

  • Channels with higher error receive higher weight, preventing the model from neglecting slow-evolving deep-ocean fields
  • A configurable limit parameter (default: 20) clamps the max ratio between channel weights to prevent extreme imbalance
  • Uses a rolling window of 25 steps to smooth scale estimates

Configuration:

# Standard MSE (Samudra v1)
loss: mse

# Dynamic variance-weighted loss (Samudra 2)
loss:
  type: dynamic
  metric: mse
  limit: 20

API Reference

ocean_emulators.utils.loss

DynamicLoss(loss_fn, *, limit, device, num_channels)

A loss function that scales each channel to contribute equally to the loss.

This uses a rolling estimate of the loss of each channel to scale each channel's loss, discouraging the model from focusing on only a few channels.

See: https://openathena.slack.com/archives/C08CYM42DT3/p1752275713570969

Source code in src/ocean_emulators/utils/loss.py
def __init__(
    self,
    loss_fn: LossFnWithContext,
    *,
    limit: float | None,
    device: torch.device,
    num_channels: int,
):
    self.loss_fn = loss_fn
    self._device = device
    self._per_channel_scale: Float[torch.Tensor, " var"] = torch.ones(
        num_channels, device=self._device
    )
    self._limit = limit

N_WINDOW = 25 class-attribute instance-attribute

Rolling window size to average over. (~number of steps)

update(loss_per_channel)

Given the unscaled per-channel loss, update the per-channel scale.

Source code in src/ocean_emulators/utils/loss.py
def update(
    self,
    loss_per_channel: Float[torch.Tensor, " hist*var"],
) -> None:
    """Given the unscaled per-channel loss, update the per-channel scale."""
    # Local import is needed to prevent a circular import error.
    from ocean_emulators.utils.distributed import all_reduce_mean, get_world_size

    loss = loss_per_channel.detach()
    loss = torch.where(loss == 0, 1e-8, loss)
    new_target_weights_with_history: Float[torch.Tensor, " hist*var"] = 1.0 / loss
    # Reshape from channels * history to channels
    # by averaging along the `hist` dimension
    new_target_weights: Float[torch.Tensor, " var"] = (
        new_target_weights_with_history.reshape(
            -1, self._per_channel_scale.shape[0]
        ).mean(dim=0)
    )

    if get_world_size() > 1:
        all_reduce_mean(new_target_weights)

    if self._limit is not None:
        min_scale = new_target_weights.min()
        max_scale = min_scale * self._limit
        new_target_weights = new_target_weights.clamp(min_scale, max_scale)

    self._per_channel_scale = (
        self._per_channel_scale * (DynamicLoss.N_WINDOW - 1) + new_target_weights
    ) / DynamicLoss.N_WINDOW

state_dict()

Return state dictionary for checkpointing.

Source code in src/ocean_emulators/utils/loss.py
def state_dict(self) -> dict[str, torch.Tensor]:
    """Return state dictionary for checkpointing."""
    return {"per_channel_scale": self._per_channel_scale.detach().cpu()}

load_state_dict(state)

Load state from state_dict.

Source code in src/ocean_emulators/utils/loss.py
def load_state_dict(self, state: dict[str, torch.Tensor]) -> None:
    """Load state from ``state_dict``."""
    if "per_channel_scale" in state:
        self._per_channel_scale = state["per_channel_scale"].to(self._device)

GradientLoss(loss_fn, *, gradient_weight, pad_mode)

Combine a base loss with a gradient matching penalty.

Applies the provided per-channel loss metric then adds an L1 penalty on spatial gradients, scaled by gradient_weight.

Source code in src/ocean_emulators/utils/loss.py
def __init__(
    self,
    loss_fn: LossFnWithContext,
    *,
    gradient_weight: float,
    pad_mode: str,
):
    self.loss_fn = loss_fn
    self._gradient_weight = gradient_weight
    self._pad_mode = pad_mode

decomposed_mse(pred, target)

Standard MSE loss (l2) computed per channel.

Source code in src/ocean_emulators/utils/loss.py
def decomposed_mse(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """Standard MSE loss (l2) computed per channel."""
    return F.mse_loss(pred, target, reduction="none").mean(dim=(0, 2, 3))

decomposed_mae(pred, target)

Standard MAE loss (l1) computed per channel.

Source code in src/ocean_emulators/utils/loss.py
def decomposed_mae(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """Standard MAE loss (l1) computed per channel."""
    return F.l1_loss(pred, target, reduction="none").mean(dim=(0, 2, 3))

decomposed_mse_diff_weighted(pred, target)

MSE loss with weighted differences.

Source code in src/ocean_emulators/utils/loss.py
def decomposed_mse_diff_weighted(
    pred: torch.Tensor, target: torch.Tensor
) -> torch.Tensor:
    """MSE loss with weighted differences."""
    # Compute standard MSE
    mse = F.mse_loss(pred, target, reduction="none")

    # Weight the differences more heavily
    diff_weight = 2.0  # Adjustable weight factor
    diff_mse = (
        F.mse_loss(
            pred[:, 1:] - pred[:, :-1], target[:, 1:] - target[:, :-1], reduction="none"
        )
        * diff_weight
    )

    # Combine losses
    combined_loss = torch.cat([mse[:, :1], diff_mse], dim=1)
    return combined_loss.mean(dim=(0, 2, 3))

decomposed_mse_scaled(pred, target, scaling)

MSE loss with scaled residuals.

Source code in src/ocean_emulators/utils/loss.py
def decomposed_mse_scaled(
    pred: torch.Tensor, target: torch.Tensor, scaling: torch.Tensor
) -> torch.Tensor:
    """MSE loss with scaled residuals."""
    scaled_pred = pred * scaling.view(1, -1, 1, 1)
    scaled_target = target * scaling.view(1, -1, 1, 1)
    return F.mse_loss(scaled_pred, scaled_target, reduction="none").mean(dim=(0, 2, 3))

decomposed_mse_mae(pred, target)

Combined MSE and MAE loss.

Source code in src/ocean_emulators/utils/loss.py
def decomposed_mse_mae(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
    """Combined MSE and MAE loss."""
    mse = F.mse_loss(pred, target, reduction="none")
    mae = F.l1_loss(pred, target, reduction="none")
    combined = (mse + mae) / 2
    return combined.mean(dim=(0, 2, 3))

gradient_l1_loss(pred, target, pad_mode)

L1 loss on spatial gradients, averaged per channel.

Source code in src/ocean_emulators/utils/loss.py
def gradient_l1_loss(
    pred: torch.Tensor, target: torch.Tensor, pad_mode: str
) -> torch.Tensor:
    """L1 loss on spatial gradients, averaged per channel."""
    pred_grad_y, pred_grad_x = _spatial_gradients(pred, pad_mode=pad_mode)
    target_grad_y, target_grad_x = _spatial_gradients(target, pad_mode=pad_mode)

    grad_loss_y = F.l1_loss(pred_grad_y, target_grad_y, reduction="none")
    grad_loss_x = F.l1_loss(pred_grad_x, target_grad_x, reduction="none")

    grad_loss = (grad_loss_y.mean(dim=(0, 2, 3)) + grad_loss_x.mean(dim=(0, 2, 3))) / 2
    return grad_loss

decomposed_mae_gradient_weighted(pred, target, gradient_weight, pad_mode='constant')

MAE loss with spatial gradient matching penalty.

Source code in src/ocean_emulators/utils/loss.py
def decomposed_mae_gradient_weighted(
    pred: torch.Tensor,
    target: torch.Tensor,
    gradient_weight: float,
    pad_mode: str = "constant",
) -> torch.Tensor:
    """MAE loss with spatial gradient matching penalty."""
    mae_per_channel = decomposed_mae(pred, target)
    grad_loss = gradient_l1_loss(pred, target, pad_mode)
    return mae_per_channel + gradient_weight * grad_loss