Data Utilities¶
ocean_emulators.utils.data
¶
Masks(prognostic, boundary)
dataclass
¶
A collection of masks to expose the ocean and mask land.
DataSource(name, data, means, stds, masks, dataset_spec)
dataclass
¶
Data source for the model.
is_compact
cached
property
¶
Check if the data source is compact.
filter(var_names, *, prefix)
¶
Filter the data source to only include the specified variables (and levels).
If the dataset is compact, it will also filter the levels based on the variable names (which encode the level in the name).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
var_names
|
PrognosticVarNames | BoundaryVarNames
|
Variable names to filter. |
required |
prefix
|
str
|
Prefix for the new data source name. |
required |
Returns:
| Type | Description |
|---|---|
Self
|
A new |
Source code in src/ocean_emulators/utils/data.py
map(func, *, suffix=None)
¶
Map the function over the data source.
Source code in src/ocean_emulators/utils/data.py
map_data(func, *, suffix=None)
¶
Map the function over just data in DataSource.
Source code in src/ocean_emulators/utils/data.py
slice(time)
¶
Slice the data source to only include the specified time slice.
Source code in src/ocean_emulators/utils/data.py
normalize(fill_nan=True, fill_value=0.0)
¶
normalize_with(data, variable_axis=0, fill_nan=True, fill_value=0.0)
¶
Normalize input data treated as torch Tensors.
Source code in src/ocean_emulators/utils/data.py
OceanData(data, means, stds, mask)
dataclass
¶
A slice of ocean data (boundary or prognostic) with normalization statistics.
This dataclass bundles raw tensor data with the statistics needed to normalize it
and the mask needed to handle land/invalid values. It serves as an intermediary
representation used when constructing training Examples from raw xarray data.
The typical workflow is
- Load raw data from a
DataSourceviafrom_data_source() - Slice to the desired time range with
with_time() - Apply normalization and masking with
normalize_and_mask() - Flatten time/variable dims to create the final
InputorPrognostictensor
Attributes:
| Name | Type | Description |
|---|---|---|
data |
Float[Tensor, 'batch time variable lat lon']
|
Raw ocean variable values with shape (batch, time, variable, lat, lon). |
means |
Float[Tensor, ' variable']
|
Per-variable means for normalization, shape (variable,). |
stds |
Float[Tensor, ' variable']
|
Per-variable standard deviations for normalization, shape (variable,). |
mask |
Bool[Tensor, ' variable']
|
Boolean mask indicating valid ocean points (True) vs land (False), broadcast-compatible with the variable dimension. |
with_time(time_range)
¶
normalize_and_mask(normalize_before_mask, masked_fill_value)
¶
Normalize and mask tensors.
Source code in src/ocean_emulators/utils/data.py
Normalize(src, prognostic_var_names, boundary_var_names)
¶
Store normalization parameters and pre-compute numpy arrays.
Source code in src/ocean_emulators/utils/data.py
normalize_tensor_prognostic(data, fill_nan=True, fill_value=0.0)
¶
Normalize prognostic tensor.
Source code in src/ocean_emulators/utils/data.py
unnormalize_tensor_prognostic(data, fill_value=float('nan'))
¶
Unnormalize prognostic tensor and apply fill value to land cells.
Source code in src/ocean_emulators/utils/data.py
unnormalize_tensor_boundary(data, fill_value=float('nan'))
¶
Unnormalize boundary tensor.
Source code in src/ocean_emulators/utils/data.py
LoadStats(load_time_seconds)
dataclass
¶
Captures stats about loading a single TrainData object.
accumulated(stats)
classmethod
¶
Accumulate the stats across multiple LoadStats objects in a batch.
conditional_rearrange(data, pattern, except_dim='lev', concat_dim='variable')
¶
Rearrange a Dataset using an einsum notation with and without a dimension.
When a dataset has variables with a mixture of dimensions and an einsum-like rearrange is applied on that dataset, it's common that the pattern will combinate one too many variables. Sometimes, it's desirable to apply the rearrange pattern on two versions of the data: one including variables with that dimension and one without, and then concatenate them along a new dimension.
For example, surface level boundary variables, which only occur at t0, should not be
combinatorially rearranged with depth variables that have multiple time steps. In
such a situation, this function can be used to apply a standard einsum rearrangement
to depth and surface variables, including and excluding variables who have a time
dimension, respectively.
This method is stable: even if it creates a new number of dimensions, it will preserve the order of the variables in the original dataset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Dataset
|
The dataset to rearrange. |
required |
pattern
|
str
|
The einsum pattern to use for rearranging. |
required |
except_dim
|
The dimension to exclude from the pattern. |
'lev'
|
|
concat_dim
|
The dimension to concatenate along. |
'variable'
|
Returns:
| Type | Description |
|---|---|
DataArray
|
The combined, rearranged dataset as a |
Source code in src/ocean_emulators/utils/data.py
438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 | |
extract_wet_mask(data, prognostic_var_names, *, dataset_spec)
¶
A mask for where the oceans are. Water is wet.
Source code in src/ocean_emulators/utils/data.py
flatten_masks(data, dataset_spec)
¶
Adds level-wise mask variables from the stacked wet mask.
Source code in src/ocean_emulators/utils/data.py
unflatten_masks(data, dataset_spec)
¶
Adds a stacked wet mask xarray.DataArray from level-wise mask variables.
Source code in src/ocean_emulators/utils/data.py
spherical_area(data)
¶
Compute real grid cell areas on a spherical Earth.
Uses the spherical geometry formula: A = R² × Δλ × (sin(φ₂) - sin(φ₁))
where: - R is Earth's radius (6371 km) - Δλ is the longitude spacing in radians - φ₁, φ₂ are the latitude bounds of the cell in radians
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data
|
Dataset
|
Dataset containing lat/lon coordinates |
required |
Returns:
| Type | Description |
|---|---|
Grid
|
Grid cell areas in m² |
Source code in src/ocean_emulators/utils/data.py
get_inference_steps(data_source, hist=1)
¶
Get the number of inference/rollout steps for the given time configuration.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
data_source
|
DataSource
|
The data source sliced to the inference time range |
required |
hist
|
int
|
How many additional history samples we get per step |
1
|
Returns:
| Name | Type | Description |
|---|---|---|
num_steps |
Total number of rolled-out inferences which fit into the time range |
Source code in src/ocean_emulators/utils/data.py
get_anomalies_vars(var_names)
¶
Get the variables that need to be computed for anomalies.
compute_anomalies(data, means, stds, anomalies_vars)
¶
Compute anomalies for the given variables.
Source code in src/ocean_emulators/utils/data.py
with_level_index_vars(data, dataset_spec)
¶
Ensure variable names use a depth level index, not depth level value.
Source code in src/ocean_emulators/utils/data.py
with_lat_lon_coords(data)
¶
Standardize dataset coordinates; prefer "lat"/"lon" over "y"/"x".
Source code in src/ocean_emulators/utils/data.py
validate_data(data, means, stds, dataset_spec, boundary_var_names, static_data_vars=None)
¶
Validate the data such that we have the correct format for training.