Time-stepping primitives for training, validation, and inference.
Provides module-level functions that handle single-step forward passes
(train_batch, validate_batch) and multi-step autoregressive
rollouts (run_rollout).
run_rollout(model, dataset, inf_aggregator, epoch, output_dir=None, model_path=None, num_model_steps_forward=200, save_zarr=False, tensor_map=None, normalize=None)
Performs inference, which is an auto-regressive rollout.
Source code in src/ocean_emulators/stepper.py
| @torch.no_grad()
def run_rollout(
model: BaseModel,
dataset: InferenceDataset,
inf_aggregator: InferenceEvaluatorAggregator,
epoch: int,
output_dir: str | PathLike | None = None,
model_path: str | PathLike | None = None,
num_model_steps_forward: int = 200,
save_zarr: bool = False,
tensor_map: TensorMap | None = None,
normalize: Normalize | None = None,
) -> None:
"""Performs inference, which is an auto-regressive rollout."""
if save_zarr:
if output_dir is None or model_path is None:
raise ValueError(
"output_dir and model_path must be provided if save_zarr is True"
)
if tensor_map is None or normalize is None:
raise ValueError(
"tensor_map and normalize must be provided if save_zarr is True"
)
coords = dataset.get_coords_dict()
if num_model_steps_forward > 0:
chunk_size = num_model_steps_forward
else:
chunk_size = 20
writer = ZarrWriter(
output_dir,
coords=coords,
hist=inf_aggregator.hist,
model_path=model_path,
time_chunk_size=chunk_size,
normalize=normalize,
tensor_map=tensor_map,
)
else:
writer = None
record_logs = get_record_to_wandb(label="inference")
logger.info(f"Inference [epoch {epoch}]: processing initial prognostic.")
logs = inf_aggregator.record_initial_prognostic(
initial_prognostic=dataset.initial_prognostic.to(get_device()),
)
record_logs(logs)
num_model_steps = len(dataset)
num_steps_list = []
# If num_model_steps_forward is -1, then we are doing a full forward pass
if num_model_steps_forward == -1:
num_steps_list = [num_model_steps]
else:
# Windows of partial forward passes
num_loops = num_model_steps // num_model_steps_forward
if num_loops > 0:
num_steps_list = [num_model_steps_forward] * num_loops
last_model_steps_forward = num_model_steps % num_model_steps_forward
if last_model_steps_forward > 0:
num_steps_list = num_steps_list + [last_model_steps_forward]
else:
num_steps_list = [num_model_steps]
num_loops = len(num_steps_list)
initial_prognostic = dataset.initial_prognostic
step = 0
for loop, num_steps in enumerate(num_steps_list):
logger.info(
f"Inference [epoch {epoch}]: loop {loop} of {num_loops - 1}. "
f"Stepping {num_steps} steps forward."
)
dataset.to(get_device())
IO: ModelInferenceOutput = model.inference(
dataset,
initial_prognostic=initial_prognostic,
steps_completed=step,
num_steps=num_steps,
epoch=epoch,
)
# Setting initial prognostic for next loop
initial_prognostic = IO.prediction[-1].unsqueeze(0).clone()
if writer:
logger.info("Writing to zarr...")
writer.record_batch(IO)
writer.write()
logger.info("Recording logs...")
logs = inf_aggregator.record_batch(IO)
logger.info("Logging to wandb...")
record_logs(logs)
step += num_steps
|