Source code for weatherbenchX.metrics.deterministic

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementation of deterministic metrics and assiciated statistics."""
from absl import logging
from collections.abc import Hashable
from typing import Iterable, Mapping, Sequence, Union
import jax
import numpy as np
import jax.numpy as jnp
from weatherbenchX import xarray_tree
from weatherbenchX.metrics import base
import xarray as xr
import xarray.ufuncs as xu


### Statistics


class RelativeIntensity(base.PerVariableStatistic):
  """Relative intensity of predictions.

  Relative intensity is defined as the ratio of the mean of the predictions
  to the mean of the targets. The metric returns the absolute value of the
  difference between this ratio and the ideal value of 1.

  Helpful in capturing e.g. strobing effects observed in a
  precipitation output. Notably, this metric is intended for predictions and
  targets that are non-negative such as precip.

  """

  def __init__(self, spatial_dims: Sequence[str] = ('latitude', 'longitude')):
    """Init.

    Args:
      spatial_dims: The dimensions to compute the relative intensity over.
    """
    self._spatial_dims = spatial_dims

  def _compute_per_variable(
      self,
      predictions: xr.DataArray,
      targets: xr.DataArray,
      ) -> xr.DataArray:

    spatial_dims = self._spatial_dims
    # Add a small epsilon to both denominator and numerator to avoid division by
    # zero and ensure RIE is 0 when both are 0.
    epsilon = 1e-6

    if 'mask' in targets.coords:
      # If mask is present, we compute mean only over mask==1 region.
      # If any value in that region is NaN, the mean will be NaN.
      # So, we set masked-out values to 0, then sum with skipna=False,
      # and divide by the count of unmasked values.
      mask = targets.mask == 1
      count = mask.sum(dim=spatial_dims, skipna=False)
      prediction_sum = predictions.where(mask, 0).sum(
          dim=spatial_dims, skipna=False
      )
      target_sum = targets.where(mask, 0).sum(dim=spatial_dims, skipna=False)
      prediction_mean = prediction_sum / count
      prediction_mean = prediction_mean.where(count > 0, 0.0)
      target_mean = target_sum / count
      target_mean = target_mean.where(count > 0, 0.0)
      ratio = (prediction_mean + epsilon) / (target_mean + epsilon)
      result = abs(ratio - 1)
      # The mask is 1 if there was at least one valid value in the aggregation.
      # Otherwise, it is 0.
      result.coords['mask'] = (count > 0).astype(int)
    else:
      prediction_mean = predictions.mean(dim=spatial_dims, skipna=False)
      target_mean = targets.mean(dim=spatial_dims, skipna=False)
      ratio = (prediction_mean + epsilon) / (target_mean + epsilon)
      result = abs(ratio - 1)

    return result


[docs] class Error(base.PerVariableStatistic): """Error between predictions and targets.""" def _compute_per_variable( self, predictions: xr.DataArray, targets: xr.DataArray, ) -> xr.DataArray: error = predictions - targets return error
[docs] class AbsoluteError(base.PerVariableStatistic): """Absolute error between predictions and targets.""" def _compute_per_variable( self, predictions: xr.DataArray, targets: xr.DataArray, ) -> xr.DataArray: error = predictions - targets return abs(error)
[docs] class SquaredError(base.PerVariableStatistic): """Squared error between predictions and targets.""" def _compute_per_variable( self, predictions: xr.DataArray, targets: xr.DataArray, ) -> xr.DataArray: return (predictions - targets) ** 2
[docs] class PredictionPassthrough(base.PerVariableStatistic): """Simply returns predictions.""" def __init__(self, copy_nans_from_targets: bool = False): """Init. Args: copy_nans_from_targets: If True, copy any nans from the targets to the predictions. """ self._copy_nans_from_targets = copy_nans_from_targets def _compute_per_variable( self, predictions: xr.DataArray, targets: xr.DataArray, ) -> xr.DataArray: # Make sure potential coordinates from targets are preserved. result = predictions + xr.zeros_like(targets) if self._copy_nans_from_targets: result = result.where(~targets.isnull()) return result
[docs] class TargetPassthrough(base.PerVariableStatistic): """Simply returns targets.""" def __init__(self, copy_nans_from_predictions: bool = False): """Init. Args: copy_nans_from_predictions: If True, copy any nans from the predictions to the predictions. """ self._copy_nans_from_predictions = copy_nans_from_predictions def _compute_per_variable( self, predictions: xr.DataArray, targets: xr.DataArray, ) -> xr.DataArray: # Make sure potential coordinates from predictions are preserved. result = targets + xr.zeros_like(predictions) if self._copy_nans_from_predictions: result = result.where(~predictions.isnull()) return result
[docs] class WindVectorSquaredError(base.Statistic): """Computes squared error between two wind components. SE = (u_pred - u_target) ** 2 + (v_pred - v_target) ** 2 """ def __init__( self, u_name: Sequence[str], v_name: Sequence[str], vector_name: Sequence[str], ): """Init. Args: u_name: Name of the u wind component, e.g. [`u_component_of_wind`]. v_name: Name of the v wind component, e.g. [`v_component_of_wind`]. vector_name: Name to give output variable, e.g. [`wind`]. """ self._u_name = u_name self._v_name = v_name self._vector_name = vector_name if not len(self._u_name) == len(self._v_name) == len(self._vector_name): raise ValueError( 'u_name, v_name, and vector_name must have the same length' ) @property def unique_name(self) -> str: suffix = '_'.join(self._vector_name) return 'WindVectorSquaredError_' + suffix def compute( self, predictions: Mapping[Hashable, xr.DataArray], targets: Mapping[Hashable, xr.DataArray], ) -> Mapping[Hashable, xr.DataArray]: out = {} for u, v, vector in zip(self._u_name, self._v_name, self._vector_name): predictions_u = predictions[u] predictions_v = predictions[v] targets_u = targets[u] targets_v = targets[v] se = (predictions_u - targets_u) ** 2 + (predictions_v - targets_v) ** 2 out[vector] = se return out
[docs] class SquaredPredictionAnomaly(base.PerVariableStatisticWithClimatology): """Computes (predictions - climatology)**2.""" def _compute_per_variable_with_aligned_climatology( self, predictions: xr.DataArray, targets: xr.DataArray, aligned_climatology: xr.DataArray, ) -> xr.DataArray: prediction_anom = predictions - aligned_climatology return prediction_anom**2
[docs] class SquaredTargetAnomaly(base.PerVariableStatisticWithClimatology): """Computes (targets - climatology)**2.""" def _compute_per_variable_with_aligned_climatology( self, predictions: xr.DataArray, targets: xr.DataArray, aligned_climatology: xr.DataArray, ) -> xr.DataArray: target_anom = targets - aligned_climatology return target_anom**2
[docs] class AnomalyCovariance(base.PerVariableStatisticWithClimatology): """Computes (predictions - climatology) * (targets - climatology).""" def _compute_per_variable_with_aligned_climatology( self, predictions: xr.DataArray, targets: xr.DataArray, aligned_climatology: xr.DataArray, ) -> xr.DataArray: prediction_anom = predictions - aligned_climatology target_anom = targets - aligned_climatology return prediction_anom * target_anom
class ErrorExceedance(base.PerVariableStatistic): """Computes absolute errors exceeding thresholds.""" def __init__(self, thresholds: Sequence[float] | xr.DataArray): """Init. Args: thresholds: The thresholds to use for error exceedance. If a list is given then it will be converted to an xr.DataArray with dim `error_exceedance_thresholds`. """ if isinstance(thresholds, Sequence): thresholds = xr.DataArray( thresholds, dims='error_exceedance_thresholds', coords={'error_exceedance_thresholds': thresholds}, ) self._thresholds = thresholds def _compute_per_variable( self, predictions: xr.DataArray, targets: xr.DataArray, ) -> xr.DataArray: abs_error = abs(predictions - targets) if isinstance(self._thresholds, xr.Dataset): thresholds = self._thresholds[abs_error.name] else: thresholds = self._thresholds out = (abs_error > thresholds).astype(float) # Make sure NaNs are preserved out = out.where(~abs_error.isnull()) out = out.where(~thresholds.isnull()) return out ### Metrics # The following metrics are just the mean of a Statistic defined above, and # so we can just use the Statistic directly as the Metric. We provide # convenience aliases here, however: Bias = Error # Bias is the mean Error. MAE = AbsoluteError # MAE is the Mean Absolute Error. MSE = SquaredError # MSE is the Mean Squared Error. PredictionAverage = PredictionPassthrough TargetAverage = TargetPassthrough
[docs] class RMSE(base.PerVariableMetric): """Root mean squared error.""" @property def statistics(self) -> Mapping[str, base.Statistic]: return {'SquaredError': SquaredError()} def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[str, xr.DataArray], ) -> xr.DataArray: """Computes metrics from aggregated statistics.""" return xu.sqrt(statistic_values['SquaredError'])
[docs] class WindVectorRMSE(base.Metric): """Computes vector RMSE between two wind components.""" def __init__( self, u_name: Union[str, list[str]], v_name: Union[str, list[str]], vector_name: Union[str, list[str]], ): """Init. Args can be a single string or a list, in which case the statistic will be computed separately for the different elements in the list. For example, `u_name=['u_component_of_wind', '10m_u_component_of_wind_10m']`. Args: u_name: Name of the u wind component, e.g. `u_component_of_wind`. v_name: Name of the v wind component, e.g. `v_component_of_wind`. vector_name: Name to give output variable, e.g. `wind`. """ self._u_name = [u_name] if isinstance(u_name, str) else u_name self._v_name = [v_name] if isinstance(v_name, str) else v_name self._vector_name = ( [vector_name] if isinstance(vector_name, str) else vector_name ) if not len(self._u_name) == len(self._v_name) == len(self._vector_name): raise ValueError( 'u_name, v_name, and vector_name must have the same length' ) @property def statistics(self) -> Mapping[str, base.Statistic]: return { 'WindVectorSquaredError': WindVectorSquaredError( self._u_name, self._v_name, self._vector_name ) } def values_from_mean_statistics( self, statistic_values: Mapping[str, Mapping[Hashable, xr.DataArray]], ) -> Mapping[Hashable, xr.DataArray]: return xarray_tree.map_structure( xu.sqrt, statistic_values['WindVectorSquaredError'] )
[docs] class ACC(base.PerVariableMetric): """Anomaly correlation coefficient.""" def __init__(self, climatology: xr.Dataset): self._climatology = climatology @property def statistics(self) -> Mapping[str, base.Statistic]: return { 'SquaredPredictionAnomaly': SquaredPredictionAnomaly( climatology=self._climatology ), 'SquaredTargetAnomaly': SquaredTargetAnomaly( climatology=self._climatology ), 'AnomalyCovariance': AnomalyCovariance(climatology=self._climatology), } def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[str, xr.DataArray], ) -> xr.DataArray: """Computes metrics from aggregated statistics.""" return statistic_values['AnomalyCovariance'] / ( xu.sqrt(statistic_values['SquaredPredictionAnomaly']) * xu.sqrt(statistic_values['SquaredTargetAnomaly']) )
[docs] class PredictionActivity(base.PerVariableMetric): """Activity in predictions defined as the std dev of the prediction anomalies. This is used e.g. by ECMWF: https://arxiv.org/abs/2307.10128 """ def __init__(self, climatology: xr.Dataset): self._climatology = climatology @property def statistics(self) -> Mapping[str, base.Statistic]: return { 'SquaredPredictionAnomaly': SquaredPredictionAnomaly( climatology=self._climatology ), } def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[str, xr.DataArray], ) -> xr.DataArray: """Computes metrics from aggregated statistics.""" return xu.sqrt(statistic_values['SquaredPredictionAnomaly'])