Skip to content

Samplers

ocean_emulators.utils.samplers

EquivalenceGroupBatchSampler(groups, batch_size, shuffle=True, drop_last=False)

Bases: Sampler[Batch]

Groups indices into equivalence classes, batches within groups, and optionally shuffles.

This sampler partitions dataset indices into groups. It creates batches within each group, then chains them together. When shuffle=True, batches are globally shuffled each epoch to avoid sequential group processing.

Parameters:

Name Type Description Default
groups list[list[int]]

List of index lists, where each inner list contains indices belonging to the same equivalence group.

required
batch_size int

Number of samples per batch

required
shuffle bool

Whether to shuffle indices within groups and shuffle batches globally

True
drop_last bool

Whether to drop incomplete batches at the end of each group

False
Source code in src/ocean_emulators/utils/samplers.py
def __init__(
    self,
    groups: list[list[int]],
    batch_size: int,
    shuffle: bool = True,
    drop_last: bool = False,
):
    super().__init__()
    self.groups = groups
    self.batch_size = batch_size
    self.shuffle = shuffle
    self.drop_last = drop_last

    # Choose sampler based on shuffle setting
    SubsetSampler = SubsetRandomSampler if self.shuffle else _SimpleSubsetSampler

    self._samplers = [
        BatchSampler(
            SubsetSampler(group),
            batch_size=self.batch_size,
            drop_last=self.drop_last,
        )
        for group in self.groups
    ]

from_dataset_sizes(dataset_sizes, batch_size, shuffle=True, drop_last=False) classmethod

Create sampler from dataset sizes, treating each as a contiguous group.

Parameters:

Name Type Description Default
dataset_sizes list[int]

List of individual dataset sizes. Groups are created based on cumulative boundaries, where each dataset forms its own equivalence group.

required
batch_size int

Number of samples per batch

required
shuffle bool

Whether to shuffle indices within groups and shuffle batches globally

True
drop_last bool

Whether to drop incomplete batches at the end of each group

False
Source code in src/ocean_emulators/utils/samplers.py
@classmethod
def from_dataset_sizes(
    cls,
    dataset_sizes: list[int],
    batch_size: int,
    shuffle: bool = True,
    drop_last: bool = False,
) -> Self:
    """Create sampler from dataset sizes, treating each as a contiguous group.

    Args:
        dataset_sizes: List of individual dataset sizes. Groups are created based on
            cumulative boundaries, where each dataset forms its own equivalence group.
        batch_size: Number of samples per batch
        shuffle: Whether to shuffle indices within groups and shuffle batches globally
        drop_last: Whether to drop incomplete batches at the end of each group
    """
    cumsum = 0
    groups = []
    for size in dataset_sizes:
        groups.append(list(range(cumsum, cumsum + size)))
        cumsum += size
    return cls(groups, batch_size, shuffle, drop_last)

from_datasets(datasets, group_key, batch_size, shuffle, drop_last) classmethod

Create sampler by grouping datasets using a key function.

This factory method allows grouping datasets by arbitrary criteria (e.g., resolution, regardless of other parameters like stride). Datasets with the same key are batched together.

Parameters:

Name Type Description Default
datasets list[TorchTrainDataset]

List of TorchTrainDataset instances to group

required
group_key Callable[[TorchTrainDataset], Hashable]

Callable that extracts grouping key from a dataset.

required
batch_size int

Number of samples per batch

required
shuffle bool

Whether to shuffle indices within groups and shuffle batches globally

required
drop_last bool

Whether to drop incomplete batches at the end of each group

required

Examples:

  • lambda ds: (ds._input_src.data.sizes['lat'], ds._input_src.data.sizes['lon']) # group by resolution
  • lambda ds: ds._input_src.data.sizes['lat'] # group by latitude size only

Returns:

Type Description
Self

EquivalenceGroupBatchSampler configured to group by the provided key

Example
Group datasets by resolution, allowing different strides to be batched together

sampler = EquivalenceGroupBatchSampler.from_datasets( ... datasets=dataset_list, ... group_key=lambda ds: tuple(prog.grid_size for prog in ds.prognostic_srcs), ... batch_size=32, ... shuffle=True, ... drop_last=True, ... )

Source code in src/ocean_emulators/utils/samplers.py
@classmethod
def from_datasets(
    cls,
    datasets: list["TorchTrainDataset"],
    group_key: Callable[["TorchTrainDataset"], Hashable],
    batch_size: int,
    shuffle: bool,
    drop_last: bool,
) -> Self:
    """Create sampler by grouping datasets using a key function.

    This factory method allows grouping datasets by arbitrary criteria (e.g., resolution,
    regardless of other parameters like stride). Datasets with the same key are batched together.

    Args:
        datasets: List of TorchTrainDataset instances to group
        group_key: Callable that extracts grouping key from a dataset.
        batch_size: Number of samples per batch
        shuffle: Whether to shuffle indices within groups and shuffle batches globally
        drop_last: Whether to drop incomplete batches at the end of each group

    Examples:
            - lambda ds: (ds._input_src.data.sizes['lat'], ds._input_src.data.sizes['lon'])  # group by resolution
            - lambda ds: ds._input_src.data.sizes['lat']  # group by latitude size only

    Returns:
        EquivalenceGroupBatchSampler configured to group by the provided key

    Example:
        >>> # Group datasets by resolution, allowing different strides to be batched together
        >>> sampler = EquivalenceGroupBatchSampler.from_datasets(
        ...     datasets=dataset_list,
        ...     group_key=lambda ds: tuple(prog.grid_size for prog in ds.prognostic_srcs),
        ...     batch_size=32,
        ...     shuffle=True,
        ...     drop_last=True,
        ... )
    """
    from collections import defaultdict

    # Group indices by their key
    groups: dict[Hashable, list[int]] = defaultdict(list)

    cumsum = 0
    for ds in datasets:
        key = group_key(ds)
        assert isinstance(key, Hashable), "`group_key` must be hashable."
        groups[key].extend(range(cumsum, cumsum + len(ds)))
        cumsum += len(ds)

    # Sort by key for deterministic ordering across runs
    sorted_groups = sorted(groups.items(), key=lambda x: x[0])  # type: ignore
    group_indices = [indices for _, indices in sorted_groups]

    return cls(group_indices, batch_size, shuffle, drop_last)

__len__()

Calculate total number of batches across all groups.

Source code in src/ocean_emulators/utils/samplers.py
def __len__(self):
    """Calculate total number of batches across all groups."""
    total_batches = 0
    for sampler in self._samplers:
        total_batches += len(sampler)

    return total_batches

DistributedEquivalenceGroupBatchSampler(datasets, group_key, batch_size, num_replicas, rank, shuffle=True, drop_last=False, seed=0)

Bases: Sampler[Batch]

Distributed version of EquivalenceGroupBatchSampler for multi-GPU training.

Uses composition to delegate batching logic to EquivalenceGroupBatchSampler, handling only the distribution and epoch-based shuffling.

Ensures uniform batch counts across all ranks to prevent DDP hangs at collective sync points. Each equivalence group is chunked into logical DDP steps of num_replicas batches, so ranks process the same group at the same step. For an incomplete per-group step: - drop_last=True: drops that incomplete group step - drop_last=False: pads it by duplicating batches from the same group

Note: Compared to the non-distributed sampler: this one won't shuffle within batches, only between batches, when shuffle=True.

Parameters:

Name Type Description Default
datasets list[TorchTrainDataset]

List of TorchTrainDataset instances to group

required
group_key Callable[[TorchTrainDataset], Hashable]

Callable that extracts grouping key from a dataset

required
batch_size int

Number of samples per batch

required
num_replicas int

Number of distributed workers (world size)

required
rank int

Index of current worker (0 to num_replicas-1)

required
shuffle bool

Whether to shuffle batches

True
drop_last bool

Whether to drop incomplete batches within each group, and whether to trim (vs pad) when distributing batches across ranks

False
seed int

Random seed for shuffling (default: 0)

0
Source code in src/ocean_emulators/utils/samplers.py
def __init__(
    self,
    datasets: list["TorchTrainDataset"],
    group_key: Callable[["TorchTrainDataset"], Hashable],
    batch_size: int,
    num_replicas: int,
    rank: int,
    shuffle: bool = True,
    drop_last: bool = False,
    seed: int = 0,
):
    super().__init__()
    if num_replicas <= 0:
        raise ValueError(f"num_replicas must be positive, got {num_replicas}.")
    if rank >= num_replicas or rank < 0:
        raise ValueError(
            f"Invalid rank {rank}, must be in range [0, {num_replicas})"
        )

    self.num_replicas = num_replicas
    self.rank = rank
    self.shuffle = shuffle
    self.drop_last = drop_last
    self.seed = seed
    self.epoch = 0

    # Delegate batching logic to inner sampler (without shuffle for determinism)
    self._inner = EquivalenceGroupBatchSampler.from_datasets(
        datasets=datasets,
        group_key=group_key,
        batch_size=batch_size,
        shuffle=False,  # We handle shuffling with seeded RNG
        drop_last=drop_last,
    )

set_epoch(epoch)

Set the epoch for deterministic shuffling across workers.

Source code in src/ocean_emulators/utils/samplers.py
def set_epoch(self, epoch: int) -> None:
    """Set the epoch for deterministic shuffling across workers."""
    self.epoch = epoch

__len__()

Number of batches for this worker (same for all ranks).

Source code in src/ocean_emulators/utils/samplers.py
def __len__(self):
    """Number of batches for this worker (same for all ranks)."""
    n = self.num_replicas
    total = 0
    for sampler in self._inner._samplers:
        group_len = len(sampler)
        if self.drop_last:
            total += (group_len // n) * n
        else:
            total += math.ceil(group_len / n) * n
    if self.drop_last:
        return total // n
    else:
        return math.ceil(total / n)