Source code for weatherbenchX.metrics.deterministic

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of deterministic metrics and assiciated statistics."""

from typing import Hashable, Mapping, Sequence, Union
import numpy as np
from weatherbenchX import xarray_tree
from weatherbenchX.metrics import base
import xarray as xr


### Statistics


[docs] class Error(base.PerVariableStatistic): """Error between predictions and targets.""" def _compute_per_variable( self, predictions: xr.DataArray, targets: xr.DataArray, ) -> xr.DataArray: error = predictions - targets return error
[docs] class AbsoluteError(base.PerVariableStatistic): """Absolute error between predictions and targets.""" def _compute_per_variable( self, predictions: xr.DataArray, targets: xr.DataArray, ) -> xr.DataArray: error = predictions - targets return abs(error)
[docs] class SquaredError(base.PerVariableStatistic): """Squared error between predictions and targets.""" def _compute_per_variable( self, predictions: xr.DataArray, targets: xr.DataArray, ) -> xr.DataArray: return (predictions - targets) ** 2
[docs] class PredictionPassthrough(base.PerVariableStatistic): """Simply returns predictions.""" def _compute_per_variable( self, predictions: xr.DataArray, targets: xr.DataArray, ) -> xr.DataArray: # Make sure potential coordinates from targets are preserved. return predictions + xr.zeros_like(targets)
[docs] class TargetPassthrough(base.PerVariableStatistic): """Simply returns targets.""" def _compute_per_variable( self, predictions: xr.DataArray, targets: xr.DataArray, ) -> xr.DataArray: # Make sure potential coordinates from predictions are preserved. return targets + xr.zeros_like(predictions)
[docs] class WindVectorSquaredError(base.Statistic): """Computes squared error between two wind components. SE = (u_pred - u_target) ** 2 + (v_pred - v_target) ** 2 """ def __init__( self, u_name: Sequence[str], v_name: Sequence[str], vector_name: Sequence[str], ): """Init. Args: u_name: Name of the u wind component, e.g. [`u_component_of_wind`]. v_name: Name of the v wind component, e.g. [`v_component_of_wind`]. vector_name: Name to give output variable, e.g. [`wind`]. """ self._u_name = u_name self._v_name = v_name self._vector_name = vector_name if not len(self._u_name) == len(self._v_name) == len(self._vector_name): raise ValueError( 'u_name, v_name, and vector_name must have the same length' ) @property def unique_name(self) -> str: suffix = '_'.join(self._vector_name) return 'WindVectorSquaredError_' + suffix def compute( self, predictions: Mapping[Hashable, xr.DataArray], targets: Mapping[Hashable, xr.DataArray], ) -> Mapping[Hashable, xr.DataArray]: out = {} for u, v, vector in zip(self._u_name, self._v_name, self._vector_name): predictions_u = predictions[u] predictions_v = predictions[v] targets_u = targets[u] targets_v = targets[v] se = (predictions_u - targets_u) ** 2 + (predictions_v - targets_v) ** 2 out[vector] = se return out
[docs] class SquaredPredictionAnomaly(base.PerVariableStatisticWithClimatology): """Computes (predictions - climatology)**2.""" def _compute_per_variable_with_aligned_climatology( self, predictions: xr.DataArray, targets: xr.DataArray, aligned_climatology: xr.DataArray, ) -> xr.DataArray: prediction_anom = predictions - aligned_climatology return prediction_anom**2
[docs] class SquaredTargetAnomaly(base.PerVariableStatisticWithClimatology): """Computes (targets - climatology)**2.""" def _compute_per_variable_with_aligned_climatology( self, predictions: xr.DataArray, targets: xr.DataArray, aligned_climatology: xr.DataArray, ) -> xr.DataArray: target_anom = targets - aligned_climatology return target_anom**2
[docs] class AnomalyCovariance(base.PerVariableStatisticWithClimatology): """Computes (predictions - climatology) * (targets - climatology).""" def _compute_per_variable_with_aligned_climatology( self, predictions: xr.DataArray, targets: xr.DataArray, aligned_climatology: xr.DataArray, ) -> xr.DataArray: prediction_anom = predictions - aligned_climatology target_anom = targets - aligned_climatology return prediction_anom * target_anom
### Metrics
[docs] class Bias(base.PerVariableMetric): """Mean error.""" @property def statistics(self) -> Mapping[Hashable, base.Statistic]: return {'Error': Error()} def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[Hashable, xr.DataArray], ) -> xr.DataArray: """Computes metrics from aggregated statistics.""" return statistic_values['Error']
[docs] class MAE(base.PerVariableMetric): """Mean absolute error.""" @property def statistics(self) -> Mapping[Hashable, base.Statistic]: return {'AbsoluteError': AbsoluteError()} def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[Hashable, xr.DataArray], ) -> xr.DataArray: """Computes metrics from aggregated statistics.""" return statistic_values['AbsoluteError']
[docs] class MSE(base.PerVariableMetric): """Mean squared error. Note that if applied to probability forecasts, this is the Brier Score. """ @property def statistics(self) -> Mapping[Hashable, base.Statistic]: return {'SquaredError': SquaredError()} def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[Hashable, xr.DataArray], ) -> xr.DataArray: """Computes metrics from aggregated statistics.""" return statistic_values['SquaredError']
[docs] class RMSE(base.PerVariableMetric): """Root mean squared error.""" @property def statistics(self) -> Mapping[Hashable, base.Statistic]: return {'SquaredError': SquaredError()} def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[Hashable, xr.DataArray], ) -> xr.DataArray: """Computes metrics from aggregated statistics.""" return np.sqrt(statistic_values['SquaredError'])
[docs] class PredictionAverage(base.PerVariableMetric): """Average prediction values.""" @property def statistics(self) -> Mapping[Hashable, base.Statistic]: return {'PredictionPassthrough': PredictionPassthrough()} def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[Hashable, xr.DataArray], ) -> xr.DataArray: """Computes metrics from aggregated statistics.""" return statistic_values['PredictionPassthrough']
[docs] class TargetAverage(base.PerVariableMetric): """Average target values.""" @property def statistics(self) -> Mapping[Hashable, base.Statistic]: return {'TargetPassthrough': TargetPassthrough()} def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[Hashable, xr.DataArray], ) -> xr.DataArray: """Computes metrics from aggregated statistics.""" return statistic_values['TargetPassthrough']
[docs] class WindVectorRMSE(base.Metric): """Computes vector RMSE between two wind components.""" def __init__( self, u_name: Union[str, list[str]], v_name: Union[str, list[str]], vector_name: Union[str, list[str]], ): """Init. Args can be a single string or a list, in which case the statistic will be computed separately for the different elements in the list. For example, `u_name=['u_component_of_wind', '10m_u_component_of_wind_10m']`. Args: u_name: Name of the u wind component, e.g. `u_component_of_wind`. v_name: Name of the v wind component, e.g. `v_component_of_wind`. vector_name: Name to give output variable, e.g. `wind`. """ self._u_name = [u_name] if isinstance(u_name, str) else u_name self._v_name = [v_name] if isinstance(v_name, str) else v_name self._vector_name = ( [vector_name] if isinstance(vector_name, str) else vector_name ) if not len(self._u_name) == len(self._v_name) == len(self._vector_name): raise ValueError( 'u_name, v_name, and vector_name must have the same length' ) @property def statistics(self) -> Mapping[Hashable, base.Statistic]: return { 'WindVectorSquaredError': WindVectorSquaredError( self._u_name, self._v_name, self._vector_name ) } def _values_from_mean_statistics_with_internal_names( self, statistic_values: Mapping[str, Mapping[Hashable, xr.DataArray]], ) -> Mapping[Hashable, xr.DataArray]: return xarray_tree.map_structure( np.sqrt, statistic_values['WindVectorSquaredError'] )
[docs] class ACC(base.PerVariableMetric): """Anomaly correlation coefficient.""" def __init__(self, climatology: xr.Dataset): self._climatology = climatology @property def statistics(self): return { 'SquaredPredictionAnomaly': SquaredPredictionAnomaly( climatology=self._climatology ), 'SquaredTargetAnomaly': SquaredTargetAnomaly( climatology=self._climatology ), 'AnomalyCovariance': AnomalyCovariance(climatology=self._climatology), } def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[Hashable, xr.DataArray], ) -> xr.DataArray: """Computes metrics from aggregated statistics.""" return statistic_values['AnomalyCovariance'] / ( np.sqrt(statistic_values['SquaredPredictionAnomaly']) * np.sqrt(statistic_values['SquaredTargetAnomaly']) )
[docs] class PredictionActivity(base.PerVariableMetric): """Activity in predictions defined as the std dev of the prediction anomalies. This is used e.g. by ECMWF: https://arxiv.org/abs/2307.10128 """ def __init__(self, climatology: xr.Dataset): self._climatology = climatology @property def statistics(self): return { 'SquaredPredictionAnomaly': SquaredPredictionAnomaly( climatology=self._climatology ), } def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[Hashable, xr.DataArray], ) -> xr.DataArray: """Computes metrics from aggregated statistics.""" return np.sqrt(statistic_values['SquaredPredictionAnomaly'])