Writes model prediction outputs to Zarr format for downstream analysis.
Source code in src/ocean_emulators/utils/writer.py
| def __init__(
self,
output_dir: str | os.PathLike,
coords: dict[str, xr.DataArray],
hist: int,
model_path: str | os.PathLike,
time_chunk_size: int,
normalize: Normalize,
tensor_map: TensorMap,
):
self.pred_path = os.path.join(output_dir, "predictions.zarr")
if os.path.exists(self.pred_path):
raise FileExistsError(
f"Predictions already exist at {self.pred_path}. Please choose a unique experiment name, output directory, or delete the existing predictions."
)
self.hist = hist
self.buffer: torch.Tensor | None = None
self.time_buffer: xr.DataArray | None = None
self.coords = coords
self.model_path = model_path
self.time_chunk_size = time_chunk_size
self.normalize = normalize
self.tensor_map = tensor_map
|