Skip to content

Train

ocean_emulators.train

Trainer(cfg)

Orchestrates the full model training loop.

Handles initialization, distributed setup, checkpointing, learning rate scheduling, EMA, and Weights & Biases logging.

Source code in src/ocean_emulators/train.py
def __init__(self, cfg: TrainConfig) -> None:
    cfg.prepare_output_dirs()
    cfg.save_yaml(cfg.experiment.output_dir / "config.yaml")

    # Backend
    self.device, self.distributed = init_train_backend(cfg.backend)

    # Adjust workers and memory pinning based on device
    if not using_gpu():
        cfg.pin_mem = False
    elif cfg.disk_mode:
        cfg.pin_mem = True

    # Distributed mode
    dask.config.set(scheduler="synchronous")

    # Set seeds
    set_seed(cfg.experiment.rand_seed)

    # Getting prognostic and boundary variables
    self.dataset_spec = cfg.data.dataset.build()
    self.prognostic_var_names: PrognosticVarNames = (
        self.dataset_spec.prognostic_var_names
    )
    self.boundary_var_names: BoundaryVarNames = self.dataset_spec.boundary_var_names
    self.levels = self.dataset_spec.num_prognostic_depth_levels

    str_prognostics = ", ".join([i for i in self.prognostic_var_names])
    str_boundaries = ", ".join([i for i in self.boundary_var_names])

    logger.info(f"Prognostic variables: {str_prognostics}")
    logger.info(f"Boundary variables: {str_boundaries}")
    logger.info(f"Levels: {self.levels}")

    self.N_bound = len(self.boundary_var_names)
    self.N_prog = len(self.prognostic_var_names)

    self.data_container = cfg.data.build(
        data_root=cfg.experiment.resolved_data_root,
    )
    self.train_schedule: TrainSchedule = cfg.experiment.train_schedule
    if self.train_schedule == "mix" and cfg.model.pred_residuals:
        raise ValueError(
            "Residual predictions on a mixed multiscale training schedule is not currently supported."
        )
    if self.train_schedule == "mix" and any(step > 1 for step in cfg.steps):
        raise ValueError(
            "Step predictions on a mixed multiscale training schedule is not currently supported."
        )

    data_num_workers = cfg.data.loading.num_pytorch_workers()
    persistent_workers = cfg.data.loading.persistent_pytorch_workers()

    self.mp_context: BaseContext | None = None
    if data_num_workers > 0:
        if self.data_container.supports_fork:
            self.mp_context = multiprocessing.get_context("fork")
        else:
            self.mp_context = multiprocessing.get_context("spawn")

    self.num_prog_in = int((cfg.data.hist + 1) * self.N_prog)
    self.num_boundary_in = int((cfg.data.hist + 1) * self.N_bound)
    self.num_in = self.num_prog_in + self.num_boundary_in
    self.num_out = self.num_prog_in

    self.tensor_map = TensorMap(dataset_spec=self.dataset_spec).to(self.device)

    logger.info(f"Number of inputs (prognostic + boundary): {self.num_in}")
    logger.info(f"Number of outputs (prognostic): {self.num_out}")

    assert isinstance(cfg.data_stride, list)
    assert isinstance(cfg.steps, list)
    assert isinstance(cfg.step_transition, list)
    assert len(cfg.step_transition) == len(cfg.steps) - 1
    max_steps = str(cfg.steps[-1])
    self.str_video = "steps_" + max_steps + "_" + "_Lateral_Data_025_no_smooth"

    # Dataloaders
    logger.info(f"Loading data")
    if cfg.train_time.overlaps(cfg.val_time):
        raise ValueError(
            f"Training time range {cfg.train_time} overlaps "
            f"with validation time range {cfg.val_time}"
        )

    self.concurrent_compute = cfg.data.concurrent_compute

    self.primary_src = self.data_container.primary_source

    # We use dask for inference since it has memory issues otherwise.
    # TODO(jder): Could rewrite inference dataset like we did for TorchTrainDataset
    # see https://github.com/suryadheeshjith/Ocean_Emulator/issues/208
    self.inference_src = self.data_container.inference_source

    self.loader_version = self.data_container.loader_version

    # This is used by both the aggregator and corrector. It only works at a single scale.
    self.normalize = Normalize(
        self.primary_src,
        prognostic_var_names=self.prognostic_var_names,
        boundary_var_names=self.boundary_var_names,
    )

    self.model = cfg.model.build(
        prog_channels=self.num_prog_in,
        boundary_channels=self.num_boundary_in,
        out_channels=self.num_out,
        hist=cfg.data.hist,
        # TODO(559): This won't work at multiple scales. Refactor as part of src.
        static_data_for_corrector=self.data_container.static_data,
        srcs=self.data_container.sources,
        tensor_map=self.tensor_map,
        normalize=self.normalize,
        dataset_spec=self.dataset_spec,
    ).to(self.device)

    self.nets_dir = cfg.experiment.nets_dir
    self.network = self.model.__class__.__name__

    # Loss function
    self.loss_fn: LossFnWithContext = build_loss_fn(
        cfg.loss,
        device=self.device,
        num_channels=self.N_prog,
        pad_mode=cfg.model.pad,
    )

    # Optimizer
    self.optimizer = torch.optim.Adam(self.model.parameters(), lr=cfg.learning_rate)

    # Scheduler
    self.scheduler = None
    if cfg.scheduler:
        self.scheduler = cfg.scheduler.build(self.optimizer, cfg.epochs)

    # Initialize WandB
    self.wandb_logger = WandBLogger.init_instance()
    self.wandb_logger.configure(
        cfg.experiment.wandb.mode == "online", is_main_process()
    )

    self.ckpt_paths = CheckpointPaths(self.nets_dir)

    # Check for preemption
    if cfg.preemptible:
        assert not cfg.finetune, "Finetune is not supported with preemptible"
        preempted = os.path.isfile(self.ckpt_paths.latest_checkpoint_path)
        if preempted:
            cfg.resume_ckpt_path = str(self.ckpt_paths.latest_checkpoint_path)

    # Set up wandb run
    self.wandb_id, self.wandb_name = self.wandb_logger.setup_run(
        cfg.resume_ckpt_path,
        cfg,
        data_container=self.data_container,
        finetune=cfg.finetune,
    )

    # Log effective batch size
    effective_batch_size = cfg.batch_size * cfg.gradient_accumulation_steps
    logger.info(
        f"Effective batch size: {effective_batch_size} "
        f"(batch_size={cfg.batch_size} × "
        f"gradient_accumulation_steps={cfg.gradient_accumulation_steps})"
    )
    if self.is_wandb_enabled():
        self.wandb_logger.log(
            {
                "config/effective_batch_size": effective_batch_size,
            },
            step=0,
        )

    self.num_batches_seen = 0
    loaded_checkpoint = False
    if cfg.resume_ckpt_path is not None:
        if cfg.finetune:
            self.load_checkpoint(cfg.resume_ckpt_path, finetune=True)
            self.start_epoch = 1
        else:
            self.load_checkpoint(cfg.resume_ckpt_path)
            if not self.wandb_logger.enabled and is_main_process():
                warnings.warn(
                    "This checkpoint had wandb enabled, \
                        but wandb is not enabled now!"
                )
        loaded_checkpoint = True
    else:
        self.start_epoch = 1

    # Modify DDP setup based on device
    if self.distributed is not None:
        self.model = nn.parallel.DistributedDataParallel(
            nn.SyncBatchNorm.convert_sync_batchnorm(self.model),
            device_ids=[self.distributed.gpu],
        )

    # EMA (must come after DDP setup so parameter names match final self.model)
    if not loaded_checkpoint:
        self._ema = EMATracker(
            self.model,
            decay=cfg.ema_decay,
            faster_decay_at_start=cfg.faster_decay_at_start,
        )

    # Training
    self.epochs = cfg.epochs
    self.test_using_ema = cfg.test_using_ema
    self.hist: int = cfg.data.hist
    self.steps = cfg.steps
    self.step_transition = cfg.step_transition
    self.save_freq = cfg.save_freq
    self.validation_image_log_freq = cfg.validation_image_log_freq
    self.output_dir = cfg.experiment.output_dir
    self.debug = cfg.debug
    self.data_stride: list[int] = cfg.data_stride
    self.batch_size: int = cfg.batch_size
    self.gradient_accumulation_steps: int = cfg.gradient_accumulation_steps
    self.num_workers: int = data_num_workers
    self.persistent_workers: bool = persistent_workers
    self.pin_mem: bool = cfg.pin_mem
    self.train_time: config.TimeConfig = cfg.train_time
    self.val_time = cfg.val_time
    self.inference_times = cfg.inference_times
    self.inference_epochs = cfg.inference_epochs
    self.max_train_model_steps_forward = MAX_TRAIN_MODEL_STEPS_FORWARD // (
        self.hist + 1
    )
    self.normalize_before_mask: bool = cfg.data.normalize_before_mask
    self.normalize_fill_value: float = cfg.data.masked_fill_value
    self.delayed_loss_estimate: bool = cfg.delayed_loss_estimate

    self.profiler = cfg.profiler.build(self.output_dir, self.device)
    self.validation_images_enabled = self._sync_flag_from_main(
        self.wandb_logger.enabled
    )

    assert self.tensor_map is not None

    if self.inference_epochs:
        self.init_inference_stores()

    # Add type annotations for samplers
    self.train_sampler: (
        EquivalenceGroupBatchSampler | DistributedEquivalenceGroupBatchSampler
    )
    self.val_sampler: (
        EquivalenceGroupBatchSampler | DistributedEquivalenceGroupBatchSampler
    )
    self.inference_sampler: DistributedSampler | RandomSampler

    # Add type annotations for loaders
    self.train_loader: TrainDataLoader
    self.val_loader: TrainDataLoader
    self.inference_loader: DataLoader[TrainData]

get_current_step(epoch)

Determine the current step based on the epoch and transition points.

Parameters:

Name Type Description Default
epoch int

Current epoch number

required

Returns:

Name Type Description
tuple

(current_step, current_step_idx)

Source code in src/ocean_emulators/train.py
def get_current_step(self, epoch):
    """Determine the current step based on the epoch and transition points.

    Args:
        epoch (int): Current epoch number

    Returns:
        tuple: (current_step, current_step_idx)
    """
    if epoch == self.start_epoch:
        # Find initial step based on start epoch
        cur_step = None
        cur_step_idx = None
        for i, epoch_to_transition in enumerate(self.step_transition):
            if epoch <= epoch_to_transition:
                cur_step = self.steps[i]
                cur_step_idx = i
                break
        if cur_step is None:
            cur_step = self.steps[-1]
            cur_step_idx = len(self.steps) - 1
        logger.info(f"Starting training at step {cur_step}")
    else:
        # Transition to next step
        cur_step_idx = next(
            i for i, e in enumerate(self.step_transition) if e == epoch
        )
        cur_step_idx += 1
        cur_step = self.steps[cur_step_idx]
        logger.info(f"Transitioning to step {cur_step}")

    return cur_step

init_data_loaders(cur_step)

Initialize training and validation data loaders.

Parameters:

Name Type Description Default
cur_step int

Current training step size

required
Source code in src/ocean_emulators/train.py
def init_data_loaders(self, cur_step: int) -> None:
    """Initialize training and validation data loaders.

    Args:
        cur_step: Current training step size
    """
    scales = self.data_container.sources
    match self.train_schedule:
        case "standard":
            srcs: Iterable[tuple[DataSource, DataSource | None]] = [
                (scales[0], None)
            ]
        case "match":
            srcs = [(s, s) for s in scales]
        case "mix":
            srcs = list(itertools.product(scales, repeat=2))  # type: ignore
        case _:
            assert_never(self.train_schedule)

    train_datasets = [
        TorchTrainDataset(
            src=src.slice(self.train_time),
            dst=dst.slice(self.train_time) if dst else None,
            prognostic_var_names=self.prognostic_var_names,
            boundary_var_names=self.boundary_var_names,
            hist=self.hist,
            steps=cur_step,
            normalize_before_mask=self.normalize_before_mask,
            masked_fill_value=self.normalize_fill_value,
            stride=stride,
            concurrent_compute_=self.concurrent_compute,
        )
        for stride in self.data_stride
        for src, dst in srcs
    ]

    val_datasets = [
        TorchTrainDataset(
            src=src.slice(self.val_time),
            dst=dst.slice(self.val_time) if dst else None,
            prognostic_var_names=self.prognostic_var_names,
            boundary_var_names=self.boundary_var_names,
            hist=self.hist,
            steps=1,  # current_step set to 1 for validation
            normalize_before_mask=self.normalize_before_mask,
            masked_fill_value=self.normalize_fill_value,
            stride=stride,
            concurrent_compute_=self.concurrent_compute,
        )
        for stride in self.data_stride
        for src, dst in srcs
    ]

    # Create datasets
    match self.loader_version:
        case TorchTrainDataset.FLAG:
            train_data: torch.utils.data.Dataset[RawTrainData] = ConcatDataset(
                train_datasets
            )

            val_data: torch.utils.data.Dataset[RawTrainData] = ConcatDataset(
                val_datasets
            )

        case _:
            raise NotImplementedError(
                f"Loader version {self.loader_version} not supported."
            )

    logger.info("Instantiating torch loaders")

    match self.loader_version:
        case TorchTrainDataset.FLAG:
            collate_fn = collate_raw_train_data
        case _:
            raise NotImplementedError(
                f"Collate function not defined for loader version "
                f"{self.loader_version}"
            )

    # Create batch samplers - branch on distributed vs non-distributed
    # Group by input AND label resolution to handle all training schedules
    def group_key(ds):
        return tuple(prog.grid_size for prog in ds.prognostic_srcs)

    if self.distributed is not None:
        # Distributed training
        assert self.distributed.world_size is not None
        assert self.distributed.rank is not None
        train_batch_sampler = DistributedEquivalenceGroupBatchSampler(
            datasets=train_datasets,
            group_key=group_key,
            batch_size=self.batch_size,
            num_replicas=self.distributed.world_size,
            rank=self.distributed.rank,
            shuffle=True,
            drop_last=True,
        )

        val_batch_sampler = DistributedEquivalenceGroupBatchSampler(
            datasets=val_datasets,
            group_key=group_key,
            batch_size=self.batch_size,
            num_replicas=self.distributed.world_size,
            rank=self.distributed.rank,
            shuffle=False,
            drop_last=False,
        )
    else:
        # Non-distributed training
        train_batch_sampler = EquivalenceGroupBatchSampler.from_datasets(  # type: ignore
            datasets=train_datasets,
            group_key=group_key,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=True,
        )

        val_batch_sampler = EquivalenceGroupBatchSampler.from_datasets(  # type: ignore
            datasets=val_datasets,
            group_key=group_key,
            batch_size=self.batch_size,
            shuffle=True,
            drop_last=False,
        )

    # Store samplers for set_epoch calls
    self.train_sampler = train_batch_sampler
    self.val_sampler = val_batch_sampler

    # Create data loaders (same for both distributed and non-distributed)
    # When using batch_sampler, don't specify batch_size or sampler
    train_dataloader = DataLoader(
        train_data,
        batch_sampler=train_batch_sampler,
        num_workers=self.num_workers,
        persistent_workers=self.persistent_workers and self.num_workers > 0,
        pin_memory=self.pin_mem,
        collate_fn=collate_fn,
        multiprocessing_context=self.mp_context,
    )

    val_dataloader = DataLoader(
        val_data,
        batch_sampler=val_batch_sampler,
        num_workers=self.num_workers,
        persistent_workers=self.persistent_workers and self.num_workers > 0,
        pin_memory=self.pin_mem,
        collate_fn=collate_fn,
        multiprocessing_context=self.mp_context,
    )

    # Wrap dataloaders to handle GPU post-processing
    self.train_loader = TrainDataLoader(
        train_dataloader, train_datasets, self.device
    )
    self.val_loader = TrainDataLoader(val_dataloader, val_datasets, self.device)

should_log_validation_images(epoch, frequency)

Return whether to log validation images for a 1-based training epoch.

Source code in src/ocean_emulators/train.py
def should_log_validation_images(epoch: int, frequency: int) -> bool:
    """Return whether to log validation images for a 1-based training epoch."""
    if epoch < 1:
        raise ValueError(f"Epoch must be >= 1, got {epoch}")
    if frequency < 1:
        raise ValueError(
            f"Validation image log frequency must be >= 1, got {frequency}"
        )
    return (epoch - 1) % frequency == 0