Datasets¶
ocean_emulators.datasets
¶
InferenceDataset(src, prognostic_var_names, boundary_var_names, hist, normalize_before_mask, masked_fill_value, long_rollout)
¶
Bases: Dataset
This class is used for inference rollouts.
It creates rolling indices to keep track of histories/past states. For example, Hist=0 ; 0->[0, 1]; 1->[1, 2]; 2->[2, 3]; 3->[3, 4] Hist=1 ; 0->[[0, 1], [2, 3]]; 1->[[2, 3], [4, 5]]; 2->[[4, 5], [6, 7]]; 3->[[6, 7], [8, 9]] Hist=2 ; 0->[[0, 1, 2], [3, 4, 5]]; 1->[[3, 4, 5], [6, 7, 8]]; 2->[[6, 7, 8], [9, 10, 11]]; 3->[[9, 10, 11], [12, 13, 14]]
Source code in src/ocean_emulators/datasets.py
to(device)
¶
Move the dataset's context tensors to the specified device.
Call this before using the dataset for inference to ensure tensors are on the correct device (GPU).
Source code in src/ocean_emulators/datasets.py
get_boundary(step)
¶
RawTrainData(dataset_id)
¶
Source code in src/ocean_emulators/datasets.py
insert(input_, boundary, label)
¶
Add a prognostic input, boundary, and prognostic label as the last step.
TrainData(num_prognostic_channels, num_boundary_channels, ctx)
¶
A single batch of training data.
A single batch contains multiple steps worth of Example entries, each
of which is a (prognostic_input, boundary_input, label) triple. The
prognostic and boundary tensors are carried separately because the FOMO model
encodes them separately (Samudra just concatenates them later).
Source code in src/ocean_emulators/datasets.py
append(prognostic_input, boundary_input, label)
¶
Add another Example as a new step.
TorchTrainDataset(src, dst, prognostic_var_names, boundary_var_names, hist, steps, normalize_before_mask, masked_fill_value, stride=1, concurrent_compute_=False)
¶
Bases: Dataset[RawTrainData]
This class is used for training and validation.
It creates rolling indices to keep track of histories/past states. But different from InferenceDataset, as it creates rolling indices based on stride. By default, the sliding window / stride is 1.
We make use of TrainData class to store a single sample.
For example, Hist=0 ; TD: step=0->[0, 1]; step=1->[1, 2]; step=2->[2, 3]; step=3->[3, 4] Hist=1 ; TD: step=0->[[0, 1], [2, 3]]; step=1->[[2, 3], [4, 5]]; step=2->[[4, 5], [6, 7]]; step=3->[[6, 7], [8, 9]] Hist=2 ; TD: step=0->[[0, 1, 2], [3, 4, 5]]; step=1->[[3, 4, 5], [6, 7, 8]]; step=2->[[6, 7, 8], [9, 10, 11]]; step=3->[[9, 10, 11], [12, 13, 14]]
Source code in src/ocean_emulators/datasets.py
to_train_data(raw_train_data, device)
¶
Convert RawTrainData to TrainData, moving tensors to the specified device.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
raw_train_data
|
RawTrainData
|
CPU data from worker process |
required |
device
|
device
|
Target device (typically GPU) to move tensors to |
required |
Returns:
| Type | Description |
|---|---|
TrainData
|
TrainData with tensors on the target device |
Source code in src/ocean_emulators/datasets.py
TrainDataLoader(dataloader, datasets, device)
¶
Wrapper around a torch DataLoader that handles GPU post-processing.
This class wraps a DataLoader[RawTrainData] and converts the raw data to TrainData by applying GPU-based normalization and masking. This allows the data loading process to handle I/O while the main process handles GPU operations.
Since the data samples flow from one process to the other, we want to tie
them back to the dataset they came from which knows how to do that second
half once they're in the main process which has GPU access set up. To do that,
each data sample (which could come from a different dataset) has a dataset ID
-- datasets maps from those IDs to the original datasets.
Source code in src/ocean_emulators/datasets.py
__iter__()
¶
Iterate over the dataloader, converting RawTrainData to TrainData.
Source code in src/ocean_emulators/datasets.py
__getitem__(index)
¶
Access a single item by index, converting RawTrainData to TrainData.
Note: This bypasses the DataLoader's sampling/batching and directly accesses the underlying dataset for test purposes.