# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base metrics class."""
import abc
from collections.abc import Iterator, Mapping
from typing import Hashable, final
from weatherbenchX import xarray_tree
import xarray as xr
[docs]
class Metric(abc.ABC):
"""Abstract base class for metrics.
A `Metric` is defined by specifying:
* One or more `Statistic`s, which are functions of prediction/target pairs.
These are specified by implementing the `statistics` property, but are
implemented separately to the `Metric` to allow them to be reused across
multiple `Metric`s.
* A function to compute the metric's final value from (weighted) *means* of
the statistics, computed in aggregate over multiple prediction/target pairs.
This is specified by implementing `values_from_mean_statistics`.
As an example, the `RMSE` metric is defined by specifying the
`SquaredError` statistic, which returns squared errors of prediction/target
pairs, and a function which takes the square root of the mean of the
`SquaredError` statistic.
The form of weighted mean(s) used to aggregate the statistics is not
determined by the Metric and can be chosen independently, for example to
achieve different types of disaggregation and weighting. See
`aggregation.Aggregator` for details.
`Metric`s computed for each variable independently should be implemented by
subclassing `PerVariableMetric`.
"""
@property
@abc.abstractmethod
def statistics(self) -> Mapping[str, 'Statistic']:
"""Statistics whose mean values are required to compute the metric.
The keys of this Mapping are internal names for the statistics which will
be used to pass the mean values of the requested statistics to you in your
`values_from_mean_statistics` method. They are not required to be unique
outside of a specific Metric instance; externally the .unique_name of the
Statistic will be used instead, the internal names can be chosen to be more
convenient for the Metric.
"""
@abc.abstractmethod
def values_from_mean_statistics(
self,
statistic_values: Mapping[str, Mapping[Hashable, xr.DataArray]],
) -> Mapping[Hashable, xr.DataArray]:
"""Computes metrics from mean statistics, given by their internal names.
Args:
statistic_values: Mapping from your internal statistic names (the keys
of your self.statistics) to the mean values of the corresponding
statistics. These mean values consist of a mapping from variable
name to a DataArray of values.
Returns:
A Mapping from variable name to metric value DataArray. As with
statistics, the variable names don't have to correspond to variables in
the original predictions and targets, although they should where it makes
sense. If your Metric is defined on a per-variable basis, consider
subclassing PerVariableMetric.
"""
[docs]
class Statistic(Metric):
"""Abstract base class for statistics.
Statistics are functions of a pair of predictions/targets chunks, which are
intended to be aggregated by taking a (potentially weighted) mean over
multiple prediction/target pairs, and then used in the computation of a
Metric.
A Statistic can be used in two ways:
* It can be used directly as a Metric, since it implements the Metric
interface itself by passing through the mean of the statistic's values.
* One or more statistics can be wrapped as a Metric which performs some
additional computation (via `values_from_mean_statistics`) on the mean
statistics.
The incoming predictions/targets chunks can either be a dictionary of
DataArrays or a Dataset.
For univariate metrics, a PerVariableStatistic should be implemented.
Multivariate metrics have access to all variables. The output should also be a
Mapping from str to xr.DataArray. In other words, the DataArray has to be
named.
Statistics are required to assign their own `unique_name`, which is used to
deduplicate the computation of statistics that are used by multiple metrics.
Any additional parameters of the statistic which affect the result of the
computation should be captured in self.unique_name.
Statistics should preserve dimensions that are a) required to compute binnings
or weights on and b) over which the (weighted) mean is computed. These will
typically be the time dimensions (if chunking is done in time) and/or the
spatial/observation dimensions (if these are needed for binning or weighting).
Other dimensions can be reduced.
"""
@property
def unique_name(self) -> str:
"""Unique name of the statistic.
When computing multiple metrics, statistics with the same unique_name will
be assumed to be duplicates and computed only once, hence it is important
that this is unique across all statistics that are likely to be used
together.
Defaults to class name. Remember to change to a unique identifier in case
the statistic has additional parameters which affect the result of the
computation.
"""
return type(self).__name__
@abc.abstractmethod
def compute(
self,
predictions: Mapping[Hashable, xr.DataArray],
targets: Mapping[Hashable, xr.DataArray],
) -> Mapping[Hashable, xr.DataArray]:
"""Computes statistics per predictions/targets chunk.
Args:
predictions: Xarray Dataset or DataArray.
targets: Xarray Dataset or DataArray.
Returns:
statistic: Corresponding statistic, as a mapping of variable name to
DataArray.
For statistics whose values correspond to specific variables in the
predictions and targets, the variable names should be the same as the
relevant keys in the predictions and targets, and you should consider
subclassing from PerVariableStatistic if your metric can be computed
one variable at a time in a generic way.
For statistics whose values don't correspond to specific variables in
the predictions and targets, you'll need to make up new variable name(s)
to use here.
"""
# Trivial implementation of the Metric interface as a Metric which outputs
# just the mean of the statistic:
@final
@property
def statistics(self) -> Mapping[str, 'Statistic']:
return {'self': self}
@final
def values_from_mean_statistics(
self,
statistic_values: Mapping[str, Mapping[Hashable, xr.DataArray]],
) -> Mapping[Hashable, xr.DataArray]:
return statistic_values['self']
[docs]
class PerVariableStatistic(Statistic):
"""Abstract base class for statistics that are computed per variable.
The statistic will be computed independently for each variable that is
present in both predictions and targets.
"""
@final
def compute(
self,
predictions: Mapping[Hashable, xr.DataArray],
targets: Mapping[Hashable, xr.DataArray],
) -> Mapping[Hashable, xr.DataArray]:
"""Maps computation over all variables."""
result = {}
for var_name in predictions.keys():
if var_name in targets.keys():
per_var_result = self._compute_per_variable(
predictions[var_name], targets[var_name])
if per_var_result is not None:
result[var_name] = per_var_result
return result
@abc.abstractmethod
def _compute_per_variable(
self,
predictions: xr.DataArray,
targets: xr.DataArray,
) -> xr.DataArray | None:
"""Computes statistic for a variable, or None if it's not defined."""
[docs]
class PerVariableMetric(Metric):
"""Abstract base class for metrics that are computed per variable."""
@final
def values_from_mean_statistics(
self,
statistic_values: Mapping[str, Mapping[Hashable, xr.DataArray]],
) -> Mapping[Hashable, xr.DataArray]:
# Get list of common variables present for all statistics.
common_variables = set.intersection(
*[set(statistic_values[s]) for s in self.statistics]
)
values = {}
# Compute values for all common variables.
for v in common_variables:
stats_per_variable = {s: statistic_values[s][v] for s in self.statistics}
values[v] = self._values_from_mean_statistics_per_variable(
stats_per_variable
)
return values
@abc.abstractmethod
def _values_from_mean_statistics_per_variable(
self,
statistic_values: Mapping[str, xr.DataArray],
) -> xr.DataArray:
"""Compute metric values for a single variable.
Args:
statistic_values: Mapping from your internal statistic names (the keys
of your self.statistics) to the mean values of the statistics for a
single specific variable.
Returns:
The value of the metric for the given variable.
"""
# Deprecated backwards-compatibility shim. NoOpMetric used to wrap a Statistic
# as a Metric, but this is no longer necessary as Statistic implements Metric
# directly.
NoOpMetric = lambda statistic: statistic
def generate_unique_statistics_for_all_metrics(
metrics: Mapping[str, Metric],
predictions: Mapping[Hashable, xr.DataArray],
targets: Mapping[Hashable, xr.DataArray],
) -> Iterator[tuple[str, Mapping[Hashable, xr.DataArray]]]:
"""Like compute_unique_statistics_for_all_metrics, but yields k/v pairs."""
unique_statistics = {}
for m in metrics.values():
for _, stat in m.statistics.items():
unique_statistics[stat.unique_name] = stat
for k, stat in unique_statistics.items():
try:
yield k, stat.compute(predictions, targets)
except Exception as e:
raise ValueError(
'Failed to compute statistic'
f' {k}={stat} from:\n{predictions=}\n{targets=}'
) from e
def compute_unique_statistics_for_all_metrics(
metrics: Mapping[str, Metric],
predictions: Mapping[Hashable, xr.DataArray],
targets: Mapping[Hashable, xr.DataArray],
) -> Mapping[str, Mapping[Hashable, xr.DataArray]]:
"""Computes unique statistics for all metrics.
Args:
metrics: Dictionary of metrics instances.
predictions: Xarray Dataset or dictionary of DataArrays.
targets: Xarray Dataset or dictionary of DataArrays.
Returns:
statistic_values: Unique statistics computed for each input element. If
inputs are Datasets, returns a dict of statistic_name to statistic
Dataset; if inputs are dictionaries, returns a nested dictionary of
statistic_name to variable to statistic DataArrays.
"""
return {k: stat for k, stat in generate_unique_statistics_for_all_metrics(
metrics, predictions, targets)}
def compute_metric_from_statistics(
metric: Metric,
statistic_values: Mapping[str, Mapping[Hashable, xr.DataArray]],
) -> Mapping[Hashable, xr.DataArray]:
"""Computes values of a metric from mean statistics keyed by .unique_name.
This handles re-keying the statistics by the internal names used by the
`metric`.
Args:
metric: A Metric.
statistic_values: Values of statistics keyed by their .unique_name,
for example as returned by compute_unique_statistics_for_all_metrics.
Returns:
The resulting values of the `metric`.
"""
# Rename statistics from unique to internal names.
statistic_values = {
k: statistic_values[v.unique_name] for k, v in metric.statistics.items()
}
return metric.values_from_mean_statistics(statistic_values)
def compute_metrics_from_statistics(
metrics: Mapping[str, Metric],
statistic_values: Mapping[str, Mapping[Hashable, xr.DataArray]],
) -> Mapping[str, Mapping[Hashable, xr.DataArray]]:
"""Computes multiple metrics from mean statistics keyed by .unique_name.
Args:
metrics: A mapping of multiple `Metric`s.
statistic_values: Values of statistics keyed by their .unique_name,
for example as returned by compute_unique_statistics_for_all_metrics.
Returns:
The resulting values of the `metrics`.
"""
return {
metric_name: compute_metric_from_statistics(metric, statistic_values)
for metric_name, metric in metrics.items()
}
[docs]
class PerVariableStatisticWithClimatology(Statistic):
"""Base class for per-variable statistics with climatology.
This class provides a convenient way to compute statistics that are a function
of both the prediction/target and the climatology. The climatology is aligned
with the prediction/target based on the prediction's valid_time.
Subclasses must implement the `_compute_per_variable_with_aligned_climatology`
method, which takes the predictions, targets, and aligned climatology as
arguments.
"""
def __init__(self, climatology: xr.Dataset):
"""Init.
Args:
climatology: The climatology dataset.
"""
self._climatology = climatology
@final
def compute(
self,
predictions: Mapping[Hashable, xr.DataArray],
targets: Mapping[Hashable, xr.DataArray],
) -> Mapping[Hashable, xr.DataArray]:
# Ensure both inputs are dictionaries.
# This is because sometimes mask coordinates can get lost if xarray_tree
# combines variables into a Dataset.
predictions = dict(predictions)
targets = dict(targets)
climatology = dict(self._climatology[list(predictions.keys())])
return xarray_tree.map_structure(
self._compute_per_variable, predictions, targets, climatology
)
@final
def _compute_per_variable(
self,
predictions: xr.DataArray,
targets: xr.DataArray,
climatology: xr.DataArray,
) -> xr.DataArray:
"""Compute statistics per variable."""
# Predictions/targets can either have a single time dimension: valid_time
if hasattr(predictions, 'valid_time'):
valid_time = predictions.valid_time
# Or init and lead time dimensions.
elif hasattr(predictions, 'init_time') and hasattr(
predictions, 'lead_time'
):
valid_time = predictions.init_time + predictions.lead_time
else:
raise ValueError(
'Predictions should have either valid_time or init/lead_time'
' dimensions.'
)
# Climatology either has time, dayofyear or dayofyear/hour dimensions
if hasattr(climatology, 'time'):
sel_kwargs = {'time': valid_time}
else:
sel_kwargs = {'dayofyear': valid_time.dt.dayofyear}
if hasattr(climatology, 'hour'):
sel_kwargs['hour'] = valid_time.dt.hour
aligned_climatology = climatology.sel(**sel_kwargs).compute()
return self._compute_per_variable_with_aligned_climatology(
predictions, targets, aligned_climatology
)
@abc.abstractmethod
def _compute_per_variable_with_aligned_climatology(
self,
predictions: xr.DataArray,
targets: xr.DataArray,
aligned_climatology: xr.DataArray,
) -> xr.DataArray:
"""Computes statistics per variable."""