Samplers¶
ocean_emulators.utils.samplers
¶
EquivalenceGroupBatchSampler(groups, batch_size, shuffle=True, drop_last=False)
¶
Bases: Sampler[Batch]
Groups indices into equivalence classes, batches within groups, and optionally shuffles.
This sampler partitions dataset indices into groups. It creates batches within each group, then chains them together. When shuffle=True, batches are globally shuffled each epoch to avoid sequential group processing.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
groups
|
list[list[int]]
|
List of index lists, where each inner list contains indices belonging to the same equivalence group. |
required |
batch_size
|
int
|
Number of samples per batch |
required |
shuffle
|
bool
|
Whether to shuffle indices within groups and shuffle batches globally |
True
|
drop_last
|
bool
|
Whether to drop incomplete batches at the end of each group |
False
|
Source code in src/ocean_emulators/utils/samplers.py
from_dataset_sizes(dataset_sizes, batch_size, shuffle=True, drop_last=False)
classmethod
¶
Create sampler from dataset sizes, treating each as a contiguous group.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dataset_sizes
|
list[int]
|
List of individual dataset sizes. Groups are created based on cumulative boundaries, where each dataset forms its own equivalence group. |
required |
batch_size
|
int
|
Number of samples per batch |
required |
shuffle
|
bool
|
Whether to shuffle indices within groups and shuffle batches globally |
True
|
drop_last
|
bool
|
Whether to drop incomplete batches at the end of each group |
False
|
Source code in src/ocean_emulators/utils/samplers.py
from_datasets(datasets, group_key, batch_size, shuffle, drop_last)
classmethod
¶
Create sampler by grouping datasets using a key function.
This factory method allows grouping datasets by arbitrary criteria (e.g., resolution, regardless of other parameters like stride). Datasets with the same key are batched together.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
datasets
|
list[TorchTrainDataset]
|
List of TorchTrainDataset instances to group |
required |
group_key
|
Callable[[TorchTrainDataset], Hashable]
|
Callable that extracts grouping key from a dataset. |
required |
batch_size
|
int
|
Number of samples per batch |
required |
shuffle
|
bool
|
Whether to shuffle indices within groups and shuffle batches globally |
required |
drop_last
|
bool
|
Whether to drop incomplete batches at the end of each group |
required |
Examples:
- lambda ds: (ds._input_src.data.sizes['lat'], ds._input_src.data.sizes['lon']) # group by resolution
- lambda ds: ds._input_src.data.sizes['lat'] # group by latitude size only
Returns:
| Type | Description |
|---|---|
Self
|
EquivalenceGroupBatchSampler configured to group by the provided key |
Example
Group datasets by resolution, allowing different strides to be batched together¶
sampler = EquivalenceGroupBatchSampler.from_datasets( ... datasets=dataset_list, ... group_key=lambda ds: tuple(prog.grid_size for prog in ds.prognostic_srcs), ... batch_size=32, ... shuffle=True, ... drop_last=True, ... )
Source code in src/ocean_emulators/utils/samplers.py
__len__()
¶
Calculate total number of batches across all groups.
DistributedEquivalenceGroupBatchSampler(datasets, group_key, batch_size, num_replicas, rank, shuffle=True, drop_last=False, seed=0)
¶
Bases: Sampler[Batch]
Distributed version of EquivalenceGroupBatchSampler for multi-GPU training.
Uses composition to delegate batching logic to EquivalenceGroupBatchSampler, handling only the distribution and epoch-based shuffling.
Ensures uniform batch counts across all ranks to prevent DDP hangs at collective sync points. Each equivalence group is chunked into logical DDP steps of num_replicas batches, so ranks process the same group at the same step. For an incomplete per-group step: - drop_last=True: drops that incomplete group step - drop_last=False: pads it by duplicating batches from the same group
Note: Compared to the non-distributed sampler: this one won't shuffle within batches, only between batches, when
shuffle=True.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
datasets
|
list[TorchTrainDataset]
|
List of TorchTrainDataset instances to group |
required |
group_key
|
Callable[[TorchTrainDataset], Hashable]
|
Callable that extracts grouping key from a dataset |
required |
batch_size
|
int
|
Number of samples per batch |
required |
num_replicas
|
int
|
Number of distributed workers (world size) |
required |
rank
|
int
|
Index of current worker (0 to num_replicas-1) |
required |
shuffle
|
bool
|
Whether to shuffle batches |
True
|
drop_last
|
bool
|
Whether to drop incomplete batches within each group, and whether to trim (vs pad) when distributing batches across ranks |
False
|
seed
|
int
|
Random seed for shuffling (default: 0) |
0
|
Source code in src/ocean_emulators/utils/samplers.py
set_epoch(epoch)
¶
__len__()
¶
Number of batches for this worker (same for all ranks).