# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Wrappers for statistics that transform the inputs.
Example to compute binary metrics from a continuous ensemble prediction and
a continuous ground truth:
wrappers.WrappedMetric(
categorical.CSI(),
[
wrappers.ContinuousToBinary(
which='both',
threshold_value=[0, 50],
threshold_dim='threshold_value'
),
wrappers.EnsembleMean(
which='predictions', ensemble_dim='realization'
),
wrappers.ContinuousToBinary(
which='predictions',
threshold_value=[0.25, 0.75],
threshold_dim='threshold_probability'
),
],
)
"""
import abc
from collections.abc import Sequence
from typing import Any, Callable, Hashable, Iterable, Mapping, Union
import numpy as np
from weatherbenchX import xarray_tree
from weatherbenchX.metrics import base
import xarray as xr
import xarray.ufuncs as xu
def binarize_thresholds(
x: xr.DataArray,
thresholds: Union[Iterable[float], xr.DataArray, xr.Dataset],
threshold_dim: str,
) -> xr.DataArray:
"""Binarizes a continuous array using a threshold value or a list of values.
Note that this retains NaNs in the input array. If NaNs are present, the
output will be of type float otherwise bool.
Args:
x: Input DataArray.
thresholds: List, xarray.DataArray or xarray.Dataset of threshold values.
threshold_dim: Name of dimension to use for threshold values.
Returns:
binary_x: Binarized DataArray.
"""
if isinstance(thresholds, xr.Dataset):
assert threshold_dim in thresholds.dims, (
f'threshold_dim ({threshold_dim}) not found in thresholds'
f' ({thresholds.dims})'
)
assert x.name in thresholds.data_vars, (
f'Input DataArray name ({x.name}) not found in thresholds'
f' ({thresholds.data_vars})'
)
threshold = thresholds[x.name]
elif isinstance(thresholds, xr.DataArray):
assert threshold_dim in thresholds.dims, (
f'threshold_dim ({threshold_dim}) not found in thresholds'
f' ({thresholds.dims})'
)
threshold = thresholds
else:
threshold = xr.DataArray(
thresholds, dims=[threshold_dim], coords={threshold_dim: thresholds}
)
return (x > threshold).where(~xu.isnan(x)).astype(np.float32)
# Transforms
[docs]
class EnsembleMean(InputTransform):
"""Compute ensemble mean."""
def __init__(
self,
which: str,
ensemble_dim='number',
skipna=False,
skip_if_ensemble_dim_missing: bool = False,
):
"""Init.
Args:
which: Which input to apply the wrapper to. Must be one of 'predictions',
'targets', or 'both'.
ensemble_dim: Name of ensemble dimension. Default: 'number'.
skipna: If True, skip NaNs in the ensemble mean. Default: False.
skip_if_ensemble_dim_missing: If True, skip the ensemble mean if the
ensemble dimension is missing. Default: False.
"""
super().__init__(which)
self._ensemble_dim = ensemble_dim
self._skipna = skipna
self._skip_if_ensemble_dim_missing = skip_if_ensemble_dim_missing
@property
def unique_name_suffix(self) -> str:
return f'ensemble_mean_{self._ensemble_dim=}_{self._skipna=}'
def transform_fn(self, da: xr.DataArray) -> xr.DataArray:
if self._ensemble_dim not in da.dims and self._skip_if_ensemble_dim_missing:
return da
return da.mean(self._ensemble_dim, skipna=self._skipna)
class EnsembleQuantiles(InputTransform):
"""Compute ensemble quantiles."""
def __init__(
self,
which: str,
quantiles: Union[float, Iterable[float]],
quantile_dim: str = 'quantile',
ensemble_dim: str = 'number',
skipna: bool = False,
skip_if_ensemble_dim_missing: bool = False,
):
"""Init.
Args:
which: Which input to apply the wrapper to. Must be one of 'predictions',
'targets', or 'both'.
quantiles: Quantile or sequence of quantiles to compute, which must be
between 0 and 1 inclusive.
quantile_dim: Name of the new dimension and coordinate that will store the
quantiles. Default: 'quantile'.
ensemble_dim: Name of ensemble dimension. Default: 'number'.
skipna: If True, skip NaNs in the ensemble quantiles. Note that setting
this to True will SIGNIFICANTLY slow down the computation! Default:
False.
skip_if_ensemble_dim_missing: If True, skip the ensemble quantiles if the
ensemble dimension is missing. Default: False.
"""
super().__init__(which)
# Ensure quantiles is a list so that xr.DataArray.quantile creates a
# dimension and coordinate for it.
self._quantiles = (
quantiles if isinstance(quantiles, Iterable) else [quantiles]
)
self._quantile_dim = quantile_dim
self._ensemble_dim = ensemble_dim
self._skipna = skipna
self._skip_if_ensemble_dim_missing = skip_if_ensemble_dim_missing
@property
def unique_name_suffix(self) -> str:
quantiles_str = ','.join([str(q) for q in self._quantiles])
return (
f'ensemble_quantiles_{self._ensemble_dim=}_{self._quantile_dim=}_'
f'{self._skipna=}_{quantiles_str}'
)
def transform_fn(self, da: xr.DataArray) -> xr.DataArray:
if self._ensemble_dim not in da.dims and self._skip_if_ensemble_dim_missing:
return da
if 'quantile' in da.dims:
raise ValueError(
'Input DataArray already has a `quantile` dimension. Please rename it'
' before applying the EnsembleQuantiles wrapper.'
)
result = da.quantile(
self._quantiles, dim=self._ensemble_dim, skipna=self._skipna
)
if self._quantile_dim != 'quantile':
result = result.rename({'quantile': self._quantile_dim})
return result
[docs]
class ContinuousToBinary(InputTransform):
"""Converts a continuous input to a binary one.
Applies x > threshold for all threholds and concatenates along a new dimension
of name `threshold_dim`.
"""
def __init__(
self,
which: str,
threshold_value: Union[float, Iterable[float], xr.DataArray, xr.Dataset],
threshold_dim: str,
unique_name_suffix: str | None = None,
):
"""Init.
Args:
which: Which input to apply the wrapper to. Must be one of 'predictions',
'targets', or 'both'.
threshold_value: Threshold value, iterable of values, xarray.DataArray or
xarray.Dataset.
threshold_dim: Name of dimension to use for threshold values.
unique_name_suffix: Suffix to add to the unique name. If
`threshold_values` is an xarray.DataArray or xarray.Dataset, this must
be provided, and must be unique over all the threshold_value that you
intend to use within a set of Metrics that are computed together.
"""
super().__init__(which)
# Convert to list if it isn't already.
self._threshold_value = (
threshold_value
if isinstance(threshold_value, (Iterable, xr.DataArray, xr.Dataset))
else [threshold_value]
)
self._threshold_dim = threshold_dim
if isinstance(self._threshold_value, (xr.DataArray, xr.Dataset)):
if unique_name_suffix is None:
raise ValueError(
'unique_name_suffix must be provided if threshold_value is an'
' xarray.DataArray or xarray.Dataset.'
)
self._unique_name_suffix = unique_name_suffix
@property
def unique_name_suffix(self) -> str:
if self._unique_name_suffix is None:
unique_name_suffix = ','.join([str(t) for t in self._threshold_value])
else:
unique_name_suffix = self._unique_name_suffix
return f'{self._threshold_dim}={unique_name_suffix}'
def transform_fn(self, da: xr.DataArray) -> xr.DataArray:
return binarize_thresholds(da, self._threshold_value, self._threshold_dim)
def select_bin_thresholds_by_time_from_chunk(
bin_thresholds: xr.DataArray,
chunk: xr.DataArray,
) -> xr.DataArray:
"""Selects bin thresholds by time from target or prediction data array.
This function is used to select bin thresholds using the time coordinates of
the target / prediction data array da.
Chunk must have dimensions (`init_time`, `lead_time`), `valid_time` or no time
dimensions. In the latter case, the bin thresholds are returned unchanged.
Options for bin_threshold time dimensions are:
- `valid_time`
- `init_time` and `lead_time` (only compatible for init/lead_time chunks)
- `dayofyear` (with `lead_time` for init/lead_time chunks)
- No time dimensions (in which case the bin thresholds are returned
unchanged)
Args:
bin_thresholds: Data array containing bin thresholds.
chunk: Data array of targets or predictions to use for selecting bin
thresholds according to their time coordinates.
Returns:
Data array containing bin thresholds selected by time.
"""
if {'init_time', 'lead_time'}.issubset(chunk.dims):
if 'valid_time' in bin_thresholds.dims:
bin_thresholds = bin_thresholds.sel(
valid_time=chunk.init_time + chunk.lead_time
)
elif {'init_time', 'lead_time'}.issubset(bin_thresholds.dims):
bin_thresholds = bin_thresholds.sel(
init_time=chunk.init_time, lead_time=chunk.lead_time
)
elif {'dayofyear', 'lead_time'}.issubset(bin_thresholds.dims):
bin_thresholds = bin_thresholds.sel(
dayofyear=chunk.init_time.dt.dayofyear, lead_time=chunk.lead_time
)
else:
# No time dimensions in bin_thresholds, so just return it.
return bin_thresholds
elif 'valid_time' in chunk.dims:
if 'valid_time' in bin_thresholds.dims:
bin_thresholds = bin_thresholds.sel(valid_time=chunk.valid_time)
elif 'dayofyear' in bin_thresholds.dims:
bin_thresholds = bin_thresholds.sel(
dayofyear=chunk.valid_time.dt.dayofyear
)
else:
# No time dimensions in bin_thresholds, so just return it.
return bin_thresholds
else:
# No time dimensions in chunk, so just return thresholds.
return bin_thresholds
return bin_thresholds.compute()
def compute_cdf(
threshold_values: Union[Iterable[float], xr.DataArray, xr.Dataset],
da: xr.DataArray,
threshold_dim: str,
enforce_monotonicity: bool,
right_inclusive: bool = True,
) -> xr.DataArray:
"""Computes the CDF of a continuous input."""
if isinstance(threshold_values, (xr.DataArray, xr.Dataset)):
if isinstance(threshold_values, xr.Dataset):
thresholds = threshold_values[da.name] # Select appropriate variable.
else:
thresholds = threshold_values
thresholds = select_bin_thresholds_by_time_from_chunk(thresholds, da)
elif isinstance(threshold_values, Iterable):
# Need to convert to array because xarray does not interpret tuples
# correctly.
threshold_values = np.array(threshold_values)
# Convert to DataArray
thresholds = xr.DataArray(
threshold_values,
dims=[threshold_dim],
coords={threshold_dim: threshold_values},
)
else:
raise ValueError(
'Bin values must be an Iterable, xr.DataArray, or xr.Dataset.'
)
if enforce_monotonicity:
if not np.all(thresholds.diff(threshold_dim) > 0):
raise ValueError(
'Bin values must be monotonically increasing. To turn off this'
' check, set `enforce_monotonicity=False`.'
)
if right_inclusive:
cdf = (da <= thresholds).astype('float')
else:
cdf = (da < thresholds).astype('float')
# Make sure NaNs are propagated from da and thresholds.
cdf = cdf.where(~xu.isnan(da)).where(~xu.isnan(thresholds))
return cdf
class ContinuousToBins(InputTransform):
"""Converts a continuous input to a binned one.
The bins are right-inclusive, i.e., `threshold[i-1] < x <= threshold[i]`. For
example, if `bin_values` is [0.5, 1.0, 1.5], then the bins will be
(0.5, 1.0] and (1.0, 1.5]. To have open-ended bins, the `bin_values` have to
specify the `-np.inf` and `np.inf` values at the edges. That means that the
size of the binning dimension is len(bin_values) -1 .
The output DataArray will have two new coordinates: `{bin_dim}_left` and
`{bin_dim}_right`, which will contain the left and right bin edges,
respectively in the case where the bin_values are a 1D array.
If the bin_values are a n-dimensional DataArray or Dataset (the dataset
variables must match those of the DataArray), they have to have a bin_dim
coordinate. The new coordinates in this case will not be the actual values
used for binning (since they will be n-dimensional) but rather the values
of the bin_dim coordinate.
"""
def __init__(
self,
which: str,
bin_values: Union[Iterable[float], xr.DataArray, xr.Dataset],
bin_dim: str,
unique_name_suffix: str | None = None,
enforce_monotonicity: bool = True,
):
"""Initialize the transform.
Args:
which: Which input to apply the wrapper to. Must be one of 'predictions',
'targets', or 'both'.
bin_values: Iterable of threshold values or xarray.DataArray. Must be
monotonically increasing.
bin_dim: Name of dimension to use for threshold values.
unique_name_suffix: Suffix to add to the unique name. If `bin_values` is
an xarray.DataArray or xarray.Dataset, this must be provided.
enforce_monotonicity: If True, raise an error if the bin values are not
monotonically increasing. It might be necessary to turn this off if the
bin_values contain NaNs, which would raise an error here. Default: True.
"""
super().__init__(which)
self._bin_values = bin_values
self._bin_dim = bin_dim
if isinstance(self._bin_values, (xr.DataArray, xr.Dataset)):
if unique_name_suffix is None:
raise ValueError(
'unique_name_suffix must be provided if bin_values is an'
' xarray.DataArray or xarray.Dataset.'
)
self._unique_name_suffix = unique_name_suffix
self._enforce_monotonicity = enforce_monotonicity
@property
def unique_name_suffix(self) -> str:
if self._unique_name_suffix is None:
unique_name_suffix = ','.join([str(t) for t in self._bin_values])
else:
unique_name_suffix = self._unique_name_suffix
return f'ContinuousToBins_{self._bin_dim}_{unique_name_suffix}'
def transform_fn(self, da: xr.DataArray) -> xr.DataArray:
cdf = compute_cdf(
threshold_values=self._bin_values,
da=da,
threshold_dim=self._bin_dim,
enforce_monotonicity=self._enforce_monotonicity,
)
left_edges = cdf[self._bin_dim].values[:-1]
right_edges = cdf[self._bin_dim].values[1:]
result = cdf.diff(self._bin_dim)
bin_names = [
f'{left:.2f} < p <= {right:.2f}'
for left, right in zip(left_edges, right_edges)
]
result = result.assign_coords({self._bin_dim: bin_names})
result = result.assign_coords(
{f'{self._bin_dim}_left': (self._bin_dim, left_edges)}
)
result = result.assign_coords(
{f'{self._bin_dim}_right': (self._bin_dim, right_edges)}
)
return result
class ContinuousToCDF(InputTransform):
"""Converts a continuous input to a CDF.
Example: If `threshold_values` are [0.5, 1.0, 1.5], the resulting output
will have values for `[x <= 0.5, x <= 1.0, x <= 1.5]` if `right_inclusive` is
`True`, or `[x < 0.5, x < 1.0, x < 1.5]` if `right_inclusive` is `False`.
The resulting DataArray will have a coordinate `threshold` with the threshold
values in the case of a 1D array, or the values of the `threshold_dim`
coordinate in the case of a n-dimensional DataArray or Dataset.
"""
def __init__(
self,
which: str,
threshold_values: Union[Iterable[float], xr.DataArray, xr.Dataset],
threshold_dim: str,
unique_name_suffix: str | None = None,
enforce_monotonicity: bool = True,
right_inclusive: bool = True,
):
"""Initialize the transform.
Args:
which: Which input to apply the wrapper to. Must be one of 'predictions',
'targets', or 'both'.
threshold_values: Iterable of threshold values or xarray.DataArray. Must
be monotonically increasing.
threshold_dim: Name of dimension to use for threshold values.
unique_name_suffix: Suffix to add to the unique name. If
`threshold_values` is an xarray.DataArray or xarray.Dataset, this must
be provided.
enforce_monotonicity: If True, raise an error if the bin values are not
monotonically increasing. It might be necessary to turn this off if the
bin_values contain NaNs, which would raise an error here. Default: True.
right_inclusive: If True, the CDF is right-inclusive. If False, the CDF is
left-inclusive. Default: True.
"""
super().__init__(which)
self._threshold_values = threshold_values
self._threshold_dim = threshold_dim
if isinstance(self._threshold_values, (xr.DataArray, xr.Dataset)):
if unique_name_suffix is None:
raise ValueError(
'unique_name_suffix must be provided if threshold_values is an'
' xarray.DataArray or xarray.Dataset.'
)
self._unique_name_suffix = unique_name_suffix
self._enforce_monotonicity = enforce_monotonicity
self._right_inclusive = right_inclusive
@property
def unique_name_suffix(self) -> str:
if self._unique_name_suffix is None:
unique_name_suffix = ','.join([str(t) for t in self._threshold_values])
else:
unique_name_suffix = self._unique_name_suffix
return f'ContinuousToCDF_{self._threshold_dim}_{unique_name_suffix}_right_inclusive_{self._right_inclusive}'
def transform_fn(self, da: xr.DataArray) -> xr.DataArray:
result = compute_cdf(
threshold_values=self._threshold_values,
da=da,
threshold_dim=self._threshold_dim,
enforce_monotonicity=self._enforce_monotonicity,
right_inclusive=self._right_inclusive,
)
return result
class WeibullEnsembleToProbabilistic(InputTransform):
"""Convert ensemble forecasts into probabilitic forecast using the Weibull’s plotting position (Makkonen, 2006).
The forecasts should be binarized before applying this wrapper and you can
wrap the metric with the ContinuousToBinary firstly.
Makkonen, L.: Plotting Positions in Extreme Value Analysis, Journal of Applied
Meteorology and Climatology,
45, 334–340, https://doi.org/10.1175/JAM2349.1, 2006.
"""
def __init__(self, which, ensemble_dim='number', skipna=False):
"""Init.
Args:
which: Which input to apply the wrapper to. Must be 'predictions'.
ensemble_dim: Name of ensemble dimension. Default: 'number'.
"""
assert (
which == 'predictions'
), 'Only predictions can be converted to probabilities'
super().__init__(which)
self._ensemble_dim = ensemble_dim
self._skipna = skipna
@property
def unique_name_suffix(self) -> str:
return 'ensemble_to_probabilistic_by_weibull_plotting_position'
def transform_fn(self, da: xr.DataArray) -> xr.DataArray:
ensemble_members = da.sizes[self._ensemble_dim]
return da.sum(self._ensemble_dim, skipna=self._skipna) / (
ensemble_members + 1
)
class Inline(InputTransform):
"""Transform data with a provided function.
Example:
# Negate values
x = xr.DataArray(...)
y = Inline('both', lambda da: -da, 'negate').transform_fn(x)
xr.testing.assert_equal(y, -x)
"""
def __init__(
self,
which: str,
transform_fn: Callable[[xr.DataArray], xr.DataArray],
unique_name_suffix: str,
):
"""Initializes an Inline transform.
Args:
which: Which input to apply the wrapper to. Must be one of 'predictions',
'targets', or 'both'.
transform_fn: Function to transform a dataarray.
unique_name_suffix: Name to give this transform. Should be different than
any other transform used. Should uniquely identify this class and all
init args (except `which`).
"""
super().__init__(which)
self._transform_fn = transform_fn
self._unique_name_suffix = unique_name_suffix
@property
def unique_name_suffix(self) -> str:
return f'{self._unique_name_suffix}'
def transform_fn(self, da: xr.DataArray) -> xr.DataArray:
return self._transform_fn(da)
class ReLU(InputTransform):
"""Transform data by passing through a rectified linear unit."""
def __init__( # pylint: disable=useless-parent-delegation
self,
which: str,
):
"""Initializes a ReLU transform.
Args:
which: Which input to apply the wrapper to. Must be one of 'predictions',
'targets', or 'both'.
"""
super().__init__(which)
@property
def unique_name_suffix(self) -> str:
return 'relu'
def transform_fn(self, da: xr.DataArray) -> xr.DataArray:
return xr.where(da > 0, da, 0).where(~xu.isnan(da))
class ShiftAlongNewDim(InputTransform):
"""Transforms by shifting along a new dimension, possibly more than once.
This transform adds a new dimension `shift_dim` to the input DataArray, and
shifts the values along this dimension. The `shift_value` can be a constant,
a list of constants, or an xarray.Dataset.
Examples:
x = xr.DataArray(...)
x.sizes
==> Frozen({'time': 366, 'latitude': 19, 'longitude': 36})
# Shift every value by a constant (=1), add a singleton dim.
y = ShiftAlongNewDim('both', shift_value=1, shift_dim='threshold')
y.sizes
==> Frozen({'time': 366, 'latitude': 19, 'longitude': 36, 'threshold': 1})
y.threshold
==> <xarray.DataArray 'threshold' (threshold: 1)> Size: 8B
array([0.5])
Coordinates:
* threshold (threshold) float64 8B 0.5
# Shift by two different values, indexing the shift by dim 'threshold'.
y = ShiftAlongNewDim('both', shift_value=[10., 20.], shift_dim='threshold')
y.sizes
==> Frozen({'time': 366, 'latitude': 19, 'longitude': 36, 'threshold': 2})
y.threshold
==> <xarray.DataArray 'threshold' (threshold: 2)> Size: 16B
array([10., 20.])
Coordinates:
* threshold (threshold) float64 8B 10.0 20.0
# Shift by adding a DataArray. In this case the quantiles of another Dataset
quantiles = xr.Dataset(...).quantile([0.2, 0.8])
y = ShiftAlongNewDim('both', shift_value=quantiles, shift_dim='threshold')
y.sizes
==> Frozen({'time': 366, 'latitude': 19, 'longitude': 36, 'quantile': 2})
y['quantile']
==> <xarray.DataArray 'quantile' (quantile: 2)> Size: 16B
array([0.2, 0.8])
Coordinates:
* quantile (quantile) float64 8B 0.2 0.8
"""
def __init__(
self,
which: str,
shift_value: Union[float, Iterable[float], xr.Dataset],
shift_dim: str,
unique_name_suffix: str,
):
"""Initializes a ShiftAlongNewDim transform.
Args:
which: Which input to apply the wrapper to. Must be one of 'predictions',
'targets', or 'both'.
shift_value: Constant value, list of values or xarray.Dataset.
shift_dim: Name of dimension to use for the shift. The output will have a
new dimension, `shift_dim`.
unique_name_suffix: Name to give this transform. Should be different than
any other transform used. Should uniquely identify this class and all
init args (except `which`).
"""
super().__init__(which)
self._shift_value = (
shift_value
if isinstance(shift_value, (Iterable, xr.Dataset))
else [shift_value]
)
self._shift_dim = shift_dim
self._unique_name_suffix = unique_name_suffix
@property
def unique_name_suffix(self) -> str:
return self._unique_name_suffix
def transform_fn(self, da: xr.DataArray) -> xr.DataArray:
# Convert self._shifts to a DataArray in all cases.
if isinstance(self._shift_value, xr.Dataset):
shifts = self._shift_value[da.name]
if self._shift_dim not in shifts.dims:
raise RuntimeError(
f'Expected to find {self._shift_dim=} in {shifts.dims=} but did'
' not. This is probably an error.'
)
shifts = self._shift_value[da.name]
else:
shifts = xr.DataArray(
self._shift_value,
dims=[self._shift_dim],
coords={self._shift_dim: self._shift_value},
)
return da + shifts
class Rename(InputTransform):
"""Renames variables, coordinates and dimensions with xr.rename."""
def __init__(
self,
which: str,
renames: Mapping[Hashable, Hashable],
):
"""Initializes a Rename transform.
Args:
which: Which input to apply the wrapper to. Must be one of 'predictions',
'targets', or 'both'.
renames: Map from OLD to NEW name.
"""
super().__init__(which)
self._renames = renames
@property
def unique_name_suffix(self) -> str:
return f'rename_{self._renames}'
def transform_fn(self, da: xr.DataArray) -> xr.DataArray:
return da.rename(self._renames)
class Select(InputTransform):
"""Selects data by applying xr.sel and/or xr.isel."""
def __init__(
self,
which: str,
sel: Mapping[Hashable, Any] | None = None,
isel: Mapping[Hashable, Any] | None = None,
sel_kwargs: Mapping[Hashable, Any] | None = None,
isel_kwargs: Mapping[Hashable, Any] | None = None,
):
"""Initializes a Select transform.
Args:
which: Which input to apply the wrapper to. Must be one of 'predictions',
'targets', or 'both'.
sel: Selector to pass to xr.sel.
isel: Selector to pass to xr.isel.
sel_kwargs: Additional kwargs to pass to xr.sel.
isel_kwargs: Additional kwargs to pass to xr.isel.
"""
super().__init__(which)
self._isel = isel
self._sel = sel
self._isel_kwargs = isel_kwargs or {}
self._sel_kwargs = sel_kwargs or {}
@property
def unique_name_suffix(self) -> str:
return f'select_{self._isel=}_{self._isel_kwargs=}_{self._sel=}_{self._sel_kwargs=}'
def transform_fn(self, da: xr.DataArray) -> xr.DataArray:
da = da.copy()
if self._sel is not None:
da = da.sel(self._sel, **self._sel_kwargs)
if self._isel is not None:
da = da.isel(self._isel, **self._isel_kwargs)
return da
class StackToNewDimension(InputTransform):
"""Stacks any number of existing dimensions into a "range" new dimension.
Similar to xr.stack, except StackToNewDimension...
* allows the new dimension name to be the same as one of the stacked dims.
* results in a dimension with values [0, 1, ...,]
"""
def __init__(
self,
which: str,
dims_to_stack: Sequence[Hashable],
new_dim_name: Hashable,
):
"""Initializes a StackToNewDimension Transform.
Args:
which: Which input to apply the wrapper to. Must be one of 'predictions',
'targets', or 'both'.
dims_to_stack: Sequence of dimensions to stack.
new_dim_name: Name of stacked dimension.
"""
super().__init__(which)
self._dims_to_stack = dims_to_stack
self._temporary_new_dim_name = 'STACK_OF_' + '_AND_'.join(
(str(d) for d in self._dims_to_stack)
)
self._new_dim_name = new_dim_name
@property
def unique_name_suffix(self) -> str:
return f'stack_{self._dims_to_stack}_to_{self._new_dim_name}'
def transform_fn(self, da: xr.DataArray) -> xr.DataArray:
stacked = da.stack({self._temporary_new_dim_name: self._dims_to_stack})
return stacked.drop_vars(self._dims_to_stack).rename(
{self._temporary_new_dim_name: self._new_dim_name}
)
def construct_tiles(
da: xr.DataArray,
window_size: int = 3,
window_dim: str = 'window',
wrap_longitude: bool = False,
) -> xr.DataArray:
"""Constructs local neighborhood tiles (windows) across the spatial grid.
At each latitude/longitude pixel, a `window_size` x `window_size` patch of
adjacent pixels is collected along a new `window_dim` dimension by rolling the
data. To prevent rolling from wrapping around the poles, pixels where a
full spatial window cannot be formed without wrapping latitude boundaries are
removed. If `wrap_longitude` is False, pixels near the longitude edges are
also removed.
Args:
da: The input DataArray containing `latitude` and `longitude` dimensions.
Note that wrappers apply this function per DataArray, not per Dataset.
window_size: The length of the side of the square tile window (in pixels).
window_dim: The name of the new dimension along which the window pixels are
concatenated.
wrap_longitude: If True, assume longitude wraps around (e.g. global grids).
If False (default), also clip values near longitude edges (e.g. for
regional datasets).
Returns:
An `xr.DataArray` with the additional `window_dim` dimension representing
the local spatial patch for each point on the grid.
"""
shifts = []
for i in range(window_size):
for j in range(window_size):
d_lat = i - window_size // 2
d_lon = j - window_size // 2
rolled = da.roll(latitude=d_lat, longitude=d_lon, roll_coords=False)
shifts.append(rolled)
windowed = xr.concat(shifts, dim=window_dim)
# Remove any window that was rolled across the latitude edges.
half_window_size = window_size // 2
window_reach_lower = half_window_size
window_reach_upper = window_size - 1 - half_window_size
windowed = windowed.isel(
latitude=slice(
window_reach_lower, da.sizes['latitude'] - window_reach_upper
)
)
if not wrap_longitude:
windowed = windowed.isel(
longitude=slice(
window_reach_lower, da.sizes['longitude'] - window_reach_upper
)
)
return windowed
class Tile(InputTransform):
"""Constructs local neighborhood tiles (windows) across the spatial grid.
After the transformation, each spatial location contains the data from a
rolling window of size window_size centered at that location.
Adds a new dimension to the DataArray named with window_dim. The size of the
dimension is defined by window_size, e.g. a window_size of 3 will give a new
dimension of size 3x3=9.
This transform shrinks the edges of the forecast where the windows were
incomplete. A window size of 3 will remove one row/column on the edges of each
spatial dimension (latitude, longitude). If wrap_longitude is True, no
shrinking happens on the longitude dimension, and the values are rolled over
from the other edge.
"""
def __init__(
self,
which: str,
window_size: int = 3,
window_dim: str = 'window',
wrap_longitude: bool = False,
):
"""Init.
Args:
which: Which input to apply the wrapper to. Must be one of 'predictions',
'targets', or 'both'.
window_size: The length of the side of the square tile window (in pixels).
window_dim: The name of the new dimension along which the window pixels
are concatenated.
wrap_longitude: If True, assume longitude wraps around. Default: False.
"""
super().__init__(which)
self._window_size = window_size
self._window_dim = window_dim
self._wrap_longitude = wrap_longitude
@property
def unique_name_suffix(self) -> str:
return (
f'tiled_window_size_{self._window_size}_wrap_{self._wrap_longitude}_'
f'dim_{self._window_dim}'
)
def transform_fn(self, da: xr.DataArray) -> xr.DataArray:
return construct_tiles(
da,
window_size=self._window_size,
window_dim=self._window_dim,
wrap_longitude=self._wrap_longitude,
)
[docs]
class WrappedStatistic(base.Statistic):
"""Wraps a statistic with an input transform.
Also adds suffix to unique name.
TODO(srasp): Add option for multiple transforms.
"""
def __init__(self, statistic: base.Statistic, transform: InputTransform):
"""Init.
Args:
statistic: Statistic object to wrap.
transform: Transform to apply to inputs.
"""
self.statistic = statistic
self.transform = transform
@property
def unique_name(self) -> str:
return f'{self.statistic.unique_name}_{self.transform.which}_{self.transform.unique_name_suffix}'
def compute(
self,
predictions: Mapping[Hashable, xr.DataArray],
targets: Mapping[Hashable, xr.DataArray],
) -> Mapping[Hashable, xr.DataArray]:
if self.transform.which in ('predictions', 'both'):
predictions = xarray_tree.map_structure(
self.transform.transform_fn,
predictions,
)
if self.transform.which in ('targets', 'both'):
targets = xarray_tree.map_structure(
self.transform.transform_fn,
targets,
)
return self.statistic.compute(predictions, targets)
class RenamedStatistic(base.Statistic):
"""Wraps a statistic with a new unique name."""
def __init__(self, statistic: base.Statistic, unique_name: str):
self._statistic = statistic
self._unique_name = unique_name
@property
def unique_name(self) -> str:
return self._unique_name
def compute(
self,
predictions: Mapping[Hashable, xr.DataArray],
targets: Mapping[Hashable, xr.DataArray],
) -> Mapping[Hashable, xr.DataArray]:
return self._statistic.compute(predictions, targets)
[docs]
class WrappedMetric(base.Metric):
"""Wraps all statistics of a metric with input transforms."""
def __init__(
self,
metric: base.Metric,
transforms: list[InputTransform],
unique_name_suffix: str | None = None,
):
"""Init.
Args:
metric: Metric to wrap.
transforms: List of input transforms to apply. The transforms will be
applied in the order they are added to the list. I.e. transforms [f, g,
h], transform x as h(g(f(x))).
unique_name_suffix: Optional suffix to use for uniquely naming all
associated statistics. By default, this is constructed automatically
from the transforms, which may be overly verbose.
"""
self.metric = metric
self.transforms = transforms
self.unique_name_suffix = unique_name_suffix
@property
def statistics(self) -> Mapping[Hashable, base.Statistic]:
stats = {}
for name, stat in self.metric.statistics.items():
# Apply wrappers in reverse order since the last one will be called first
# in subsequent code...i.e., if stat = W(V(stat)), then the final stat is
# computed as x --> V(W(x)).
original_name = stat.unique_name
for wrapper in self.transforms[::-1]:
stat = WrappedStatistic(stat, wrapper)
if self.unique_name_suffix is not None:
unique_name = f'{original_name}_{self.unique_name_suffix}'
stat = RenamedStatistic(stat, unique_name)
stats[name] = stat
return stats
def values_from_mean_statistics(
self,
statistic_values: Mapping[str, Mapping[Hashable, xr.DataArray]],
) -> Mapping[Hashable, xr.DataArray]:
return self.metric.values_from_mean_statistics(statistic_values)
class SubselectVariablesForStatistic(base.Statistic):
"""Only compute variables for a subset of variables."""
def __init__(self, statistic: base.Statistic, variables: Sequence[str]):
"""Init.
Args:
statistic: Statistic object to wrap.
variables: Variables to compute the statistic for.
"""
self.statistic = statistic
self.variables = variables
@property
def unique_name(self) -> str:
# Make sure to change unique name in case there is another, non-subsetted
# statistic with the same name.
variables_str = '_'.join(self.variables)
return f'{self.statistic.unique_name}_{variables_str}'
def compute(
self,
predictions: Mapping[Hashable, xr.DataArray],
targets: Mapping[Hashable, xr.DataArray],
) -> Mapping[Hashable, xr.DataArray]:
predictions = {k: v for k, v in predictions.items() if k in self.variables}
targets = {k: v for k, v in targets.items() if k in self.variables}
return self.statistic.compute(predictions, targets)
class SubselectVariables(base.Metric):
"""Only compute metric for a subset of variables."""
def __init__(self, metric: base.Metric, variables: Sequence[str]):
"""Init.
Args:
metric: Metric to wrap.
variables: Variables to compute the metric for.
"""
self.metric = metric
self.variables = variables
@property
def statistics(self) -> Mapping[Hashable, base.Statistic]:
stats = {}
for name, stat in self.metric.statistics.items():
stat = SubselectVariablesForStatistic(stat, self.variables)
stats[name] = stat
return stats
def values_from_mean_statistics(
self,
statistic_values: Mapping[str, Mapping[Hashable, xr.DataArray]],
) -> Mapping[Hashable, xr.DataArray]:
return self.metric.values_from_mean_statistics(statistic_values)
# These wrappers are no longer needed, since PerVariableStatistic and
# PerVariableMetric already do this. We maintain aliases for backwards
# compatibility, but they are deprecated and will be removed in the future.
IntersectPredictionAndTargetVariablesForStatistic = lambda statistic: statistic # pylint: disable=invalid-name
IntersectPredictionAndTargetVariables = lambda metric: metric # pylint: disable=invalid-name