def __init__(self, cfg: EvalConfig) -> None:
cfg.prepare_output_dirs()
self.device = init_eval_backend(cfg.backend)
# Adjust workers and memory pinning based on device
data_num_workers = cfg.data.loading.num_pytorch_workers()
if not using_gpu():
data_num_workers = 0 # Disable multi-processing on CPU
elif cfg.disk_mode:
data_num_workers = torch.cuda.device_count() * data_num_workers
# Set seeds
set_seed(cfg.experiment.rand_seed)
# Getting prognostic and boundary variables
self.dataset_spec = cfg.data.dataset.build()
self.prognostic_var_names: PrognosticVarNames = (
self.dataset_spec.prognostic_var_names
)
self.boundary_var_names: BoundaryVarNames = self.dataset_spec.boundary_var_names
self.levels = self.dataset_spec.num_prognostic_depth_levels
str_prognostics = ", ".join([i for i in self.prognostic_var_names])
str_boundaries = ", ".join([i for i in self.boundary_var_names])
logger.info(f"Prognostic variables: {str_prognostics}")
logger.info(f"Boundary variables: {str_boundaries}")
logger.info(f"Levels: {self.levels}")
self.N_bound = len(self.boundary_var_names)
self.N_prog = len(self.prognostic_var_names)
self.num_prog_in = int((cfg.data.hist + 1) * self.N_prog)
self.num_boundary_in = int((cfg.data.hist + 1) * self.N_bound)
self.num_in = self.num_prog_in + self.num_boundary_in
self.num_out = self.num_prog_in
self.tensor_map = TensorMap(dataset_spec=self.dataset_spec).to(self.device)
logger.info(f"Number of inputs (prognostic + boundary): {self.num_in}")
logger.info(f"Number of outputs (prognostic): {self.num_out}")
# Dataloaders
logger.info(f"Loading data")
self.data_container = cfg.data.build(
cfg.experiment.resolved_data_root,
)
self.src = self.data_container.inference_source
self.data = self.src.data
self.static_data = self.data_container.static_data
self.metadata = self.src.metadata
self.wet = self.src.masks.prognostic_with_hist(cfg.data.hist)
self.area_weights: Grid = spherical_area_weights(self.data)
self.area_weights = self.area_weights.to(self.device)
self.normalize = Normalize(
self.src,
prognostic_var_names=self.prognostic_var_names,
boundary_var_names=self.boundary_var_names,
)
# Model
self.model = cfg.model.build(
prog_channels=self.num_prog_in,
boundary_channels=self.num_boundary_in,
out_channels=self.num_out,
hist=cfg.data.hist,
static_data_for_corrector=self.static_data,
srcs=self.data_container.sources,
tensor_map=self.tensor_map,
normalize=self.normalize,
dataset_spec=self.dataset_spec,
).to(self.device)
get_model_summary(self.model, None, cfg.debug)
if cfg.ckpt_path is None:
raise ValueError(
"ckpt_path must be set; try --ckpt_path=path/to/checkpoint"
)
self.load_checkpoint(cfg.ckpt_path)
self.network = self.model.__class__.__name__
# Initialize WandB
self.wandb_logger = WandBLogger.init_instance()
self.wandb_logger.configure(
cfg.experiment.wandb.mode == "online", is_main_process()
)
# Set up wandb run
self.wandb_id, self.wandb_name = self.wandb_logger.setup_run(
None, cfg, data_container=self.data_container, finetune=False
)
# Eval
self.hist = cfg.data.hist
self.output_dir = cfg.experiment.output_dir
self.debug = cfg.debug
self.num_workers = data_num_workers
self.inference_time = cfg.inference_time
self.num_model_steps_forward = cfg.num_model_steps_forward
self.save_zarr = cfg.save_zarr
self.model_path = cfg.ckpt_path
self.normalize_before_mask = cfg.data.normalize_before_mask
self.masked_fill_value = cfg.data.masked_fill_value
self.init_inference_store()