Skip to content

Samudra

The Samudra model uses a ConvNeXt U-Net backbone for autoregressive ocean emulation. Both Samudra (v1) and Samudra 2 (v2) share the same Samudra class — the architectural differences are driven by configuration.

Samudra v1 vs Samudra 2

Samudra (v1) Samudra 2 (v2)
Channel widths [200, 250, 300, 400] [280, 380, 480, 520]
ConvNeXt expansion factor 4 2
Upsampling Bilinear Zonally periodic
Loss MSE Dynamic variance-weighted MSE (limit: 20)
Resolutions 1°, 1/2°, 1/4°

Samudra 2 widens the U-Net stages by ~40% while reducing the block-internal expansion factor, shifting capacity toward inter-stage features. The dynamic loss reweights per-channel MSE inversely by each channel's running prediction error, amplifying the gradient signal from slow-evolving deep-ocean fields.

Configs:

  • Samudra v1: configs/samudra_om4_v1/
  • Samudra 2: configs/samudra_om4_v2/
  • Samudra 2 (high-res): configs/samudra_om4_v2_highres/

API Reference

ocean_emulators.models.samudra

Samudra(in_channels, out_channels, pred_residuals, last_kernel_size, pad, unet, corrector, pos_channels, add_3d_coordinates, hist, grid_size, gradient_detach_interval, use_bfloat16)

Bases: BaseModel

Samudra ocean emulator using a ConvNeXt U-Net backbone.

Implements the Samudra (and Samudra 2) model architecture for single-scale ocean emulation.

Source code in src/ocean_emulators/models/samudra.py
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    pred_residuals: bool,
    last_kernel_size: int,
    pad: str,
    unet: UNetBackbone,
    corrector: nn.Module | None,
    pos_channels: int,
    add_3d_coordinates: nn.Module | None,
    hist: int,
    grid_size: GridSize,
    gradient_detach_interval: int,
    use_bfloat16: bool,
):
    super().__init__(
        in_channels=in_channels,
        out_channels=out_channels,
        hist=hist,
        pred_residuals=pred_residuals,
        last_kernel_size=last_kernel_size,
        pad=pad,
        gradient_detach_interval=gradient_detach_interval,
    )

    if pos_channels > 0:
        self.positional_params = nn.Parameter(torch.empty(pos_channels, *grid_size))
        nn.init.normal_(self.positional_params, mean=0.0, std=1e-5)
    else:
        self.register_parameter("positional_params", None)

    self.add_3d_coordinates = add_3d_coordinates
    self.unet = unet
    self.decoder = nn.Conv2d(unet.out_channels, out_channels, last_kernel_size)

    self.corrector = corrector
    self.use_bfloat16 = use_bfloat16