Modules¶
U-Net Backbone¶
ocean_emulators.models.modules.unet_backbone
¶
UNetBackbone(in_channels, ch_width, dilation, n_layers, pad, create_block, downsampling_block, create_upsampling_block, checkpointing, drop_path_rate=0.0)
¶
Bases: Module
A configurable, convolutional or ConvNeXt[1] U-Net[2] implementation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ch_width
|
list[int]
|
The widths of CNN input channels going down into the U-Net. This module first builds
downsampling CNN blocks before reversing the |
required |
dilation
|
list[int]
|
List of dilation sizes for CNN blocks. See [3] for a general background. This list must
be one less than the length of |
required |
n_layers
|
list[int]
|
List of the number of CNN layers to be used in each block section of the U-Net. Typically,
this is set to all 1s. This value must match the length of |
required |
pad
|
str
|
The type of padding to use in all CNN blocks. Passed into |
required |
create_block
|
CoreBlockBuilder
|
A factory method that creates the CoreBlocks for all CNN layers. |
required |
downsampling_block
|
Module
|
A block that downsamples during the descent of the U-Net. |
required |
create_upsampling_block
|
UpsamplingBlockBuilder
|
A factory method that creates upsampling blocks during the ascent of the U-Net. |
required |
checkpointing
|
Checkpointing | None
|
The current mode for checkpointing (typically "all" or "simple"). None turns checkpointing off. |
required |
References
Source code in src/ocean_emulators/models/modules/unet_backbone.py
50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 | |
Encoder¶
ocean_emulators.models.modules.encoder
¶
PerceiverEncoder(in_channels, out_channels, patch_extent, perceiver)
¶
Bases: Module
A perceiver-based encoder for Samudra's flattened data (a whole column of the ocean, with history).
We adopt some of Aurora's positional encodings[1], which uses log-spaced fourier features with geometry-informed wavelengths. These encode 2d positions (the average latitude and longitude of each patch) as well as grid cell area (measured in km^2) for each token before it enters the processor.
Note: We assume that data along the lat/lon coordinates are positioned at the center of each grid point! Please ensure this is the case at the data processing time.
This encoder is designed to make the same number of patches with the same spatial extents across different scales
of input data (input data may vary in resolution of lat/lng grid). To accomplish this with a single perceiver model,
our forward call requires supplementary information: the resolution (a pair of Lat/Lon tensors), which is used to
make consistent positional encodings for patches across different scales. While higher resolution scales will
contain more data per patch, the patch will refer to the same physical area on Earth as all other scales.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
the number of input channels (roughly: time x variable x (surface + depths)). |
required |
out_channels
|
int
|
size of the latent dimension (aka, the embedding dimension). |
required |
patch_extent
|
tuple[float, float]
|
spatial extent of each patch measured in degrees of lat/lon. |
required |
perceiver
|
Module
|
the perceiver module implementation to use. |
required |
References
Source code in src/ocean_emulators/models/modules/encoder.py
patch_from(patch_extent, input_height, input_width)
¶
Calculate the patch size in lat/lng pixels (or coords) from the patch spatial extent and input grid size.
Source code in src/ocean_emulators/models/modules/encoder.py
Decoder¶
ocean_emulators.models.modules.decoder
¶
PerceiverDecoder(in_channels, out_channels, patch_extent, queries_dim, perceiver_io, window_patches, context_patches)
¶
Bases: Module
A PerceiverIO-based decoder that maps a latent patch grid to full-resolution output.
All nh * nw pos/scale-encoded latent tokens are passed as data to
the PerceiverIO[2], and every output pixel position is a query. Each
query cross-attends to the full latent representation, giving it global
spatial context — pixels near patch boundaries can attend to neighboring
patches, and the model can learn smooth inter-patch transitions.
Concretely:
- Add Aurora-style pos/scale encoding to the
nh * nwlatent tokens (telling the model where on the globe each patch is). - Pass all encoded latents as data to the PerceiverIO:
(B, nh * nw, C). - Build 3D unit-sphere queries
(x, y, z)for every output pixel from its lat/lon, embed them via a learned linear layer, and feed them to the PerceiverIO decoder head. - Inside the PerceiverIO:
a. Internal latents cross-attend to the
nh * nwdata tokens. b. The latents refine through several rounds of self-attention. c. A final cross-attention maps from queries to the refined latents, producing(B, H * W, out_channels). - Reshape to
(B, out_channels, H, W).
Spatial windowing: When window_patches is set, the latent grid
must be evenly divisible by window_patches. The grid is padded —
circular along longitude (so windows near lon=0 see context from
lon≈360) and constant-zero along latitude (poles are true boundaries)
— then Tensor.unfold extracts fixed-size overlapping windows.
Each block's PerceiverIO call receives the local data context plus
the corresponding pixel queries. Setting context_patches=None
gives each window full access to all latent tokens (windowed queries,
global data).
Because pixel queries are unit-sphere coordinates — continuous values determined by lat/lon, not grid indices — the same PerceiverIO generalizes across resolutions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels from the processor. |
required |
out_channels
|
int
|
Number of output channels per pixel. |
required |
patch_extent
|
tuple[float, float]
|
Spatial extent of each patch in degrees (lat, lon). Used for computing positional and scale encodings on latent tokens. |
required |
queries_dim
|
int
|
Embedding dimension for pixel-position queries. |
required |
perceiver_io
|
Module
|
A PerceiverIO module. |
required |
window_patches
|
int | None
|
Side length (in patches) of each spatial decode window.
If |
required |
context_patches
|
int | None
|
Number of extra patch rings around each window to
include as data context. Only used when |
required |
References
Source code in src/ocean_emulators/models/modules/decoder.py
Blocks¶
ocean_emulators.models.modules.blocks
¶
PointwiseLinear(in_channels, out_channels)
¶
Bases: Module
A 1×1 convolution implemented as nn.Linear.
Mathematically equivalent to Conv2d(kernel_size=1), but avoids the non-contiguous gradient strides that 1×1 convs produce, which cause DDP to copy gradients instead of using zero-copy views.
This optimization is use in the official ConvNext implementation0.
Source code in src/ocean_emulators/models/modules/blocks.py
ZonallyPeriodicBilinearUpsample(upsampling=2)
¶
Bases: Module
Bilinear upsampling that enforces periodicity along the x/longitude axis.
Source code in src/ocean_emulators/models/modules/blocks.py
DropPath(drop_prob=0.0)
¶
Bases: Module
Drop path dropout (for skip connections).
During training, randomly drops entire samples' skip connections
with probability drop_prob, scaling survivors by 1/(1-p) to preserve
expected values. Implemented via nn.Dropout applied to a per-sample
mask of ones.
References
[0]: Rethinking U-net Skip Connections for Biomedical Image Segmentation (https://arxiv.org/abs/2402.08276) [1]: Dropout Reduces Underfitting (https://arxiv.org/abs/2303.01500)
Source code in src/ocean_emulators/models/modules/blocks.py
ConvNeXtBlock(in_channels=300, out_channels=1, kernel_size=3, dilation=1, n_layers=1, activation=CappedGELU, pad='circular', upscale_factor=4, norm='batch', checkpoint_simple=False, pointwise_linear=False)
¶
Bases: CoreBlock
A convolution block as reported in https://github.com/CognitiveModeling/dlwp-hpx/blob/main/src/dlwp-hpx/dlwp/model/modules/blocks.py.
This is a modified version of the actual ConvNextblock which is used in the HealPix paper. Use of dilations here.
Source code in src/ocean_emulators/models/modules/blocks.py
Activations¶
ocean_emulators.models.modules.activations
¶
ReLU(**kwargs)
¶
CappedLeakyReLU(cap_value=10.0, **kwargs)
¶
Bases: Module
Implements a ReLU with capped maximum value.
:param cap_value: float: value at which to clip activation :param kwargs: passed to torch.nn.LeadyReLU
Source code in src/ocean_emulators/models/modules/activations.py
CappedGELU(cap_value=10.0, **kwargs)
¶
Bases: Module
Implements a ReLU with capped maximum value.
:param cap_value: float: value at which to clip activation :param kwargs: passed to torch.nn.LeadyReLU