Skip to content

Eval

ocean_emulators.eval

Eval(cfg)

Evaluation pipeline for ocean emulator models.

Runs long autoregressive rollouts and computes metrics against ground-truth ocean states.

Source code in src/ocean_emulators/eval.py
def __init__(self, cfg: EvalConfig) -> None:
    cfg.prepare_output_dirs()

    self.device = init_eval_backend(cfg.backend)

    # Adjust workers and memory pinning based on device
    data_num_workers = cfg.data.loading.num_pytorch_workers()
    if not using_gpu():
        data_num_workers = 0  # Disable multi-processing on CPU
    elif cfg.disk_mode:
        data_num_workers = torch.cuda.device_count() * data_num_workers

    # Set seeds
    set_seed(cfg.experiment.rand_seed)

    # Getting prognostic and boundary variables
    self.dataset_spec = cfg.data.dataset.build()
    self.prognostic_var_names: PrognosticVarNames = (
        self.dataset_spec.prognostic_var_names
    )
    self.boundary_var_names: BoundaryVarNames = self.dataset_spec.boundary_var_names
    self.levels = self.dataset_spec.num_prognostic_depth_levels

    str_prognostics = ", ".join([i for i in self.prognostic_var_names])
    str_boundaries = ", ".join([i for i in self.boundary_var_names])

    logger.info(f"Prognostic variables: {str_prognostics}")
    logger.info(f"Boundary variables: {str_boundaries}")
    logger.info(f"Levels: {self.levels}")

    self.N_bound = len(self.boundary_var_names)
    self.N_prog = len(self.prognostic_var_names)

    self.num_prog_in = int((cfg.data.hist + 1) * self.N_prog)
    self.num_boundary_in = int((cfg.data.hist + 1) * self.N_bound)
    self.num_in = self.num_prog_in + self.num_boundary_in
    self.num_out = self.num_prog_in

    self.tensor_map = TensorMap(dataset_spec=self.dataset_spec).to(self.device)

    logger.info(f"Number of inputs (prognostic + boundary): {self.num_in}")
    logger.info(f"Number of outputs (prognostic): {self.num_out}")

    # Dataloaders
    logger.info(f"Loading data")
    self.data_container = cfg.data.build(
        cfg.experiment.resolved_data_root,
    )

    self.src = self.data_container.inference_source
    self.data = self.src.data
    self.static_data = self.data_container.static_data
    self.metadata = self.src.metadata
    self.wet = self.src.masks.prognostic_with_hist(cfg.data.hist)
    self.area_weights: Grid = spherical_area_weights(self.data)
    self.area_weights = self.area_weights.to(self.device)

    self.normalize = Normalize(
        self.src,
        prognostic_var_names=self.prognostic_var_names,
        boundary_var_names=self.boundary_var_names,
    )

    # Model
    self.model = cfg.model.build(
        prog_channels=self.num_prog_in,
        boundary_channels=self.num_boundary_in,
        out_channels=self.num_out,
        hist=cfg.data.hist,
        static_data_for_corrector=self.static_data,
        srcs=self.data_container.sources,
        tensor_map=self.tensor_map,
        normalize=self.normalize,
        dataset_spec=self.dataset_spec,
    ).to(self.device)

    get_model_summary(self.model, None, cfg.debug)

    if cfg.ckpt_path is None:
        raise ValueError(
            "ckpt_path must be set; try --ckpt_path=path/to/checkpoint"
        )
    self.load_checkpoint(cfg.ckpt_path)

    self.network = self.model.__class__.__name__

    # Initialize WandB
    self.wandb_logger = WandBLogger.init_instance()
    self.wandb_logger.configure(
        cfg.experiment.wandb.mode == "online", is_main_process()
    )

    # Set up wandb run
    self.wandb_id, self.wandb_name = self.wandb_logger.setup_run(
        None, cfg, data_container=self.data_container, finetune=False
    )

    # Eval
    self.hist = cfg.data.hist
    self.output_dir = cfg.experiment.output_dir
    self.debug = cfg.debug
    self.num_workers = data_num_workers
    self.inference_time = cfg.inference_time
    self.num_model_steps_forward = cfg.num_model_steps_forward
    self.save_zarr = cfg.save_zarr
    self.model_path = cfg.ckpt_path
    self.normalize_before_mask = cfg.data.normalize_before_mask
    self.masked_fill_value = cfg.data.masked_fill_value
    self.init_inference_store()