Skip to content

Stepper

ocean_emulators.stepper

Time-stepping primitives for training, validation, and inference.

Provides module-level functions that handle single-step forward passes (train_batch, validate_batch) and multi-step autoregressive rollouts (run_rollout).

run_rollout(model, dataset, inf_aggregator, epoch, output_dir=None, model_path=None, num_model_steps_forward=200, save_zarr=False, tensor_map=None, normalize=None)

Performs inference, which is an auto-regressive rollout.

Source code in src/ocean_emulators/stepper.py
@torch.no_grad()
def run_rollout(
    model: BaseModel,
    dataset: InferenceDataset,
    inf_aggregator: InferenceEvaluatorAggregator,
    epoch: int,
    output_dir: str | PathLike | None = None,
    model_path: str | PathLike | None = None,
    num_model_steps_forward: int = 200,
    save_zarr: bool = False,
    tensor_map: TensorMap | None = None,
    normalize: Normalize | None = None,
) -> None:
    """Performs inference, which is an auto-regressive rollout."""
    if save_zarr:
        if output_dir is None or model_path is None:
            raise ValueError(
                "output_dir and model_path must be provided if save_zarr is True"
            )
        if tensor_map is None or normalize is None:
            raise ValueError(
                "tensor_map and normalize must be provided if save_zarr is True"
            )
        coords = dataset.get_coords_dict()
        if num_model_steps_forward > 0:
            chunk_size = num_model_steps_forward
        else:
            chunk_size = 20
        writer = ZarrWriter(
            output_dir,
            coords=coords,
            hist=inf_aggregator.hist,
            model_path=model_path,
            time_chunk_size=chunk_size,
            normalize=normalize,
            tensor_map=tensor_map,
        )
    else:
        writer = None
    record_logs = get_record_to_wandb(label="inference")
    logger.info(f"Inference [epoch {epoch}]: processing initial prognostic.")
    logs = inf_aggregator.record_initial_prognostic(
        initial_prognostic=dataset.initial_prognostic.to(get_device()),
    )
    record_logs(logs)
    num_model_steps = len(dataset)
    num_steps_list = []

    # If num_model_steps_forward is -1, then we are doing a full forward pass
    if num_model_steps_forward == -1:
        num_steps_list = [num_model_steps]
    else:
        # Windows of partial forward passes
        num_loops = num_model_steps // num_model_steps_forward
        if num_loops > 0:
            num_steps_list = [num_model_steps_forward] * num_loops
            last_model_steps_forward = num_model_steps % num_model_steps_forward
            if last_model_steps_forward > 0:
                num_steps_list = num_steps_list + [last_model_steps_forward]
        else:
            num_steps_list = [num_model_steps]

    num_loops = len(num_steps_list)
    initial_prognostic = dataset.initial_prognostic
    step = 0
    for loop, num_steps in enumerate(num_steps_list):
        logger.info(
            f"Inference [epoch {epoch}]: loop {loop} of {num_loops - 1}. "
            f"Stepping {num_steps} steps forward."
        )
        dataset.to(get_device())
        IO: ModelInferenceOutput = model.inference(
            dataset,
            initial_prognostic=initial_prognostic,
            steps_completed=step,
            num_steps=num_steps,
            epoch=epoch,
        )
        # Setting initial prognostic for next loop
        initial_prognostic = IO.prediction[-1].unsqueeze(0).clone()
        if writer:
            logger.info("Writing to zarr...")
            writer.record_batch(IO)
            writer.write()

        logger.info("Recording logs...")
        logs = inf_aggregator.record_batch(IO)
        logger.info("Logging to wandb...")
        record_logs(logs)
        step += num_steps