Skip to content

Datasets

ocean_emulators.datasets

InferenceDataset(src, prognostic_var_names, boundary_var_names, hist, normalize_before_mask, masked_fill_value, long_rollout)

Bases: Dataset

This class is used for inference rollouts.

It creates rolling indices to keep track of histories/past states. For example, Hist=0 ; 0->[0, 1]; 1->[1, 2]; 2->[2, 3]; 3->[3, 4] Hist=1 ; 0->[[0, 1], [2, 3]]; 1->[[2, 3], [4, 5]]; 2->[[4, 5], [6, 7]]; 3->[[6, 7], [8, 9]] Hist=2 ; 0->[[0, 1, 2], [3, 4, 5]]; 1->[[3, 4, 5], [6, 7, 8]]; 2->[[6, 7, 8], [9, 10, 11]]; 3->[[9, 10, 11], [12, 13, 14]]

Source code in src/ocean_emulators/datasets.py
@elapsed
def __init__(
    self,
    src: DataSource,
    prognostic_var_names,
    boundary_var_names,
    hist,
    normalize_before_mask,
    masked_fill_value,
    long_rollout,
):
    super().__init__()
    # NOTE: Keep tensors on CPU during initialization. This allows the dataset
    # to be passed between DataLoader worker processes. Call to(device) before
    # using the dataset for inference.

    self.hist = hist

    self.num_prognostic_channels = (hist + 1) * len(prognostic_var_names)
    data = src.data
    self.input_res = src.resolution
    self._prognostic_src = src.filter(prognostic_var_names, prefix="prognostic")
    self._boundary_src = src.filter(boundary_var_names, prefix="boundary")
    self._times = data.time
    self.normalize_before_mask = normalize_before_mask
    self.masked_fill_value = masked_fill_value

    time_indices = np.arange(data.time.size)
    indices = xr.DataArray(
        time_indices,
        dims=["time"],
        coords={"time": time_indices},
    )
    total_steps = 2 * self.hist + 1
    rolling_indices = indices.rolling(
        time=len(time_indices) - total_steps, center=False
    ).construct("window_dim")
    rolling_indices = rolling_indices.transpose("window_dim", "time").isel(
        time=slice(len(time_indices) - total_steps - 1, None)
    )  # Remove first few null indices
    self.rolling_indices = rolling_indices.isel(
        window_dim=slice(0, None, self.hist + 1)
    )  # Skip indices based on history
    self.rolling_indices = self.rolling_indices.astype(int)

    if long_rollout:
        logger.info(
            f"Long rollout will use input at time {data.time.values[0]} and produce"
            f" output at {data.time.values[self.hist + 1]}"
        )

    self.wet: PrognosticMask = src.masks.prognostic
    self.wet_surface: GridMask = src.masks.boundary
    self.wet_label = src.masks.prognostic_with_hist(self.hist)
    self.size = len(self.rolling_indices)

    if using_gpu():
        self.wet = self.wet.pin_memory()
        self.wet_surface = self.wet_surface.pin_memory()
        self.wet_label = self.wet_label.pin_memory()

    # Inference only currently supports the same output resolution as the input
    # resolution.
    self.ctx = GridContext(self.wet_label, self.input_res, self.input_res)

to(device)

Move the dataset's context tensors to the specified device.

Call this before using the dataset for inference to ensure tensors are on the correct device (GPU).

Source code in src/ocean_emulators/datasets.py
def to(self, device: torch.device) -> "InferenceDataset":
    """Move the dataset's context tensors to the specified device.

    Call this before using the dataset for inference to ensure tensors
    are on the correct device (GPU).
    """
    self.ctx = self.ctx.to(device)
    self.wet_label = self.wet_label.to(device, non_blocking=True)
    return self

get_boundary(step)

Return boundary at the requested step.

Source code in src/ocean_emulators/datasets.py
def get_boundary(self, step: int) -> Boundary:
    """Return boundary at the requested step."""
    x_index = self._get_x_index(step)
    boundary = self._get_boundary(x_index)
    return boundary

RawTrainData(dataset_id)

Source code in src/ocean_emulators/datasets.py
def __init__(self, dataset_id: "TorchTrainDataset.Id"):
    self.dataset_id: TorchTrainDataset.Id = dataset_id
    self.raw_data: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor]] = []
    self.load_stats: LoadStats | None = None

insert(input_, boundary, label)

Add a prognostic input, boundary, and prognostic label as the last step.

Source code in src/ocean_emulators/datasets.py
def insert(
    self,
    input_: torch.Tensor,
    boundary: torch.Tensor,
    label: torch.Tensor,
):
    """Add a prognostic input, boundary, and prognostic label as the last step."""
    self.raw_data.append((input_, boundary, label))

TrainData(num_prognostic_channels, num_boundary_channels, ctx)

A single batch of training data.

A single batch contains multiple steps worth of Example entries, each of which is a (prognostic_input, boundary_input, label) triple. The prognostic and boundary tensors are carried separately because the FOMO model encodes them separately (Samudra just concatenates them later).

Source code in src/ocean_emulators/datasets.py
def __init__(
    self, num_prognostic_channels: int, num_boundary_channels: int, ctx: GridContext
):
    self.num_prognostic_channels = num_prognostic_channels
    self.num_boundary_channels = num_boundary_channels
    self.ctx = ctx
    self.example_by_step: list[Example] = []
    self.load_stats: LoadStats | None = None

append(prognostic_input, boundary_input, label)

Add another Example as a new step.

Source code in src/ocean_emulators/datasets.py
def append(
    self, prognostic_input: Prognostic, boundary_input: Boundary, label: Prognostic
) -> None:
    """Add another Example as a new step."""
    self.example_by_step.append((prognostic_input, boundary_input, label))

__getitem__(step)

Converts index (step) into (prognostic, boundary, label) triple.

Source code in src/ocean_emulators/datasets.py
def __getitem__(self, step: int) -> Example:
    """Converts index (step) into (prognostic, boundary, label) triple."""
    return self.example_by_step[step]

TorchTrainDataset(src, dst, prognostic_var_names, boundary_var_names, hist, steps, normalize_before_mask, masked_fill_value, stride=1, concurrent_compute_=False)

Bases: Dataset[RawTrainData]

This class is used for training and validation.

It creates rolling indices to keep track of histories/past states. But different from InferenceDataset, as it creates rolling indices based on stride. By default, the sliding window / stride is 1.

We make use of TrainData class to store a single sample.

For example, Hist=0 ; TD: step=0->[0, 1]; step=1->[1, 2]; step=2->[2, 3]; step=3->[3, 4] Hist=1 ; TD: step=0->[[0, 1], [2, 3]]; step=1->[[2, 3], [4, 5]]; step=2->[[4, 5], [6, 7]]; step=3->[[6, 7], [8, 9]] Hist=2 ; TD: step=0->[[0, 1, 2], [3, 4, 5]]; step=1->[[3, 4, 5], [6, 7, 8]]; step=2->[[6, 7, 8], [9, 10, 11]]; step=3->[[9, 10, 11], [12, 13, 14]]

Source code in src/ocean_emulators/datasets.py
@elapsed
def __init__(
    self,
    src: DataSource,
    dst: DataSource | None,
    prognostic_var_names: PrognosticVarNames,
    boundary_var_names: BoundaryVarNames,
    hist: int,
    steps: int,
    normalize_before_mask: bool,
    masked_fill_value: float,
    stride: int = 1,
    concurrent_compute_: bool = False,
):
    super().__init__()
    self.id = f"{self.__class__.__name__}_{str(id(self))}"
    # If the src and dst DataSource are the same, we can do a lot less work.
    srcs = [src, dst] if dst else [src]

    self.hist: int = hist
    self.steps: int = steps
    self.stride: int = stride
    self.normalize_before_mask: bool = normalize_before_mask
    self.masked_fill_value: float = masked_fill_value
    self._concurrent_compute = concurrent_compute_

    self.num_prognostic_channels: int = (hist + 1) * len(prognostic_var_names)
    self.num_boundary_channels: int = (hist + 1) * len(boundary_var_names)
    assert np.array_equal(srcs[0].data.time, srcs[-1].data.time), (
        "src and dst DataSource have different time slices!"
    )
    time_ = src.data.time
    self.prognostic_srcs = [
        src.filter(prognostic_var_names, prefix="prog") for src in srcs
    ]
    self.boundary_src = src.filter(boundary_var_names, prefix="boundary")

    # This class will be used only for training and validation
    total_steps: int = 2 * self.hist + 2

    # Calculate the number of windows
    num_windows = time_.size - (total_steps - 1) * self.stride

    # Create base indices
    indices = np.arange(num_windows)
    indices_da = xr.DataArray(indices, dims=["window"])

    # Create window dimension
    window_dim = xr.DataArray(np.arange(total_steps), dims=["time"])

    # Construct rolling indices
    self.rolling_indices: Float[xr.DataArray, "window time"] = (
        indices_da + stride * window_dim
    )

    # NB(alxmrs): Keep masks on CPU - will be moved to GPU in to_train_data()
    self.wet_prognostic: list[PrognosticMask] = [
        src.masks.prognostic for src in srcs
    ]
    self.wet_surface: GridMask = src.masks.boundary

    self.ctx = GridContext(
        label_mask=self.prognostic_srcs[-1].masks.prognostic_with_hist(self.hist),
        input_resolution_cpu=self.prognostic_srcs[0].resolution,
        output_resolution_cpu=self.prognostic_srcs[-1].resolution,
    )

    self.size: int = (
        time_.size
        - self.steps * (self.hist + 1) * self.stride
        - self.hist * self.stride
    )

to_train_data(raw_train_data, device)

Convert RawTrainData to TrainData, moving tensors to the specified device.

Parameters:

Name Type Description Default
raw_train_data RawTrainData

CPU data from worker process

required
device device

Target device (typically GPU) to move tensors to

required

Returns:

Type Description
TrainData

TrainData with tensors on the target device

Source code in src/ocean_emulators/datasets.py
def to_train_data(
    self, raw_train_data: RawTrainData, device: torch.device
) -> TrainData:
    """Convert RawTrainData to TrainData, moving tensors to the specified device.

    Args:
        raw_train_data: CPU data from worker process
        device: Target device (typically GPU) to move tensors to

    Returns:
        TrainData with tensors on the target device
    """
    train_data = TrainData(
        self.num_prognostic_channels,
        self.num_boundary_channels,
        self.ctx.to(device),
    )
    for input_, boundary, label in raw_train_data.raw_data:
        prog_input, boundary_input, label_tensor = self._to_example(
            OceanData.from_data_source(
                input_,
                self.wet_prognostic[0],
                self.prognostic_srcs[0],
            ).to(device=device, non_blocking=True),
            OceanData.from_data_source(
                boundary,
                self.wet_surface,
                self.boundary_src,
            ).to(device=device, non_blocking=True),
            OceanData.from_data_source(
                label, self.wet_prognostic[-1], self.prognostic_srcs[-1]
            ).to(device=device, non_blocking=True),
        )
        train_data.append(prog_input, boundary_input, label_tensor)
    train_data.load_stats = raw_train_data.load_stats
    return train_data

TrainDataLoader(dataloader, datasets, device)

Wrapper around a torch DataLoader that handles GPU post-processing.

This class wraps a DataLoader[RawTrainData] and converts the raw data to TrainData by applying GPU-based normalization and masking. This allows the data loading process to handle I/O while the main process handles GPU operations.

Since the data samples flow from one process to the other, we want to tie them back to the dataset they came from which knows how to do that second half once they're in the main process which has GPU access set up. To do that, each data sample (which could come from a different dataset) has a dataset ID -- datasets maps from those IDs to the original datasets.

Source code in src/ocean_emulators/datasets.py
def __init__(
    self,
    dataloader: torch.utils.data.DataLoader[RawTrainData],
    datasets: list[TorchTrainDataset],
    device: torch.device,
):
    self._dataloader = dataloader
    self._datasets = {dataset.id: dataset for dataset in datasets}
    self._device = device

__iter__()

Iterate over the dataloader, converting RawTrainData to TrainData.

Source code in src/ocean_emulators/datasets.py
def __iter__(self):
    """Iterate over the dataloader, converting RawTrainData to TrainData."""
    for raw_train_data in self._dataloader:
        dataset = self._datasets[raw_train_data.dataset_id]
        train_data = dataset.to_train_data(raw_train_data, self._device)
        yield train_data

__getitem__(index)

Access a single item by index, converting RawTrainData to TrainData.

Note: This bypasses the DataLoader's sampling/batching and directly accesses the underlying dataset for test purposes.

Source code in src/ocean_emulators/datasets.py
def __getitem__(self, index: int) -> TrainData:
    """Access a single item by index, converting RawTrainData to TrainData.

    Note: This bypasses the DataLoader's sampling/batching and directly accesses
    the underlying dataset for test purposes.
    """
    # Access the underlying dataset directly
    raw_train_data = self._dataloader.dataset[index]
    # Apply the collate function to add batch dimension (expects a list)
    collate_fn = self._dataloader.collate_fn
    if collate_fn is not None:
        raw_train_data = collate_fn([raw_train_data])
    # Get the dataset that created this raw data
    dataset = self._datasets[raw_train_data.dataset_id]
    # Convert to TrainData
    train_data = dataset.to_train_data(raw_train_data, self._device)
    return train_data