# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementations of categorical metrics."""
from typing import Hashable, Mapping, Sequence, Union
import numpy as np
from weatherbenchX.metrics import base
import xarray as xr
[docs]
class TruePositives(base.PerVariableStatistic):
"""True positives from binary predictions and targets."""
@property
def unique_name(self) -> str:
return 'TruePositives'
def _compute_per_variable(
self,
predictions: xr.DataArray,
targets: xr.DataArray,
) -> xr.DataArray:
return (predictions.astype(bool) * targets.astype(bool)).where(
~np.isnan(predictions * targets)
)
[docs]
class TrueNegatives(base.PerVariableStatistic):
"""True negatives from binary predictions and targets."""
@property
def unique_name(self) -> str:
return 'TrueNegatives'
def _compute_per_variable(
self,
predictions: xr.DataArray,
targets: xr.DataArray,
) -> xr.DataArray:
return (~predictions.astype(bool) * ~targets.astype(bool)).where(
~np.isnan(predictions * targets)
)
[docs]
class FalsePositives(base.PerVariableStatistic):
"""False positives from binary predictions and targets."""
@property
def unique_name(self) -> str:
return 'FalsePositives'
def _compute_per_variable(
self,
predictions: xr.DataArray,
targets: xr.DataArray,
) -> xr.DataArray:
return (predictions.astype(bool) * ~targets.astype(bool)).where(
~np.isnan(predictions * targets)
)
[docs]
class FalseNegatives(base.PerVariableStatistic):
"""False negatives from binary predictions and targets."""
@property
def unique_name(self) -> str:
return 'FalseNegatives'
def _compute_per_variable(
self,
predictions: xr.DataArray,
targets: xr.DataArray,
) -> xr.DataArray:
return (~predictions.astype(bool) * targets.astype(bool)).where(
~np.isnan(predictions * targets)
)
[docs]
class SEEPSStatistic(base.Statistic):
"""Computes SEEPS statistic. See metric class for details."""
def __init__(
self,
variables: Sequence[str],
climatology: xr.Dataset,
dry_threshold_mm: Union[float, Sequence[float]] = 0.25,
min_p1: Union[float, Sequence[float]] = 0.1,
max_p1: Union[float, Sequence[float]] = 0.85,
):
self._variables = variables
self._climatology = climatology
self._dry_threshold_mm = (
dry_threshold_mm
if isinstance(dry_threshold_mm, Sequence)
else [dry_threshold_mm] * len(variables)
)
self._min_p1 = (
min_p1 if isinstance(min_p1, Sequence) else [min_p1] * len(variables)
)
self._max_p1 = (
max_p1 if isinstance(max_p1, Sequence) else [max_p1] * len(variables)
)
assert (
len(self._variables)
== len(self._dry_threshold_mm)
== len(self._min_p1)
== len(self._max_p1)
), 'All arguments must have the same length.'
@property
def unique_name(self) -> str:
suffix = (
'_'.join(self._variables)
+ '_dry_threshold_mm_'
+ '_'.join([str(s) for s in self._dry_threshold_mm])
+ '_min_p1_'
+ '_'.join([str(s) for s in self._min_p1])
+ '_max_p1_'
+ '_'.join([str(s) for s in self._max_p1])
)
return f'SEEPS_{suffix}'
def compute(
self,
predictions: Mapping[Hashable, xr.DataArray],
targets: Mapping[Hashable, xr.DataArray],
) -> Mapping[Hashable, xr.DataArray]:
"""Maps computation over all variables listed in self._variables."""
out = {}
for variable, dry_threshold_mm, min_p1, max_p1 in zip(
self._variables, self._dry_threshold_mm, self._min_p1, self._max_p1
):
out[variable] = self._compute_seeps_per_variable(
predictions[variable],
targets[variable],
variable,
dry_threshold_mm,
min_p1,
max_p1,
)
return out
def _convert_precip_to_seeps_cat(
self,
da: xr.DataArray,
wet_threshold_for_valid_time: xr.DataArray,
dry_threshold_mm: float,
):
"""Helper function for SEEPS computation. Converts values to categories."""
# Convert to SI units [meters]
dry_threshold = dry_threshold_mm / 1000.0
dry = da <= dry_threshold
light = np.logical_and(
da > dry_threshold, da < wet_threshold_for_valid_time
)
heavy = da >= wet_threshold_for_valid_time
result = xr.concat(
[dry, light, heavy],
dim=xr.DataArray(['dry', 'light', 'heavy'], dims=['seeps_cat']),
)
# Convert NaNs back to NaNs. .where() will convert to float type.
# Note that in the WB2 implementation, there was an additional
# .astype('int') before the .where(). It seems to work fine without it
# though.
result = result.where(da.notnull())
return result
def _compute_seeps_per_variable(
self,
predictions: xr.DataArray,
targets: xr.DataArray,
variable: str,
dry_threshold_mm: float,
min_p1: float,
max_p1: float,
) -> xr.DataArray:
valid_time = predictions.init_time + predictions.lead_time
wet_threshold = self._climatology[f'{variable}_seeps_threshold']
wet_threshold_for_valid_time = wet_threshold.sel(
dayofyear=valid_time.dt.dayofyear, hour=valid_time.dt.hour
).load()
predictions_cat = self._convert_precip_to_seeps_cat(
predictions, wet_threshold_for_valid_time, dry_threshold_mm
)
targets_cat = self._convert_precip_to_seeps_cat(
targets, wet_threshold_for_valid_time, dry_threshold_mm
)
# Compute contingency table
out = (
predictions_cat.rename({'seeps_cat': 'forecast_cat'})
* targets_cat.rename({'seeps_cat': 'truth_cat'})
).compute()
p1 = (
self._climatology[f'{variable}_seeps_dry_fraction']
.mean(('hour', 'dayofyear'))
.compute()
)
# Compute scoring matrix
# The contingency table and p1 should have matching spatial dimensions.
scoring_matrix = [
[xr.zeros_like(p1), 1 / (1 - p1), 4 / (1 - p1)],
[1 / p1, xr.zeros_like(p1), 3 / (1 - p1)],
[
1 / p1 + 3 / (2 + p1),
3 / (2 + p1),
xr.zeros_like(p1),
],
]
das = []
for mat in scoring_matrix:
das.append(xr.concat(mat, dim=out.truth_cat))
scoring_matrix = 0.5 * xr.concat(das, dim=out.forecast_cat)
scoring_matrix = scoring_matrix.compute()
# Take dot product
result = xr.dot(out, scoring_matrix, dims=('forecast_cat', 'truth_cat'))
# Mask out p1 thresholds
mask = (p1 >= min_p1) & (p1 <= max_p1)
result = result.where(mask, np.nan)
# Add NaN mask. If mask coordinate already exists, combine them.
if hasattr(predictions, 'mask'):
if hasattr(targets, 'mask'):
raise ValueError(
'Both predictions and targets have masks. This should not happen.'
)
mask = mask & predictions.mask
elif hasattr(targets, 'mask'):
mask = mask & targets.mask
result.coords['mask'] = mask
return result
# Metrics
[docs]
class CSI(base.PerVariableMetric):
"""Critical Success Index.
Also called Threat Score (TS).
CSI = (TP / (TP + FP + FN)).
"""
@property
def statistics(self) -> Mapping[Hashable, base.Statistic]:
return {
'TruePositives': TruePositives(),
'FalsePositives': FalsePositives(),
'FalseNegatives': FalseNegatives(),
}
def _values_from_mean_statistics_per_variable(
self,
statistic_values: Mapping[Hashable, xr.DataArray],
) -> xr.DataArray:
"""Computes metrics from aggregated statistics."""
return statistic_values['TruePositives'] / (
statistic_values['TruePositives']
+ statistic_values['FalsePositives']
+ statistic_values['FalseNegatives']
)
[docs]
class Accuracy(base.PerVariableMetric):
"""Accuracy.
ACC = (TP + TN) / (TP + FP + FN + TN).
"""
@property
def statistics(self) -> Mapping[Hashable, base.Statistic]:
return {
'TruePositives': TruePositives(),
'FalsePositives': FalsePositives(),
'FalseNegatives': FalseNegatives(),
'TrueNegatives': TrueNegatives(),
}
def _values_from_mean_statistics_per_variable(
self,
statistic_values: Mapping[Hashable, xr.DataArray],
) -> xr.DataArray:
"""Computes metrics from aggregated statistics."""
return (
statistic_values['TruePositives'] + statistic_values['TrueNegatives']
) / (
statistic_values['TruePositives']
+ statistic_values['FalsePositives']
+ statistic_values['FalseNegatives']
+ statistic_values['TrueNegatives']
)
[docs]
class Recall(base.PerVariableMetric):
"""Also called True Positive Rate (TPR) or Sensitivity.
Recall = TP / (TP + FN).
"""
@property
def statistics(self) -> Mapping[Hashable, base.Statistic]:
return {
'TruePositives': TruePositives(),
'FalseNegatives': FalseNegatives(),
}
def _values_from_mean_statistics_per_variable(
self,
statistic_values: Mapping[Hashable, xr.DataArray],
) -> xr.DataArray:
"""Computes metrics from aggregated statistics."""
return statistic_values['TruePositives'] / (
statistic_values['TruePositives'] + statistic_values['FalseNegatives']
)
[docs]
class Precision(base.PerVariableMetric):
"""Also called Positive Predictive Value (PPV).
Precision = TP / (TP + FP).
"""
@property
def statistics(self) -> Mapping[Hashable, base.Statistic]:
return {
'TruePositives': TruePositives(),
'FalsePositives': FalsePositives(),
}
def _values_from_mean_statistics_per_variable(
self,
statistic_values: Mapping[Hashable, xr.DataArray],
) -> xr.DataArray:
"""Computes metrics from aggregated statistics."""
return statistic_values['TruePositives'] / (
statistic_values['TruePositives'] + statistic_values['FalsePositives']
)
[docs]
class F1Score(base.PerVariableMetric):
"""F1 score.
F1 = 2 * Precision * Recall / (Precision + Recall)
= 2 * TP / (2 * TP + FP + FN).
"""
@property
def statistics(self) -> Mapping[Hashable, base.Statistic]:
return {
'TruePositives': TruePositives(),
'FalsePositives': FalsePositives(),
'FalseNegatives': FalseNegatives(),
}
def _values_from_mean_statistics_per_variable(
self,
statistic_values: Mapping[Hashable, xr.DataArray],
) -> xr.DataArray:
"""Computes metrics from aggregated statistics."""
return (
2
* statistic_values['TruePositives']
/ (
2 * statistic_values['TruePositives']
+ statistic_values['FalsePositives']
+ statistic_values['FalseNegatives']
)
)
[docs]
class FrequencyBias(base.PerVariableMetric):
"""Frequency bias.
FB = PP / P = (TP + FP) / (TP + FN)
"""
@property
def statistics(self) -> Mapping[Hashable, base.Statistic]:
return {
'TruePositives': TruePositives(),
'FalsePositives': FalsePositives(),
'FalseNegatives': FalseNegatives(),
}
def _values_from_mean_statistics_per_variable(
self,
statistic_values: Mapping[Hashable, xr.DataArray],
) -> xr.DataArray:
"""Computes metrics from aggregated statistics."""
return (
statistic_values['TruePositives'] + statistic_values['FalsePositives']
) / (statistic_values['TruePositives'] + statistic_values['FalseNegatives'])
[docs]
class SEEPS(base.Metric):
"""Computes Stable Equitable Error in Probability Space.
Definition in Rodwell et al. (2010):
https://www.ecmwf.int/en/elibrary/76205-new-equitable-score-suitable-verifying-precipitation-nwp
Important: In most cases, the statistic will contain NaNs because of the
masking of high and low p1 values. For this reason, a `mask` coordinate will
be added to the resulting statistic to be used in combination with
`masked=True` in the aggregator. If a mask already exists in either the
predictions or targets, it will be combined with the p1 mask.
"""
def __init__(
self,
variables: Sequence[str],
climatology: xr.Dataset,
dry_threshold_mm: Union[float, Sequence[float]] = 0.25,
min_p1: Union[float, Sequence[float]] = 0.1,
max_p1: Union[float, Sequence[float]] = 0.85,
):
# pyformat: disable
"""Init.
Args:
variables: List of precipitation variables to compute SEEPS for.
climatology: Climatology containing `*_seeps_dry_fraction` and
`*_seeps_threshold` for each of the precipitation variables with
dimensions `dayofyear` and `hour`, as well as `latitude` and `longitude`
corresponding to the predictions/targets coordinates, see example below.
dry_threshold_mm: Values smaller or equal are considered dry. Unit: mm.
Can be list for each variable. Must be same length. Default: 0.25
min_p1: Mask out p1 values below this threshold. Can be list for each
variable. Default: 0.1
max_p1: Mask out p1 values above this threshold. Can be list for each
variable. Default: 0.85
Example:
>>> climatology
<xarray.Dataset> Size: 24MB
Dimensions: (hour: 4, dayofyear: 366,
longitude: 64, latitude: 32)
Coordinates:
* dayofyear (dayofyear) int64 3kB 1 ... 366
* hour (hour) int64 32B 0 6 12 18
* latitude (latitude) float64 256B -87.1...
* longitude (longitude) float64 512B 0.0 ...
Data variables:
total_precipitation_6hr_seeps_dry_fraction (hour, dayofyear, longitude, latitude) ...
total_precipitation_6hr_seeps_threshold (hour, dayofyear, longitude, latitude) ...
"""
# pyformat: enable
self._variables = variables
self._climatology = climatology
self._dry_threshold_mm = dry_threshold_mm
self._min_p1 = min_p1
self._max_p1 = max_p1
@property
def statistics(self) -> Mapping[Hashable, base.Statistic]:
return {
'SEEPSStatistic': SEEPSStatistic(
self._variables,
self._climatology,
self._dry_threshold_mm,
self._min_p1,
self._max_p1,
)
}
def _values_from_mean_statistics_with_internal_names(
self,
statistic_values: Mapping[str, Mapping[Hashable, xr.DataArray]],
) -> Mapping[Hashable, xr.DataArray]:
return statistic_values['SEEPSStatistic']