Skip to content

Modules

U-Net Backbone

ocean_emulators.models.modules.unet_backbone

UNetBackbone(in_channels, ch_width, dilation, n_layers, pad, create_block, downsampling_block, create_upsampling_block, checkpointing, drop_path_rate=0.0)

Bases: Module

A configurable, convolutional or ConvNeXt[1] U-Net[2] implementation.

Parameters:

Name Type Description Default
ch_width list[int]

The widths of CNN input channels going down into the U-Net. This module first builds downsampling CNN blocks before reversing the ch_widths to build upsampling CNN blocks. Typically, these values should be set in monotonically non-decreasing sizes.

required
dilation list[int]

List of dilation sizes for CNN blocks. See [3] for a general background. This list must be one less than the length of ch_widths.

required
n_layers list[int]

List of the number of CNN layers to be used in each block section of the U-Net. Typically, this is set to all 1s. This value must match the length of dilation.

required
pad str

The type of padding to use in all CNN blocks. Passed into torch.functional.pad's mode argument.

required
create_block CoreBlockBuilder

A factory method that creates the CoreBlocks for all CNN layers.

required
downsampling_block Module

A block that downsamples during the descent of the U-Net.

required
create_upsampling_block UpsamplingBlockBuilder

A factory method that creates upsampling blocks during the ascent of the U-Net.

required
checkpointing Checkpointing | None

The current mode for checkpointing (typically "all" or "simple"). None turns checkpointing off.

required
References
Source code in src/ocean_emulators/models/modules/unet_backbone.py
def __init__(
    self,
    in_channels: int,
    ch_width: list[int],
    dilation: list[int],
    n_layers: list[int],
    pad: str,
    create_block: CoreBlockBuilder,
    downsampling_block: nn.Module,
    create_upsampling_block: UpsamplingBlockBuilder,
    checkpointing: "Checkpointing | None",
    drop_path_rate: float = 0.0,
):
    super().__init__()
    self.in_channels = in_channels
    self.out_channels: int = ch_width[0]

    # Create local copies of config lists that will be reversed
    ch_width = [in_channels] + ch_width.copy()
    dilation = dilation.copy()
    n_layers = n_layers.copy()
    self.pad = pad

    match checkpointing:
        case "all":
            self.checkpoint_all = True
            checkpoint_simple = False
        case "simple":
            self.checkpoint_all = False
            checkpoint_simple = True
        case None:
            self.checkpoint_all = False
            checkpoint_simple = False
        case _:
            assert_never(checkpointing)

    # going down
    layers: list[nn.Module] = []
    for i, (a, b) in enumerate(pairwise(ch_width)):
        # Core block
        layers.append(
            create_block(
                in_channels=a,
                out_channels=b,
                dilation=dilation[i],
                n_layers=n_layers[i],
                pad=pad,
                checkpoint_simple=checkpoint_simple,
            )
        )
        # Down sampling block
        layers.append(downsampling_block)

    # Middle block
    layers.append(
        create_block(
            in_channels=b,
            out_channels=b,
            dilation=dilation[i],
            n_layers=n_layers[i],
            pad=pad,
            checkpoint_simple=checkpoint_simple,
        )
    )

    # First upsampling
    layers.append(create_upsampling_block(in_channels=b, out_channels=b))

    # Reverse for upsampling path
    ch_width.reverse()
    dilation.reverse()
    n_layers.reverse()

    # going up
    for i, (a, b) in enumerate(pairwise(ch_width[:-1])):
        layers.append(
            create_block(
                in_channels=a,
                out_channels=b,
                dilation=dilation[i],
                n_layers=n_layers[i],
                pad=pad,
                checkpoint_simple=checkpoint_simple,
            )
        )
        layers.append(create_upsampling_block(in_channels=b, out_channels=b))

    # Final conv block
    layers.append(
        create_block(
            in_channels=b,
            out_channels=b,  # this is the same as self.out_channels
            dilation=dilation[i],
            n_layers=n_layers[i],
            pad=pad,
            checkpoint_simple=checkpoint_simple,
        )
    )

    first_block = layers[0]
    assert isinstance(first_block, CoreBlock)
    self.N_pad = first_block.N_pad
    self.layers = nn.ModuleList(layers)
    self.num_steps = int(len(ch_width) - 1)
    self.drop_path = DropPath(drop_path_rate)

Encoder

ocean_emulators.models.modules.encoder

PerceiverEncoder(in_channels, out_channels, patch_extent, perceiver)

Bases: Module

A perceiver-based encoder for Samudra's flattened data (a whole column of the ocean, with history).

We adopt some of Aurora's positional encodings[1], which uses log-spaced fourier features with geometry-informed wavelengths. These encode 2d positions (the average latitude and longitude of each patch) as well as grid cell area (measured in km^2) for each token before it enters the processor.

Note: We assume that data along the lat/lon coordinates are positioned at the center of each grid point! Please ensure this is the case at the data processing time.

This encoder is designed to make the same number of patches with the same spatial extents across different scales of input data (input data may vary in resolution of lat/lng grid). To accomplish this with a single perceiver model, our forward call requires supplementary information: the resolution (a pair of Lat/Lon tensors), which is used to make consistent positional encodings for patches across different scales. While higher resolution scales will contain more data per patch, the patch will refer to the same physical area on Earth as all other scales.

Parameters:

Name Type Description Default
in_channels int

the number of input channels (roughly: time x variable x (surface + depths)).

required
out_channels int

size of the latent dimension (aka, the embedding dimension).

required
patch_extent tuple[float, float]

spatial extent of each patch measured in degrees of lat/lon.

required
perceiver Module

the perceiver module implementation to use.

required
References
Source code in src/ocean_emulators/models/modules/encoder.py
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    patch_extent: tuple[float, float],
    perceiver: nn.Module,
) -> None:
    super().__init__()
    self.in_channels = in_channels
    self.out_channels: int = out_channels  # aka, `embed_dim`.
    self.patch_extent = patch_extent
    self.perceiver = perceiver
    # TODO(#451): The input to these position and scale linear units could be a hparam.
    self.pos_embed = nn.Linear(self.out_channels, self.out_channels)
    self.scale_embed = nn.Linear(self.out_channels, self.out_channels)

patch_from(patch_extent, input_height, input_width)

Calculate the patch size in lat/lng pixels (or coords) from the patch spatial extent and input grid size.

Source code in src/ocean_emulators/models/modules/encoder.py
def patch_from(
    patch_extent: tuple[float, float], input_height: int, input_width: int
) -> tuple[int, int]:
    """Calculate the patch size in lat/lng pixels (or coords) from the patch spatial extent and input grid size."""
    lat_spacing = 180.0 / input_height  # Full sphere is 180 degrees (pole to pole)
    lon_spacing = 360.0 / input_width  # Full circle is 360 degrees

    # Calculate patch size to match target extent
    patch_h = int(round(patch_extent[0] / lat_spacing))
    patch_w = int(round(patch_extent[1] / lon_spacing))

    return patch_h, patch_w

Decoder

ocean_emulators.models.modules.decoder

PerceiverDecoder(in_channels, out_channels, patch_extent, queries_dim, perceiver_io, window_patches, context_patches)

Bases: Module

A PerceiverIO-based decoder that maps a latent patch grid to full-resolution output.

All nh * nw pos/scale-encoded latent tokens are passed as data to the PerceiverIO[2], and every output pixel position is a query. Each query cross-attends to the full latent representation, giving it global spatial context — pixels near patch boundaries can attend to neighboring patches, and the model can learn smooth inter-patch transitions.

Concretely:

  1. Add Aurora-style pos/scale encoding to the nh * nw latent tokens (telling the model where on the globe each patch is).
  2. Pass all encoded latents as data to the PerceiverIO: (B, nh * nw, C).
  3. Build 3D unit-sphere queries (x, y, z) for every output pixel from its lat/lon, embed them via a learned linear layer, and feed them to the PerceiverIO decoder head.
  4. Inside the PerceiverIO: a. Internal latents cross-attend to the nh * nw data tokens. b. The latents refine through several rounds of self-attention. c. A final cross-attention maps from queries to the refined latents, producing (B, H * W, out_channels).
  5. Reshape to (B, out_channels, H, W).

Spatial windowing: When window_patches is set, the latent grid must be evenly divisible by window_patches. The grid is padded — circular along longitude (so windows near lon=0 see context from lon≈360) and constant-zero along latitude (poles are true boundaries) — then Tensor.unfold extracts fixed-size overlapping windows. Each block's PerceiverIO call receives the local data context plus the corresponding pixel queries. Setting context_patches=None gives each window full access to all latent tokens (windowed queries, global data).

Because pixel queries are unit-sphere coordinates — continuous values determined by lat/lon, not grid indices — the same PerceiverIO generalizes across resolutions.

Parameters:

Name Type Description Default
in_channels int

Number of input channels from the processor.

required
out_channels int

Number of output channels per pixel.

required
patch_extent tuple[float, float]

Spatial extent of each patch in degrees (lat, lon). Used for computing positional and scale encodings on latent tokens.

required
queries_dim int

Embedding dimension for pixel-position queries.

required
perceiver_io Module

A PerceiverIO module. dim must equal in_channels, queries_dim must match this decoder's queries_dim, and logits_dim must equal out_channels.

required
window_patches int | None

Side length (in patches) of each spatial decode window. If None, all patches are used globally (no windowing). E.g. window_patches=8 means each PerceiverIO call covers an 8x8 block of patches.

required
context_patches int | None

Number of extra patch rings around each window to include as data context. Only used when window_patches is set. Default 1 gives each window one ring of neighboring patches beyond its own block. None means full context — every window sees all latent tokens (windowed queries but global data attention).

required
References
Source code in src/ocean_emulators/models/modules/decoder.py
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    patch_extent: tuple[float, float],
    queries_dim: int,
    perceiver_io: nn.Module,
    window_patches: int | None,
    context_patches: int | None,
) -> None:
    super().__init__()
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.patch_extent = patch_extent
    if window_patches is None and context_patches is not None:
        raise ValueError(
            "window_patches must be set in order for context_patches to be set."
        )
    self.window_patches = window_patches
    self.context_patches = context_patches

    # TODO(#451): The input to these position and scale linear units could be a hparam.
    # Same pos/scale linear layers as the encoder, but applied *before* the
    # perceiver (the encoder applies them after).
    self.pos_embed = nn.Linear(in_channels, in_channels)
    self.scale_embed = nn.Linear(in_channels, in_channels)

    # Embed 3D unit-sphere coordinates into queries_dim for the PerceiverIO decoder head.
    self.query_embed = nn.Linear(3, queries_dim)

    self.perceiver_io = perceiver_io

Blocks

ocean_emulators.models.modules.blocks

PointwiseLinear(in_channels, out_channels)

Bases: Module

A 1×1 convolution implemented as nn.Linear.

Mathematically equivalent to Conv2d(kernel_size=1), but avoids the non-contiguous gradient strides that 1×1 convs produce, which cause DDP to copy gradients instead of using zero-copy views.

This optimization is use in the official ConvNext implementation0.

Source code in src/ocean_emulators/models/modules/blocks.py
def __init__(self, in_channels: int, out_channels: int):
    super().__init__()
    self.linear = torch.nn.Linear(in_channels, out_channels)

ZonallyPeriodicBilinearUpsample(upsampling=2)

Bases: Module

Bilinear upsampling that enforces periodicity along the x/longitude axis.

Source code in src/ocean_emulators/models/modules/blocks.py
def __init__(self, upsampling: int | tuple[int, int] = 2):
    super().__init__()
    if isinstance(upsampling, int):
        upsampling = (upsampling, upsampling)
    if tuple(upsampling) != (2, 2):
        raise ValueError(
            "ZonallyPeriodicBilinearUpsample only supports 2x upsampling"
        )
    self.scale_h, self.scale_w = upsampling

DropPath(drop_prob=0.0)

Bases: Module

Drop path dropout (for skip connections).

During training, randomly drops entire samples' skip connections with probability drop_prob, scaling survivors by 1/(1-p) to preserve expected values. Implemented via nn.Dropout applied to a per-sample mask of ones.

References

[0]: Rethinking U-net Skip Connections for Biomedical Image Segmentation (https://arxiv.org/abs/2402.08276) [1]: Dropout Reduces Underfitting (https://arxiv.org/abs/2303.01500)

Source code in src/ocean_emulators/models/modules/blocks.py
def __init__(self, drop_prob: float = 0.0):
    super().__init__()
    self.dropout = torch.nn.Dropout(p=drop_prob)

ConvNeXtBlock(in_channels=300, out_channels=1, kernel_size=3, dilation=1, n_layers=1, activation=CappedGELU, pad='circular', upscale_factor=4, norm='batch', checkpoint_simple=False, pointwise_linear=False)

Bases: CoreBlock

A convolution block as reported in https://github.com/CognitiveModeling/dlwp-hpx/blob/main/src/dlwp-hpx/dlwp/model/modules/blocks.py.

This is a modified version of the actual ConvNextblock which is used in the HealPix paper. Use of dilations here.

Source code in src/ocean_emulators/models/modules/blocks.py
def __init__(
    self,
    in_channels: int = 300,
    out_channels: int = 1,
    kernel_size: int = 3,
    dilation: int = 1,
    n_layers: int = 1,
    activation: Callable[[], torch.nn.Module] = CappedGELU,
    pad="circular",
    upscale_factor: int = 4,
    norm="batch",
    checkpoint_simple: bool = False,
    pointwise_linear: bool = False,
):
    super().__init__(in_channels, out_channels, kernel_size, dilation, pad)
    assert n_layers == 1, "Can only use a single layer here!"

    # Instantiate pointwise linear to increase/decrease channel depth if necessary
    if in_channels == out_channels:
        self.skip_module = lambda x: x  # Identity-function required in forward pass
    else:
        self.skip_module = _pointwise(pointwise_linear, in_channels, out_channels)

    # Convolution block
    convblock: list[torch.nn.Module] = []
    convblock.append(
        torch.nn.Conv2d(
            in_channels=in_channels,
            out_channels=int(in_channels * upscale_factor),
            kernel_size=kernel_size,
            dilation=dilation,
        )
    )
    # BatchNorm
    if norm == "batch":
        convblock.append(torch.nn.BatchNorm2d(in_channels * upscale_factor))
    # Instance Norm
    elif norm == "instance":
        convblock.append(torch.nn.InstanceNorm2d(in_channels * upscale_factor))
    elif norm == "nonorm":
        pass
    else:
        raise NotImplementedError
    if activation is not None:
        convblock.append(activation())
    convblock.append(
        torch.nn.Conv2d(
            in_channels=int(in_channels * upscale_factor),
            out_channels=int(in_channels * upscale_factor),
            kernel_size=kernel_size,
            dilation=dilation,
        )
    )
    # BatchNorm
    if norm == "batch":
        convblock.append(torch.nn.BatchNorm2d(in_channels * upscale_factor))
    # Instance Norm
    elif norm == "instance":
        convblock.append(torch.nn.InstanceNorm2d(in_channels * upscale_factor))
    elif norm == "nonorm":
        pass
    else:
        raise NotImplementedError
    if activation is not None:
        convblock.append(activation())
    # Linear postprocessing
    convblock.append(
        _pointwise(
            pointwise_linear, int(in_channels * upscale_factor), out_channels
        )
    )
    self.convblock = torch.nn.Sequential(*convblock)
    self.checkpoint_simple = checkpoint_simple

Activations

ocean_emulators.models.modules.activations

ReLU(**kwargs)

Bases: Module

Implements a ReLU.

:param kwargs: passed to torch.nn.ReLU

Source code in src/ocean_emulators/models/modules/activations.py
def __init__(self, **kwargs):
    """
    :param kwargs: passed to torch.nn.ReLU
    """
    super().__init__()
    self.relu = torch.nn.ReLU(**kwargs)

CappedLeakyReLU(cap_value=10.0, **kwargs)

Bases: Module

Implements a ReLU with capped maximum value.

:param cap_value: float: value at which to clip activation :param kwargs: passed to torch.nn.LeadyReLU

Source code in src/ocean_emulators/models/modules/activations.py
def __init__(self, cap_value=10.0, **kwargs):
    """
    :param cap_value: float: value at which to clip activation
    :param kwargs: passed to torch.nn.LeadyReLU
    """
    super().__init__()
    self.relu = torch.nn.LeakyReLU(**kwargs)
    self.cap = torch.nn.Buffer(torch.tensor(cap_value, dtype=torch.float32))

CappedGELU(cap_value=10.0, **kwargs)

Bases: Module

Implements a ReLU with capped maximum value.

:param cap_value: float: value at which to clip activation :param kwargs: passed to torch.nn.LeadyReLU

Source code in src/ocean_emulators/models/modules/activations.py
def __init__(self, cap_value=10.0, **kwargs):
    """
    :param cap_value: float: value at which to clip activation
    :param kwargs: passed to torch.nn.LeadyReLU
    """
    super().__init__()
    self.gelu = torch.nn.GELU(**kwargs)
    self.cap = torch.nn.Buffer(torch.tensor(cap_value, dtype=torch.float32))