Skip to content

Configuration

Overview

Configuration is defined by config.py and values are stored in YAML files within the configs/ directory. Configuration files can include other configuration files using the !include directive.

Each configuration file is associated with a Pydantic model — you can generate JSON schemas for them with uv run src/ocean_emulators/config_schema.py (which is run automatically in pre-commit). To associate a configuration file with a Pydantic model, generate the JSON schema (if it doesn't already exist) and then add this line to the top of the config file:

# yaml-language-server: $schema=path/to/schema.json

This is what the config_schema.py script uses to determine which model to validate against, and also enables autocomplete/type checking in VS Code via the YAML extension.

Command Line Configuration

The train and eval modules accept the configuration file as a positional argument. You can override arbitrary keys on the command line — see --help for details. When overriding an object (as opposed to a single scalar value) via the command line, you can either supply JSON like --data '{"key": "value"}' or a YAML file with a leading @ symbol: --data @configs/data/file.yaml.

Training runs create a YAML file in the checkpoint directory with the final configuration used which you can use to reproduce the run by passing to train e.g. uv run -m ocean_emulators.train path/to/config.yaml.

API Reference

ocean_emulators.config

JulianDate(s)

Represents a Julian date as a cftime.datetime at noon on the relevant day.

This is the format the OM4 data uses, so we match that here. TODO(jder): probably worth asserting the date format when opening the data.

Source code in src/ocean_emulators/config.py
def __init__(self, s: str):
    datetime = cftime.datetime.strptime(s, "%Y-%m-%d", calendar="julian")
    datetime = datetime.replace(hour=12)
    self.datetime = datetime

TimeConfig

Bases: BaseConfig

Represents a time slice of the data.

Endpoints are Julian dates (not times) but cftime stores them in datetimes. The final endpoint is exclusive.

overlaps(other)

Check if this time range overlaps with another time range.

Parameters:

Name Type Description Default
other Self

Another TimeConfig to check for overlap

required

Returns:

Type Description
bool

True if the time ranges overlap, False otherwise

Source code in src/ocean_emulators/config.py
def overlaps(self, other: Self) -> bool:
    """Check if this time range overlaps with another time range.

    Args:
        other: Another TimeConfig to check for overlap

    Returns:
        True if the time ranges overlap, False otherwise
    """
    return (
        self.start.datetime < other.end.datetime
        and self.end.datetime > other.start.datetime
    )

PerceiverConfig

Bases: BaseConfig

A standard config interface to various perceiver implementations.

Builds either a regular Perceiver (for the encoder, via build) or a PerceiverIO (for the decoder, via build_io). Both respect the shared implementation setting from FOMOConfig.perceiver_implementation.

build(in_channels, out_channels, max_patch_size, implementation)

Build a regular Perceiver (used by the encoder).

Source code in src/ocean_emulators/config.py
def build(
    self,
    in_channels: int,
    out_channels: int,
    max_patch_size: tuple[int, int],
    implementation: PerceiverImpl,
) -> nn.Module:
    """Build a regular Perceiver (used by the encoder)."""
    # This is not really a "frequency" but a maximum of the width appears to be reasonable from looking at the code.
    max_freq = max(*max_patch_size)

    num_freq_bands = 4
    if _use_flash(implementation):
        try:
            from flash_perceiver import Perceiver as FlashPerceiver  # type: ignore
        except ImportError as e:
            raise _flash_import_error() from e
        from einops.layers.torch import Rearrange

        # Flash perceiver expects (batch, seq_len, dim) and only adds rotary
        # positions on its latents — it has no built-in intra-patch
        # positional signal on the input. Naive Perceiver handles
        # (batch, ph, pw, dim) and adds 2D Fourier features via
        # `input_axis=2, fourier_encode_data=True`. Prepend an explicit
        # FourierFeatures2D so flash matches naive on intra-patch position.
        fourier_dim = fourier_features_2d_dim(num_freq_bands)
        perceiver: nn.Module = nn.Sequential(
            FourierFeatures2D(num_freq_bands=num_freq_bands, max_freq=max_freq),
            Rearrange("b ph pw v -> b (ph pw) v"),
            FlashPerceiver(
                latent_rotary_emb_dim=max_freq,
                depth=self.depth,
                input_dim=in_channels + fourier_dim,
                output_dim=out_channels,
                output_mode="average",
                latent_dim=self.latent_dim,
                num_latents=self.num_latents,
                use_flash_attn=True,
                weight_tie_layers=True,
                self_per_cross_attn=2,
            ),
        )
    elif _use_naive(implementation):
        perceiver = NaivePerceiver(
            num_freq_bands=num_freq_bands,
            max_freq=max_freq,
            depth=self.depth,
            input_axis=2,
            input_channels=in_channels,
            num_classes=out_channels,
            latent_dim=self.latent_dim,
            num_latents=self.num_latents,
            weight_tie_layers=True,
            self_per_cross_attn=2,
        )
    else:
        raise ValueError(f"Unknown perceiver implementation: {implementation}.")

    return perceiver

build_io(in_channels, queries_dim, out_channels, implementation)

Build a PerceiverIO (used by the decoder).

Source code in src/ocean_emulators/config.py
def build_io(
    self,
    in_channels: int,
    queries_dim: int,
    out_channels: int,
    implementation: PerceiverImpl,
) -> nn.Module:
    """Build a PerceiverIO (used by the decoder)."""
    if _use_flash(implementation):
        try:
            from flash_perceiver.perceiver import (  # type: ignore
                PerceiverIO as FlashPerceiverIO,  # type: ignore
            )
        except ImportError as e:
            raise _flash_import_error() from e
        perceiver_io: nn.Module = FlashPerceiverIO(
            depth=self.depth,
            input_dim=in_channels,
            query_dim=queries_dim,
            proj_dim=out_channels,
            num_latents=self.num_latents,
            latent_dim=self.latent_dim,
            use_flash_attn=True,
            weight_tie_layers=True,
        )
    elif _use_naive(implementation):
        from perceiver_pytorch.perceiver_io import PerceiverIO as NaivePerceiverIO

        perceiver_io = NaivePerceiverIO(
            depth=self.depth,
            dim=in_channels,
            queries_dim=queries_dim,
            logits_dim=out_channels,
            num_latents=self.num_latents,
            latent_dim=self.latent_dim,
            weight_tie_layers=True,
            decoder_ff=True,
        )
    else:
        raise ValueError(f"Unknown perceiver implementation: {implementation}.")

    return perceiver_io

DecoderConfig

Bases: BaseConfig

A PerceiverIO-based decoder configuration.

Uses PerceiverIO (with an explicit query mechanism) rather than a regular Perceiver. Output pixel positions are encoded as queries, so the output size is determined by the query count — not by num_latents.

When window_patches is set, the decoder tiles the output grid into spatial blocks of that many patches per side. Each block's PerceiverIO call receives only the overlapping latent tokens plus context_patches extra rings of neighbors, keeping cost bounded even when the latent grid is large (i.e. fine patch_extent).

ocean_emulators.config_base

BaseConfig

Bases: BaseModel

Base class for all configs.

TopLevelConfig(*args, **kwargs)

Bases: BaseSettings

Base class for top-level configs (ie tasks like train or eval).

Source code in src/ocean_emulators/config_base.py
def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)

from_yaml_and_cli(args_to_parse=None) classmethod

Load config from YAML & CLI with validation.

Source code in src/ocean_emulators/config_base.py
@classmethod
def from_yaml_and_cli(
    cls,
    args_to_parse: list[str] | None = None,
) -> Self:
    """Load config from YAML & CLI with validation."""
    parser = argparse.ArgumentParser(
        description=cls.__doc__,
        epilog=textwrap.dedent(
            """
            YAML files can include other YAML files using the !include tag,
            as in `data: !include configs/data/something.yaml`
            You can also replace any JSON argument listed above with a YAML file by
            specifying it with an @ symbol,
            eg `--some_param=@configs/data/something.yaml`.
            """
        ),
    )
    parser.add_argument("config", type=str, help="Path to config YAML file")

    cli_source = IncludeYamlCliSettingsSource(
        cls,
        root_parser=parser,
        # If args_to_parse is None, we parse argv, which is what `True` does
        cli_parse_args=args_to_parse if args_to_parse is not None else True,
    )

    # We do this after creating CliSettingsSource (which populates the parser)
    # so the help is complete on error.
    args = parser.parse_args(args_to_parse)

    # Then we read the YAML file specified in the CLI
    # Note that by default, YamlConfigSettingsSource will ignore missing files
    config_path = Path(args.config).expanduser().resolve()
    if not config_path.exists():
        raise FileNotFoundError(
            f"Config file `{args.config}` (full path: {config_path}) not found"
        )
    yaml_values = YamlConfigSettingsSource(cls, yaml_file=config_path)()

    return cls(
        _cli_settings_source=cli_source,
        **yaml_values,
    )

save_yaml(save_path)

Save config to YAML file.

Source code in src/ocean_emulators/config_base.py
def save_yaml(self, save_path: Path) -> None:
    """Save config to YAML file."""
    with open(save_path, "w") as f:
        yaml.dump(self.model_dump(), f)

IncludeYamlCliSettingsSource(*args, **kwargs)

Bases: CliSettingsSource

CliSettingsSource which permits @filename.yaml for JSON arguments.

Source code in src/ocean_emulators/config_base.py
def __init__(self, *args, **kwargs):
    super().__init__(*args, **kwargs)

register_include_constructor()

Set up yaml.safe_load to include other yaml files via !include.

Source code in src/ocean_emulators/config_base.py
def register_include_constructor():
    """Set up yaml.safe_load to include other yaml files via !include."""

    def include_constructor(loader: yaml.Loader, node: yaml.Node) -> Any:
        if hasattr(loader.stream, "name"):
            name = loader.stream.name  # type: ignore
        else:
            raise ValueError(
                "To support includes, you must load a file object, not a string"
            )
        filename = os.path.normpath(os.path.join(os.path.dirname(name), node.value))
        with open(filename) as f:
            return yaml.safe_load(f)

    # This is arguably unsafe, but we don't parse untrusted YAML
    yaml.loader.SafeLoader.add_constructor("!include", include_constructor)