# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base metrics class."""
import abc
from typing import Hashable, Mapping
from weatherbenchX import xarray_tree
import xarray as xr
[docs]
class Statistic(abc.ABC):
"""Abstract base class for statistics.
Statistics are computed for a pair of predictions/targets chunks. The
resulting statistics chunks will then be averaged (potentially weighted)
across chunks.
The incoming predictions/targets chunks can either be a dictionary of
DataArrays or a Dataset.
For univariate metrics, a PerVariableStatistic should be implemented.
Multivariate metrics have access to all variables. The output should also be a
Mapping from str to xr.DataArray. In other words, the DataArray has to be
named.
Statistics are required to assign their own unique name. In the case of
additional parameters, these should be in self.unique_name.
Statistics should preserve dimensions that are a) required to compute binnings
or weights on and b) over which the (weighted) mean is computed. These will
typically be the time dimensions (if chunking is done in time) and/or the
spatial/observation dimensions (if these are needed for binning or weighting).
Other dimensions can be reduced.
Typically, one or more statistics are assiciated with a metric which then
uses the averaged statistic(s) to compute the final metric values.
"""
@property
def unique_name(self) -> str:
"""Name of the statistic.
Defaults to class name. Remember to change to a unique identifier in case
statistic has additional parameters.
"""
return type(self).__name__
@abc.abstractmethod
def compute(
self,
predictions: Mapping[Hashable, xr.DataArray],
targets: Mapping[Hashable, xr.DataArray],
) -> Mapping[Hashable, xr.DataArray]:
"""Computes statistics per predictions/targets chunk.
Args:
predictions: Xarray Dataset or DataArray.
targets: Xarray Dataset or DataArray.
Returns:
statistic: Corresponding statistic
"""
[docs]
class PerVariableStatistic(Statistic):
"""Abstract base class for statistics that are computed per variable."""
def compute(
self,
predictions: Mapping[Hashable, xr.DataArray],
targets: Mapping[Hashable, xr.DataArray],
) -> Mapping[Hashable, xr.DataArray]:
"""Maps computation over all variables."""
# Ensure both inputs are dictionaries.
# This is because sometimes mask coordinates can get lost if xarray_tree
# combines variables into a Dataset.
predictions = dict(predictions)
targets = dict(targets)
return xarray_tree.map_structure(
self._compute_per_variable, predictions, targets
)
@abc.abstractmethod
def _compute_per_variable(
self,
predictions: xr.DataArray,
targets: xr.DataArray,
) -> xr.DataArray:
"""Computes statistics per variable."""
[docs]
class Metric(abc.ABC):
"""Abstract base class for metrics.
Metrics define one or more statistics to be computed. Their names can be
chosen freely inside the metric. Before the computation of the metrics from
the aggregated statistics, the unique statistic names will be renamed to the
internal names. Metrics computed for each variable independently should
be implemented as PerVariableMetric classes.
"""
@property
@abc.abstractmethod
def statistics(self) -> Mapping[str, Statistic]:
"""Dictionary of required statistics."""
def values_from_mean_statistics(
self,
statistic_values: Mapping[str, Mapping[Hashable, xr.DataArray]],
) -> Mapping[Hashable, xr.DataArray]:
"""Computes metrics from averaged statistics."""
# Rename statistics from unique to internal names.
statistic_values = {
k: statistic_values[v.unique_name] for k, v in self.statistics.items()
}
return self._values_from_mean_statistics_with_internal_names(
statistic_values
)
@abc.abstractmethod
def _values_from_mean_statistics_with_internal_names(
self,
statistic_values: Mapping[str, Mapping[Hashable, xr.DataArray]],
) -> Mapping[Hashable, xr.DataArray]:
"""Computes metric values from statistics after renaming to internal names."""
[docs]
class PerVariableMetric(Metric):
"""Abstract base class for metrics that are computed per variable."""
def _values_from_mean_statistics_with_internal_names(
self,
statistic_values: Mapping[str, Mapping[Hashable, xr.DataArray]],
) -> Mapping[Hashable, xr.DataArray]:
# Get list of common variables present for all statistics.
common_variables = set.intersection(
*[set(statistic_values[s]) for s in self.statistics]
)
values = {}
# Compute values for all common variables.
for v in common_variables:
stats_per_variable = {s: statistic_values[s][v] for s in self.statistics}
values[v] = self._values_from_mean_statistics_per_variable(
stats_per_variable
)
return values
@abc.abstractmethod
def _values_from_mean_statistics_per_variable(
self,
statistic_values: Mapping[Hashable, xr.DataArray],
) -> xr.DataArray:
"""Compute metric values for a single variable."""
class NoOpMetric(PerVariableMetric):
"""General metric wrapper that simply returns the mean statistics."""
def __init__(self, statistic: Statistic):
"""Init.
Args:
statistic: Statistic to be wrapped.
"""
self._statistic = statistic
@property
def statistics(self) -> Mapping[str, Statistic]:
return {'statistic': self._statistic}
def _values_from_mean_statistics_per_variable(
self,
statistic_values: Mapping[Hashable, xr.DataArray],
) -> xr.DataArray:
"""Computes metrics from aggregated statistics."""
return statistic_values['statistic']
def compute_unique_statistics_for_all_metrics(
metrics: Mapping[str, Metric],
predictions: Mapping[Hashable, xr.DataArray],
targets: Mapping[Hashable, xr.DataArray],
) -> Mapping[str, Mapping[Hashable, xr.DataArray]]:
"""Computes unique statistics for all metrics.
Args:
metrics: Dictionary of metrics instances.
predictions: Xarray Dataset or dictionary of DataArrays.
targets: Xarray Dataset or dictionary of DataArrays.
Returns:
statistic_values: Unique statistics computed for each input element. If
inputs are Datasets, returns a dict of statistic_name to statistic
Dataset; if inputs are dictionaries, returns a nested dictionary of
statistic_name to variable to statistic DataArrays.
"""
unique_statistics = {}
for m in metrics.values():
for _, stat in m.statistics.items():
unique_statistics[stat.unique_name] = stat
statistic_values = {
k: stat.compute(predictions, targets)
for k, stat in unique_statistics.items()
}
return statistic_values
[docs]
class PerVariableStatisticWithClimatology(Statistic):
"""Base class for per-variable statistics with climatology.
This class provides a convenient way to compute statistics that are a function
of both the prediction/target and the climatology. The climatology is aligned
with the prediction/target based on the prediction's valid_time.
Subclasses must implement the `_compute_per_variable_with_aligned_climatology`
method, which takes the predictions, targets, and aligned climatology as
arguments.
"""
def __init__(self, climatology: xr.Dataset):
"""Init.
Args:
climatology: The climatology dataset.
"""
self._climatology = climatology
def compute(
self,
predictions: Mapping[Hashable, xr.DataArray],
targets: Mapping[Hashable, xr.DataArray],
) -> Mapping[Hashable, xr.DataArray]:
# Ensure both inputs are dictionaries.
# This is because sometimes mask coordinates can get lost if xarray_tree
# combines variables into a Dataset.
predictions = dict(predictions)
targets = dict(targets)
climatology = dict(self._climatology[list(predictions.keys())])
return xarray_tree.map_structure(
self._compute_per_variable, predictions, targets, climatology
)
def _compute_per_variable(
self,
predictions: xr.DataArray,
targets: xr.DataArray,
climatology: xr.DataArray,
) -> xr.DataArray:
"""Compute statistics per variable."""
# Predictions/targets can either have a single time dimension: valid_time
if hasattr(predictions, 'valid_time'):
valid_time = predictions.valid_time
# Or init and lead time dimensions.
elif hasattr(predictions, 'init_time') and hasattr(
predictions, 'lead_time'
):
valid_time = predictions.init_time + predictions.lead_time
else:
raise ValueError(
'Predictions should have either valid_time or init/lead_time'
' dimensions.'
)
# Climatology either has dayofyear or dayofyear/hour dimensions
sel_kwargs = {'dayofyear': valid_time.dt.dayofyear}
if hasattr(climatology, 'hour'):
sel_kwargs['hour'] = valid_time.dt.hour
aligned_climatology = climatology.sel(**sel_kwargs).compute()
return self._compute_per_variable_with_aligned_climatology(
predictions, targets, aligned_climatology
)
@abc.abstractmethod
def _compute_per_variable_with_aligned_climatology(
self,
predictions: xr.DataArray,
targets: xr.DataArray,
aligned_climatology: xr.DataArray,
) -> xr.DataArray:
"""Computes statistics per variable."""