Implement a new metric

import numpy as np
import xarray as xr
from weatherbenchX.metrics import base
from weatherbenchX.metrics import deterministic

Metrics in WeatherBench-X are defined by a set of statistics and instructions how to compute the final metrics value from the averaged statistics.

Statistics are computed from the predictions and targets for each element. Further, statistics are divided into single variable statistics (computed separately for each variable; most common use case) and multi-variate statistics (where statistics are computed as a function of several variables).

As a simple example, let’s take the RMSE. Here, the statistic in the squared error which is a per-variable computation.

class SquaredError(base.PerVariableStatistic):
  """Squared error between predictions and targets."""

  def compute_per_variable(
      self,
      predictions: xr.DataArray,
      targets: xr.DataArray,
  ) -> xr.DataArray:
    return (predictions - targets) ** 2

The RMSE metric specifies the SquaredError statistic and takes the square root over it from the aggregated values.

class RMSE(base.PerVariableMetric):
  """Root mean squared error."""

  @property
  def statistics(self) -> Mapping[str, base.Statistic]:
    return {'SquaredError': SquaredError()}

  def _values_from_mean_statistics_per_variable(
      self,
      statistic_values: Mapping[str, xr.DataArray],
  ) -> xr.DataArray:
    """Computes metrics from aggregated statistics."""
    return np.sqrt(statistic_values['SquaredError'])
predictions = xr.Dataset({'2m_temperature': xr.DataArray(np.ones((2, 32, 64)), dims=['init_time', 'latitude', 'longitude'])})
targets = predictions.copy()
predictions
<xarray.Dataset> Size: 33kB
Dimensions:         (init_time: 2, latitude: 32, longitude: 64)
Dimensions without coordinates: init_time, latitude, longitude
Data variables:
    2m_temperature  (init_time, latitude, longitude) float64 33kB 1.0 ... 1.0
rmse = deterministic.RMSE()
statistic_values = {name: statistic.compute(predictions, targets) for name, statistic in rmse.statistics.items()}
statistic_values
{'SquaredError': {'2m_temperature': <xarray.DataArray '2m_temperature' (init_time: 2, latitude: 32, longitude: 64)> Size: 33kB
  array([[[0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          ...,
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.]],
  
         [[0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          ...,
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.],
          [0., 0., 0., ..., 0., 0., 0.]]])
  Dimensions without coordinates: init_time, latitude, longitude}}

Take the mean now. Here we do it explicitly for a single metrics. Typically, this would be done in compute_unique_statistics_for_all_metrics.

statistic_values['SquaredError'] = {k: v.mean() for k,v in statistic_values['SquaredError'].items()}
statistic_values
{'SquaredError': {'2m_temperature': <xarray.DataArray '2m_temperature' ()> Size: 8B
  array(0.)}}

Now we can compute the metric (in this case take the square root) from the averaged statistic.

rmse.values_from_mean_statistics(statistic_values)
{'2m_temperature': <xarray.DataArray '2m_temperature' ()> Size: 8B
 array(0.)}

Note: Some metrics can have more than one statistic. See, for example, the ensemble CRPS implementation.