Skip to content

FOMO

ocean_emulators.models.fomo

FOMO(in_channels, out_channels, pred_residuals, last_kernel_size, pad, add_3d_coordinates, encoder, processor, decoder, hist, checkpointing, gradient_detach_interval, use_bfloat16)

Bases: BaseModel

FOMO: A Foundation Model for the Oceans + Observations.

Currently, this model is used only as a physical ocean emulator.

Source code in src/ocean_emulators/models/fomo.py
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    pred_residuals: bool,
    last_kernel_size: int,
    pad: str,
    add_3d_coordinates: nn.Module | None,
    encoder: PerceiverEncoder,
    processor: UNetBackbone,
    decoder: PerceiverDecoder,
    hist: int,
    checkpointing: "Checkpointing | None",
    gradient_detach_interval: int,
    use_bfloat16: bool,
):
    super().__init__(
        in_channels=in_channels,
        out_channels=out_channels,
        hist=hist,
        pred_residuals=pred_residuals,
        last_kernel_size=last_kernel_size,
        pad=pad,
        gradient_detach_interval=gradient_detach_interval,
    )

    self.maybe_add_3d_coordinates = add_3d_coordinates
    self.encoder = encoder
    self.processor = processor
    self.decoder = decoder
    self.use_bfloat16 = use_bfloat16

    if checkpointing == "all":
        apply_activation_checkpointing(
            self,
            check_fn=lambda m: isinstance(m, _checkpoint_types),
        )