Skip to content

Visualization

Core

ocean_emulators.viz.core

Viz(output_path, dataset_name, runs, basins, groundtruth_rollout, time_range)

Generates maps, time series, and probability density plots from evaluation outputs.

Source code in src/ocean_emulators/viz/core.py
def __init__(
    self,
    output_path: str,
    dataset_name: str,
    runs: list["VizRun"],
    basins: xr.Dataset,
    groundtruth_rollout: xr.Dataset,
    time_range: slice,
):
    pred_dict: dict[str, dict[str, Any]] = {}
    for run in runs:
        pred_dict[run.name] = {
            "name": run.name,
            "data": run.data,
            "ls": run.variables,
        }

    key1 = runs[0].name
    # TODO: Support non-OM4 dataset specs in visualization.
    self.dataset_spec = build_om4_spec()
    levels = len(self.dataset_spec.depth_levels)

    groundtruth_rollout = groundtruth_rollout.sel(time=time_range)

    if "y" in groundtruth_rollout.coords:
        groundtruth_rollout = groundtruth_rollout.drop_vars(
            ["lat", "lon"], errors="ignore"
        )
        groundtruth_rollout = groundtruth_rollout.rename({"y": "lat", "x": "lon"})

    groundtruth_rollout = groundtruth_rollout.assign(
        areacello=(["lat", "lon"], spherical_area_weights(groundtruth_rollout))
    )

    # Compute real grid cell areas for physical calculations
    groundtruth_rollout["areacello_spherical"] = (
        ["lat", "lon"],
        spherical_area(groundtruth_rollout),
    )

    # This function processes the ds_groundtruth and predictions for plotting
    # The predictions are loaded into pred_dict
    data, pred_dict = process_data(
        groundtruth_rollout, pred_dict, dataset_spec=self.dataset_spec
    )

    last_index = len(data.time) - 1
    self.time_indices = [0, last_index // 2, last_index]

    var_list = {
        "vo": r"$v$ $( m/s )$",
        "uo": r"$u$ $( m/s )$",
        "thetao": r"$T$ $( ^\circ C )$",
        "tos": r"$T$ $( ^\circ C )$",
        "so": r"$so$ $( psu )$",
        "zos": r"$zos$ $( m )$",
        "KE": r"$KE$ $( J/m^2 )$",
        "OHC": r"$OHC$ $Anomaly$ $( ZJ )$",
    }

    # Create folder paths
    self.timeseries_path = os.path.join(output_path, f"Timeseries")
    if not os.path.isdir(self.timeseries_path):
        os.makedirs(self.timeseries_path)

    self.ohc_path = os.path.join(output_path, f"OHC")
    if not os.path.isdir(self.ohc_path):
        os.makedirs(self.ohc_path)

    self.temp_path = os.path.join(output_path, f"Temperature")
    if not os.path.isdir(self.temp_path):
        os.makedirs(self.temp_path)

    self.salinity_path = os.path.join(output_path, f"Salinity")
    if not os.path.isdir(self.salinity_path):
        os.makedirs(self.salinity_path)

    self.pdfs_path = os.path.join(output_path, f"PDFs")
    if not os.path.isdir(self.pdfs_path):
        os.makedirs(self.pdfs_path)

    self.enso_path = os.path.join(output_path, f"ENSO")
    if not os.path.isdir(self.enso_path):
        os.makedirs(self.enso_path)

    self.metrics_path = os.path.join(output_path, f"Metrics")
    if not os.path.isdir(self.metrics_path):
        os.makedirs(self.metrics_path)

    self.movie_path = os.path.join(output_path, f"Movies")
    if not os.path.isdir(self.movie_path):
        os.makedirs(self.movie_path)

    clist = ["#ff807a", "#1e8685", "#ffb579", "#63c8ab"]

    atlantic_mask0 = basins["basin_atlantic"]
    atlantic_mask = atlantic_mask0.where(atlantic_mask0["lat"] >= -32)
    atlantic_mask = process_mask(data, atlantic_mask)
    pacific_mask0 = basins["basin_pacific"]
    pacific_mask = pacific_mask0.where(
        pacific_mask0["lat"] >= -32
    )  # TODO: include this -32 masking in the basin data
    pacific_mask = process_mask(data, pacific_mask)
    indian_ocean_mask0 = basins["basin_indian"]
    indian_ocean_mask = indian_ocean_mask0.where(indian_ocean_mask0["lat"] >= -32)
    indian_ocean_mask = process_mask(data, indian_ocean_mask)
    southern_ocean_mask0 = basins["basin_southern"]
    southern_ocean_mask = process_mask(data, southern_ocean_mask0)
    arctic_mask0 = basins["basin_arctic"]
    arctic_ocean_mask = process_mask(data, arctic_mask0)

    self.basin_masks = xr.Dataset(
        {
            "Atlantic": atlantic_mask,
            "Pacific": pacific_mask,
            "Southern": southern_ocean_mask,
            "Indian": indian_ocean_mask,
            "Arctic": arctic_ocean_mask,
        }
    )

    # Compute profile means
    with ProgressBar():
        logger.info("Computing profile for ground truth " + dataset_name)
        profile_groundtruth = profile_mean(data).load()

        for k in pred_dict.keys():
            logger.info("Computing profile for prediction " + k)
            pred_dict[k]["profile_prediction"] = profile_mean(
                pred_dict[k]["ds_prediction"]
            ).load()

    self.data: xr.Dataset = data
    self.profile_groundtruth: xr.Dataset = profile_groundtruth
    self.pred_dict: dict[str, dict[str, Any]] = pred_dict
    self.dataset_name: str = dataset_name
    self.clist: list[str] = clist
    self.var_list: dict[str, str] = var_list
    self.levels: int = levels
    self.key1: str = key1
    self.output_path: str = output_path

step_create_ohc_salinity_slopes_table()

Create a CSV table with OHC and salinity slopes.

Source code in src/ocean_emulators/viz/core.py
def step_create_ohc_salinity_slopes_table(self):
    """Create a CSV table with OHC and salinity slopes."""
    GT_ohc_slope = self.linear_fit(self.ohc_anomaly_global(self.data))[0]
    GT_salinity_slope = self.linear_fit(self.salinity_global(self.data))[0]

    pd_data = []
    pd_data.append(
        {
            "Model": self.dataset_name,
            "OHC": GT_ohc_slope,
            "Salinity": GT_salinity_slope,
        }
    )

    for k in self.pred_dict.keys():
        pd_data.append(
            {
                "Model": self.pred_dict[k]["name"],
                "OHC": self.pred_dict[k]["OHC_slope"],
                "OHC Slope Ratio": self.pred_dict[k]["OHC_slope"] / GT_ohc_slope,
                "Salinity": self.pred_dict[k]["salinity_slope"],
                "Salinity Slope Ratio": (
                    self.pred_dict[k]["salinity_slope"] / GT_salinity_slope
                ),
            }
        )

    # Create a DataFrame
    df = pd.DataFrame(pd_data)

    # Define the file path
    file_path = os.path.join(self.output_path, "ohc_salinity_slopes_table.csv")

    # Save the DataFrame to a CSV file
    df.to_csv(file_path, index=False)

isnan(x)

Wrapped around np.isnan which correctly reflects the type we use it on.

Source code in src/ocean_emulators/viz/core.py
def isnan(x: xr.DataArray) -> xr.DataArray:
    """Wrapped around np.isnan which correctly reflects the type we use it on."""
    return np.isnan(x)  # type: ignore

combine_variables_by_level(ds_groundtruth, pred_dict, dataset_spec)

Combine variables by level for ground truth and predictions.

Parameters: ds_groundtruth (xarray.Dataset): The ground truth dataset. pred_dict (dict): Dictionary containing prediction datasets.

Returns: xarray.Dataset, dict: Updated ground truth and prediction datasets.

Source code in src/ocean_emulators/viz/core.py
def combine_variables_by_level(
    ds_groundtruth: xr.Dataset,
    pred_dict: dict[str, dict[str, Any]],
    dataset_spec: DatasetSpec,
) -> tuple[xr.Dataset, dict[str, dict[str, Any]]]:
    """
    Combine variables by level for ground truth and predictions.

    Parameters:
    ds_groundtruth (xarray.Dataset): The ground truth dataset.
    pred_dict (dict): Dictionary containing prediction datasets.

    Returns:
    xarray.Dataset, dict: Updated ground truth and prediction datasets.
    """
    ds_groundtruth = _combine_variables_by_level(
        ds_groundtruth, ["thetao", "so", "uo", "vo", "mask"], dataset_spec
    )
    for key in pred_dict.keys():
        pred_dict[key]["ds_prediction"] = _combine_variables_by_level(
            pred_dict[key]["ds_prediction"], pred_dict[key]["ls"], dataset_spec
        )
    return ds_groundtruth, pred_dict

postprocess_for_plot(ds_groundtruth, areacello, dz, pred_dict)

Postprocess for plotting.

Parameters: ds_groundtruth (xarray.Dataset): The ground truth dataset. areacello (xarray.DataArray): areacello dataarray. pred_dict (dict): Dictionary containing prediction datasets.

Returns: xarray.Dataset, dict: Postprocessed ground truth and prediction datasets.

Source code in src/ocean_emulators/viz/core.py
def postprocess_for_plot(
    ds_groundtruth, areacello: xr.DataArray, dz: np.ndarray, pred_dict
):
    """
    Postprocess for plotting.

    Parameters:
    ds_groundtruth (xarray.Dataset): The ground truth dataset.
    areacello (xarray.DataArray): areacello dataarray.
    pred_dict (dict): Dictionary containing prediction datasets.

    Returns:
    xarray.Dataset, dict: Postprocessed ground truth and prediction datasets.
    """
    areacello_values = areacello.values
    times = ds_groundtruth.time
    areacello_spherical_values = ds_groundtruth["areacello_spherical"].values

    # Masking land with NaNs
    if "mask" in ds_groundtruth.data_vars:
        wetmask = ds_groundtruth["mask"].isel(
            time=0, missing_dims="ignore"
        )  # our data does not always have time for a mask
    else:
        wetmask = ds_groundtruth.wetmask

    ds_groundtruth = _postprocess_for_plot(
        ds_groundtruth, areacello_values, areacello_spherical_values, dz, times, wetmask
    )

    coords = ds_groundtruth.coords

    for key in pred_dict.keys():
        pred_dict[key]["ds_prediction"] = _postprocess_for_plot(
            pred_dict[key]["ds_prediction"],
            areacello_values,
            areacello_spherical_values,
            dz,
            times,
            wetmask,
            coords=coords,
        )

        # Rename lat and lon to y and x
        pred_dict[key]["ds_prediction"] = pred_dict[key]["ds_prediction"].rename(
            {"lat": "y", "lon": "x"}
        )

    # Rename lat and lon to y and x (This needs to be done in the end!)
    ds_groundtruth = ds_groundtruth.rename({"lat": "y", "lon": "x"})

    return ds_groundtruth, pred_dict

process_data(data, pred_dict, dataset_spec)

Get plot ready OM4 data.

Source code in src/ocean_emulators/viz/core.py
def process_data(
    data: xr.Dataset,
    pred_dict: dict[str, dict[str, Any]],
    dataset_spec: DatasetSpec,
) -> tuple[xr.Dataset, dict[str, dict[str, Any]]]:
    """
    Get plot ready OM4 data.
    """
    ds_groundtruth = with_level_index_vars(data, dataset_spec=dataset_spec)

    # Store ds_prediction
    copy_dict = deepcopy(pred_dict)

    for key in pred_dict.keys():
        ds_prediction = pred_dict[key]["data"]

        assert ds_prediction.time.size == ds_groundtruth.time.size, (
            f"Sizes different for {key}: {ds_prediction.time.size}!="
            f"{ds_groundtruth.time.size}; prediction range is "
            f"{ds_prediction.time.values[0]} to "
            f"{ds_prediction.time.values[-1]}\n"
            f"groundtruth range is {ds_groundtruth.time.values[0]} to "
            f"{ds_groundtruth.time.values[-1]}"
        )
        if "model_path" in ds_prediction.attrs:
            copy_dict[key]["model_path"] = ds_prediction.attrs["model_path"]

        pred_dict[key]["ds_prediction"] = ds_prediction

    ### Combine Variables by level
    ds_groundtruth, pred_dict = combine_variables_by_level(
        ds_groundtruth, pred_dict, dataset_spec
    )

    ### Postprocess predictions for plotting
    ds_groundtruth, pred_dict = postprocess_for_plot(
        ds_groundtruth,
        ds_groundtruth.areacello,
        np.array(dataset_spec.depth_thickness),
        pred_dict,
    )

    return ds_groundtruth, pred_dict

profile_mean(ds)

Compute the mean of each variable for each time step.

Source code in src/ocean_emulators/viz/core.py
def profile_mean(ds: xr.Dataset) -> xr.Dataset:
    """
    Compute the mean of each variable for each time step.
    """
    return ds.weighted(ds.areacello).mean(["y", "x"])

Config

ocean_emulators.viz.config