Source code for weatherbenchX.metrics.categorical

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Implementations of categorical metrics."""

from typing import Hashable, Mapping, Sequence, Union, final

import numpy as np
from weatherbenchX.metrics import base
from weatherbenchX.metrics import wrappers
import xarray as xr
import xarray.ufuncs as xu


[docs] class TruePositives(base.PerVariableStatistic): """True positives from binary predictions and targets.""" @property def unique_name(self) -> str: return 'TruePositives' def _compute_per_variable( self, predictions: xr.DataArray, targets: xr.DataArray, ) -> xr.DataArray: return ( (predictions.astype(bool) * targets.astype(bool)) .where(~xu.isnan(predictions * targets)) .astype(np.float32) )
[docs] class TrueNegatives(base.PerVariableStatistic): """True negatives from binary predictions and targets.""" @property def unique_name(self) -> str: return 'TrueNegatives' def _compute_per_variable( self, predictions: xr.DataArray, targets: xr.DataArray, ) -> xr.DataArray: return ( (~predictions.astype(bool) * ~targets.astype(bool)) .where(~xu.isnan(predictions * targets)) .astype(np.float32) )
[docs] class FalsePositives(base.PerVariableStatistic): """False positives from binary predictions and targets.""" @property def unique_name(self) -> str: return 'FalsePositives' def _compute_per_variable( self, predictions: xr.DataArray, targets: xr.DataArray, ) -> xr.DataArray: return ( (predictions.astype(bool) * ~targets.astype(bool)) .where(~xu.isnan(predictions * targets)) .astype(np.float32) )
[docs] class FalseNegatives(base.PerVariableStatistic): """False negatives from binary predictions and targets.""" @property def unique_name(self) -> str: return 'FalseNegatives' def _compute_per_variable( self, predictions: xr.DataArray, targets: xr.DataArray, ) -> xr.DataArray: return ( (~predictions.astype(bool) * targets.astype(bool)) .where(~xu.isnan(predictions * targets)) .astype(np.float32) )
[docs] class SEEPS(base.Statistic): """Computes Stable Equitable Error in Probability Space. Definition in Rodwell et al. (2010): https://www.ecmwf.int/en/elibrary/76205-new-equitable-score-suitable-verifying-precipitation-nwp Important: In most cases, the statistic will contain NaNs because of the masking of high and low p1 values. For this reason, a `mask` coordinate will be added to the resulting statistic to be used in combination with `masked=True` in the aggregator. If a mask already exists in either the predictions or targets, it will be combined with the p1 mask. """ def __init__( self, variables: Sequence[str], climatology: xr.Dataset, dry_threshold_mm: Union[float, Sequence[float]] = 0.25, min_p1: Union[float, Sequence[float]] = 0.1, max_p1: Union[float, Sequence[float]] = 0.85, ): # pyformat: disable """Init. Args: variables: List of precipitation variables to compute SEEPS for. climatology: Climatology containing `*_seeps_dry_fraction` and `*_seeps_threshold` for each of the precipitation variables with dimensions `dayofyear` and `hour`, as well as `latitude` and `longitude` corresponding to the predictions/targets coordinates, see example below. dry_threshold_mm: Values smaller or equal are considered dry. Unit: mm. Can be list for each variable. Must be same length. Default: 0.25 min_p1: Mask out p1 values below this threshold. Can be list for each variable. Default: 0.1 max_p1: Mask out p1 values above this threshold. Can be list for each variable. Default: 0.85 Example: >>> climatology <xarray.Dataset> Size: 24MB Dimensions: (hour: 4, dayofyear: 366, longitude: 64, latitude: 32) Coordinates: * dayofyear (dayofyear) int64 3kB 1 ... 366 * hour (hour) int64 32B 0 6 12 18 * latitude (latitude) float64 256B -87.1... * longitude (longitude) float64 512B 0.0 ... Data variables: total_precipitation_6hr_seeps_dry_fraction (hour, dayofyear, longitude, latitude) ... total_precipitation_6hr_seeps_threshold (hour, dayofyear, longitude, latitude) ... """ # pyformat: enable self._variables = variables self._climatology = climatology self._dry_threshold_mm = ( dry_threshold_mm if isinstance(dry_threshold_mm, Sequence) else [dry_threshold_mm] * len(variables) ) self._min_p1 = ( min_p1 if isinstance(min_p1, Sequence) else [min_p1] * len(variables) ) self._max_p1 = ( max_p1 if isinstance(max_p1, Sequence) else [max_p1] * len(variables) ) assert ( len(self._variables) == len(self._dry_threshold_mm) == len(self._min_p1) == len(self._max_p1) ), 'All arguments must have the same length.' @property def unique_name(self) -> str: suffix = ( '_'.join(self._variables) + '_dry_threshold_mm_' + '_'.join([str(s) for s in self._dry_threshold_mm]) + '_min_p1_' + '_'.join([str(s) for s in self._min_p1]) + '_max_p1_' + '_'.join([str(s) for s in self._max_p1]) ) return f'SEEPS_{suffix}' def compute( self, predictions: Mapping[Hashable, xr.DataArray], targets: Mapping[Hashable, xr.DataArray], ) -> Mapping[Hashable, xr.DataArray]: """Maps computation over all variables listed in self._variables.""" out = {} for variable, dry_threshold_mm, min_p1, max_p1 in zip( self._variables, self._dry_threshold_mm, self._min_p1, self._max_p1 ): out[variable] = self._compute_seeps_per_variable( predictions[variable], targets[variable], variable, dry_threshold_mm, min_p1, max_p1, ) return out def _convert_precip_to_seeps_cat( self, da: xr.DataArray, wet_threshold_for_valid_time: xr.DataArray, dry_threshold_mm: float, ): """Helper function for SEEPS computation. Converts values to categories.""" # Convert to SI units [meters] dry_threshold = dry_threshold_mm / 1000.0 dry = da <= dry_threshold light = xu.logical_and( da > dry_threshold, da < wet_threshold_for_valid_time ) heavy = da >= wet_threshold_for_valid_time result = xr.concat( [dry, light, heavy], dim=xr.DataArray(['dry', 'light', 'heavy'], dims=['seeps_cat']), ) # Convert NaNs back to NaNs. .where() will convert to float type. # Note that in the WB2 implementation, there was an additional # .astype('int') before the .where(). It seems to work fine without it # though. result = result.where(da.notnull()) return result def _compute_seeps_per_variable( self, predictions: xr.DataArray, targets: xr.DataArray, variable: str, dry_threshold_mm: float, min_p1: float, max_p1: float, ) -> xr.DataArray: valid_time = predictions.init_time + predictions.lead_time wet_threshold = self._climatology[f'{variable}_seeps_threshold'] wet_threshold_for_valid_time = wet_threshold.sel( dayofyear=valid_time.dt.dayofyear, hour=valid_time.dt.hour ).load() predictions_cat = self._convert_precip_to_seeps_cat( predictions, wet_threshold_for_valid_time, dry_threshold_mm ) targets_cat = self._convert_precip_to_seeps_cat( targets, wet_threshold_for_valid_time, dry_threshold_mm ) # Compute contingency table out = ( predictions_cat.rename({'seeps_cat': 'forecast_cat'}) * targets_cat.rename({'seeps_cat': 'truth_cat'}) ).compute() p1 = ( self._climatology[f'{variable}_seeps_dry_fraction'] .mean(('hour', 'dayofyear')) .compute() ) # Compute scoring matrix # The contingency table and p1 should have matching spatial dimensions. scoring_matrix = [ [xr.zeros_like(p1), 1 / (1 - p1), 4 / (1 - p1)], [1 / p1, xr.zeros_like(p1), 3 / (1 - p1)], [ 1 / p1 + 3 / (2 + p1), 3 / (2 + p1), xr.zeros_like(p1), ], ] das = [] for mat in scoring_matrix: das.append(xr.concat(mat, dim=out.truth_cat)) scoring_matrix = 0.5 * xr.concat(das, dim=out.forecast_cat) scoring_matrix = scoring_matrix.compute() # Take dot product result = xr.dot(out, scoring_matrix, dims=('forecast_cat', 'truth_cat')) # Mask out p1 thresholds mask = (p1 >= min_p1) & (p1 <= max_p1) result = result.where(mask, np.nan) # Add NaN mask. If mask coordinate already exists, combine them. if hasattr(predictions, 'mask'): if hasattr(targets, 'mask'): raise ValueError( 'Both predictions and targets have masks. This should not happen.' ) mask = mask & predictions.mask elif hasattr(targets, 'mask'): mask = mask & targets.mask result.coords['mask'] = mask return result
class RankedProbabilityScore(base.PerVariableStatistic): """Ranked probability score for cumulative distribution functions. Given a ground truth scalar random variable Y, a prediction random variable X, a sequence of bin boundaries b_0 < b_1 < ... < b_k, where b_0 = -inf and b_K = +inf, the Ranked Probability Score is defined as RPS = E[ Σk (CDF(Y)(b_k) - CDF(X)(b_k))^2 ] where the sum over k is taken over k = 1, 2, ..., K, and CDF(X) and CDF(Y) are the cumulative distribution functions of X and Y, respectively. Here it is assumed that the predictions and targets already represent the CDF in the `bin_dim` dimension. For an implementation that computes the RPS from ensemble predictions, see `probabilistic.EnsembleRankedProbabilityScore`. """ def __init__( self, bin_dim: str, ): self._bin_dim = bin_dim @property def unique_name(self) -> str: return 'RankedProbabilityScore' def _compute_per_variable( self, predictions: xr.DataArray, targets: xr.DataArray, ) -> xr.DataArray: return ((predictions - targets) ** 2).sum(self._bin_dim) # Metrics
[docs] class CSI(base.PerVariableMetric): """Critical Success Index. Also called Threat Score (TS). CSI = (TP / (TP + FP + FN)). """ @property def statistics(self) -> Mapping[str, base.Statistic]: return { 'TruePositives': TruePositives(), 'FalsePositives': FalsePositives(), 'FalseNegatives': FalseNegatives(), } def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[str, xr.DataArray], ) -> xr.DataArray: """Computes metrics from aggregated statistics.""" return statistic_values['TruePositives'] / ( statistic_values['TruePositives'] + statistic_values['FalsePositives'] + statistic_values['FalseNegatives'] )
[docs] class Accuracy(base.PerVariableMetric): """Accuracy. ACC = (TP + TN) / (TP + FP + FN + TN). """ @property def statistics(self) -> Mapping[str, base.Statistic]: return { 'TruePositives': TruePositives(), 'FalsePositives': FalsePositives(), 'FalseNegatives': FalseNegatives(), 'TrueNegatives': TrueNegatives(), } def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[str, xr.DataArray], ) -> xr.DataArray: """Computes metrics from aggregated statistics.""" return ( statistic_values['TruePositives'] + statistic_values['TrueNegatives'] ) / ( statistic_values['TruePositives'] + statistic_values['FalsePositives'] + statistic_values['FalseNegatives'] + statistic_values['TrueNegatives'] )
[docs] class Recall(base.PerVariableMetric): """Also called True Positive Rate (TPR) or Sensitivity. Recall = TP / (TP + FN). """ @property def statistics(self) -> Mapping[str, base.Statistic]: return { 'TruePositives': TruePositives(), 'FalseNegatives': FalseNegatives(), } def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[str, xr.DataArray], ) -> xr.DataArray: """Computes metrics from aggregated statistics.""" return statistic_values['TruePositives'] / ( statistic_values['TruePositives'] + statistic_values['FalseNegatives'] )
class FalseAlarmRate(base.PerVariableMetric): """False Alarm Rate. FAR = FP / (TP + FP). """ @property def statistics(self) -> Mapping[str, base.Statistic]: return { 'TruePositives': TruePositives(), 'FalsePositives': FalsePositives(), } def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[str, xr.DataArray], ) -> xr.DataArray: """Computes metrics from aggregated statistics.""" return statistic_values['FalsePositives'] / ( statistic_values['TruePositives'] + statistic_values['FalsePositives'] )
[docs] class Precision(base.PerVariableMetric): """Also called Positive Predictive Value (PPV). Precision = TP / (TP + FP). """ @property def statistics(self) -> Mapping[str, base.Statistic]: return { 'TruePositives': TruePositives(), 'FalsePositives': FalsePositives(), } def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[str, xr.DataArray], ) -> xr.DataArray: """Computes metrics from aggregated statistics.""" return statistic_values['TruePositives'] / ( statistic_values['TruePositives'] + statistic_values['FalsePositives'] )
[docs] class F1Score(base.PerVariableMetric): """F1 score. F1 = 2 * Precision * Recall / (Precision + Recall) = 2 * TP / (2 * TP + FP + FN). """ @property def statistics(self) -> Mapping[str, base.Statistic]: return { 'TruePositives': TruePositives(), 'FalsePositives': FalsePositives(), 'FalseNegatives': FalseNegatives(), } def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[str, xr.DataArray], ) -> xr.DataArray: """Computes metrics from aggregated statistics.""" return ( 2 * statistic_values['TruePositives'] / ( 2 * statistic_values['TruePositives'] + statistic_values['FalsePositives'] + statistic_values['FalseNegatives'] ) )
[docs] class FrequencyBias(base.PerVariableMetric): """Frequency bias. FB = PP / P = (TP + FP) / (TP + FN) """ @property def statistics(self) -> Mapping[str, base.Statistic]: return { 'TruePositives': TruePositives(), 'FalsePositives': FalsePositives(), 'FalseNegatives': FalseNegatives(), } def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[str, xr.DataArray], ) -> xr.DataArray: """Computes metrics from aggregated statistics.""" return ( statistic_values['TruePositives'] + statistic_values['FalsePositives'] ) / (statistic_values['TruePositives'] + statistic_values['FalseNegatives'])
class HSS(base.PerVariableMetric): """Heidke Skill Score. HSS = 2 * (TP * TN - FP * FN) / ((TP + FN) * (FN + TN) + (TP + FP) * (FP + TN)) """ @property def statistics(self) -> Mapping[str, base.Statistic]: return { 'TruePositives': TruePositives(), 'FalsePositives': FalsePositives(), 'FalseNegatives': FalseNegatives(), 'TrueNegatives': TrueNegatives(), } def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[str, xr.DataArray], ) -> xr.DataArray: """Computes metrics from aggregated statistics.""" tp = statistic_values['TruePositives'] tn = statistic_values['TrueNegatives'] fp = statistic_values['FalsePositives'] fn = statistic_values['FalseNegatives'] numerator = 2 * (tp * tn - fp * fn) denominator = (tp + fn) * (fn + tn) + (tp + fp) * (fp + tn) return numerator / denominator class ETS(base.PerVariableMetric): """Equitable Threat Score (also called Gilbert Skill Score). ETS = (TP - TP_random) / (TP + FP + FN - TP_random) where TP_random = ((TP + FP) * (TP + FN)) / (TP + FP + FN + TN). """ @property def statistics(self) -> Mapping[str, base.Statistic]: return { 'TruePositives': TruePositives(), 'FalsePositives': FalsePositives(), 'FalseNegatives': FalseNegatives(), 'TrueNegatives': TrueNegatives(), } def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[str, xr.DataArray], ) -> xr.DataArray: """Computes metrics from aggregated statistics.""" tp = statistic_values['TruePositives'] tn = statistic_values['TrueNegatives'] fp = statistic_values['FalsePositives'] fn = statistic_values['FalseNegatives'] tp_plus_fp = tp + fp tp_plus_fn = tp + fn all_sum = tp + fp + fn + tn tp_random = (tp_plus_fp * tp_plus_fn) / all_sum numerator = tp - tp_random denominator = tp + fp + fn - tp_random return numerator / denominator class SEDI(base.PerVariableMetric): """Symmetric extremal dependency index. SEDI = (ln(F) - ln(H) + ln(1-H) - ln(1-F)) / (ln(H) + ln(F) + ln(1-H) + ln(1-F)) where H = TP/(TP+FN) (hit rate) and F = FP/(FP+TN) (false alarm rate). See Ferro and Stephenson (2011) https://journals.ametsoc.org/view/journals/wefo/26/5/waf-d-10-05030_1.pdf """ @property def statistics(self) -> Mapping[str, base.Statistic]: return { 'TruePositives': TruePositives(), 'FalsePositives': FalsePositives(), 'FalseNegatives': FalseNegatives(), 'TrueNegatives': TrueNegatives(), } def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[str, xr.DataArray], ) -> xr.DataArray: """Computes metrics from aggregated statistics.""" tp = statistic_values['TruePositives'] tn = statistic_values['TrueNegatives'] fp = statistic_values['FalsePositives'] fn = statistic_values['FalseNegatives'] h = tp / (tp + fn) f = fp / (fp + tn) # Clip rates to avoid log(0) errors and division by zero, following # Ferro and Stephenson (2011) h = h.clip(1e-6, 1 - 1e-6) f = f.clip(1e-6, 1 - 1e-6) log_h = xu.log(h) log_f = xu.log(f) log_1_minus_h = xu.log(1 - h) log_1_minus_f = xu.log(1 - f) numerator = log_f - log_h + log_1_minus_h - log_1_minus_f denominator = log_h + log_f + log_1_minus_h + log_1_minus_f return numerator / denominator class Reliability(base.PerVariableMetric): """Reliability / calibration curve. E.g. see https://scikit-learn.org/stable/auto_examples/calibration/plot_calibration_curve.html. This metric should be used for binary ground truth and probability predictions. It will automatically apply binning to the predictions into 10 equal-width bins, assuming the predictions are in [0, 1]. You can modify the number of bins and the bin edges by passing in `bin_values` and `bin_dim`. For each bin of predicted probabilities, the metric will compute the probability of the positive class according to the ground truth. """ def __init__( self, bin_values: Sequence[float] = ( -np.inf, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ), bin_dim: str = 'reliability_bin', statistic_suffix: str | None = None, ): self._bin_values = bin_values self._bin_dim = bin_dim self._unique_name_suffix = statistic_suffix @property def statistics(self) -> Mapping[str, base.Statistic]: binned_prediction_wrapper = wrappers.ContinuousToBins( which='predictions', bin_values=self._bin_values, bin_dim=self._bin_dim, unique_name_suffix=self._unique_name_suffix, ) return { 'TruePositives': wrappers.WrappedStatistic( TruePositives(), binned_prediction_wrapper ), 'FalsePositives': wrappers.WrappedStatistic( FalsePositives(), binned_prediction_wrapper ), } def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[str, xr.DataArray], ) -> xr.DataArray: """Computes metrics from aggregated statistics.""" return statistic_values['TruePositives'] / ( statistic_values['TruePositives'] + statistic_values['FalsePositives'] ) class Confident(base.PerVariableStatisticWithClimatology): """Forecast confidence. Whether the prediction spread < threshold * climatological spread. """ def __init__( self, ensemble_dim: str, climatology: xr.Dataset, spread_quantile_boundaries: tuple[float, float] = (0.1, 0.9), confidence_threshold: float = 0.7, ): super().__init__(climatology) self._ensemble_dim = ensemble_dim self._spread_low, self._spread_high = spread_quantile_boundaries self._confidence_threshold = confidence_threshold @property def unique_name(self) -> str: return ( 'Confident' + f'_conf_thres={self._confidence_threshold}' + f'_spread_low={self._spread_low}' + f'_spread_high={self._spread_high}' ) def _compute_per_variable_with_aligned_climatology( self, predictions: xr.DataArray, targets: xr.DataArray, aligned_climatology: xr.DataArray, ) -> xr.DataArray: """Computes confidence per variable.""" del targets # Unused. # Get the spread of the predictions. predictions_spread = predictions.quantile( self._spread_high, dim=self._ensemble_dim ) - predictions.quantile(self._spread_low, dim=self._ensemble_dim) # Climatologies are already quantiles. climatology_spread = aligned_climatology.sel( quantile=self._spread_high ) - aligned_climatology.sel(quantile=self._spread_low) return predictions_spread < self._confidence_threshold * climatology_spread class Covered(base.PerVariableStatistic): """Forecast coverage. Whether the target lies within a prediction interval with specified quantile boundaries. """ def __init__( self, ensemble_dim: str, interval_quantile_boundaries: tuple[float, float] = (0.1, 0.9), ): self._ensemble_dim = ensemble_dim self._interval_low, self._interval_high = interval_quantile_boundaries @property def unique_name(self) -> str: return ( 'Covered' + f'_interval_low={self._interval_low}' + f'_interval_high={self._interval_high}' ) def _compute_per_variable( self, predictions: xr.DataArray, targets: xr.DataArray, ) -> xr.DataArray: """Computes coverage per variable.""" predictions_low = predictions.quantile( self._interval_low, dim=self._ensemble_dim ) predictions_high = predictions.quantile( self._interval_high, dim=self._ensemble_dim ) return (predictions_low <= targets) & (targets <= predictions_high) class JaccardDistant(base.PerVariableStatisticWithClimatology): """Thresholded Jaccard distance of prediction interval from climatology. Whether the Jaccard distance between the forecast prediction interval and climatology prediction interval is greater than a threshold. Jaccard Distance is defined as 1 - |A ∩ B| / |A ∪ B|, where A is the set of points in the forecast interval and B is the set of points in the climatology interval. """ def __init__( self, ensemble_dim: str, climatology: xr.Dataset, threshold: float = 0.75, interval_quantile_boundaries: tuple[float, float] = (0.1, 0.9), ): super().__init__(climatology) self._ensemble_dim = ensemble_dim self._threshold = threshold self._interval_low, self._interval_high = interval_quantile_boundaries @property def unique_name(self) -> str: return ( 'JaccardDistant' + f'_threshold={self._threshold}' + f'_interval_low={self._interval_low}' + f'_interval_high={self._interval_high}' ) def _compute_per_variable_with_aligned_climatology( self, predictions: xr.DataArray, targets: xr.DataArray, aligned_climatology: xr.DataArray, ) -> xr.DataArray: """Computes jaccard distance per variable.""" del targets # Unused. predictions_low = predictions.quantile( self._interval_low, dim=self._ensemble_dim ) predictions_high = predictions.quantile( self._interval_high, dim=self._ensemble_dim ) climatology_low = aligned_climatology.sel(quantile=self._interval_low) climatology_high = aligned_climatology.sel(quantile=self._interval_high) # A ∩ B = max(min(A), min(B)) - min(max(A), max(B)) max_of_lows = xu.maximum(predictions_low, climatology_low) min_of_highs = xu.minimum(predictions_high, climatology_high) # Length of intersection is the difference. If they don't overlap, # this difference could be negative, so we take the max with 0. intersection_length = xu.maximum(0, min_of_highs - max_of_lows) # |A ∪ B| = |A| + |B| - |A ∩ B| predictions_interval_length = predictions_high - predictions_low climatology_interval_length = climatology_high - climatology_low union_length = ( predictions_interval_length + climatology_interval_length ) - intersection_length # We need to handle the case where union_length is 0. This occurs if both # intervals are identical single points (e.g., [5, 5] and [5, 5]). In this # specific case, the Jaccard Index should be 1 (perfect overlap). Note that # in this case the intersection_length will also be 0. jaccard_index = xr.where( union_length > 0, intersection_length / union_length, 1.0 ) jaccard_distance = 1 - jaccard_index return jaccard_distance > self._threshold class Opportunism(base.PerVariableMetric): """Opporunism. Fraction of forecast that is (un)confident, (un)covered, and (un)jaccard-distant. """ def __init__( self, ensemble_dim: str, climatology: xr.Dataset, is_confident: bool, is_covered: bool | None = None, is_jaccard_distant: bool | None = None, confidence_quantile_boundaries: tuple[float, float] = (0.1, 0.9), coverage_quantile_boundaries: tuple[float, float] = (0.1, 0.9), jaccard_distance_quantile_boundaries: tuple[float, float] = (0.1, 0.9), confidence_threshold: float = 0.7, jaccard_distance_threshold: float = 0.75, ): """Initializes the Opportunism metric. Args: ensemble_dim: The dimension name of the ensemble. climatology: The climatology dataset. is_confident: Whether to compute if the forecast is confident or not in the metric. is_covered: Whether to compute if the forecast is covered or not in the metric. If not set, the coverage will not be computed. is_jaccard_distant: Whether to compute if the forecast is jaccard-distant or not in the metric. If not set, the jaccard-distance will not be computed. confidence_quantile_boundaries: The quantiles boundaries to use. coverage_quantile_boundaries: The quantiles boundaries to use. jaccard_distance_quantile_boundaries: The quantiles boundaries to use. confidence_threshold: The threshold to use for confidence. jaccard_distance_threshold: The threshold to use for jaccard-distance. """ self._is_confident = is_confident self._is_covered = is_covered self._is_jaccard_distant = is_jaccard_distant self._ensemble_dim = ensemble_dim self._climatology = climatology self._confidence_quantile_boundaries = confidence_quantile_boundaries self._coverage_quantile_boundaries = coverage_quantile_boundaries self._jaccard_distance_quantile_boundaries = ( jaccard_distance_quantile_boundaries ) self._confidence_threshold = confidence_threshold self._jaccard_distance_threshold = jaccard_distance_threshold @final @property def statistics(self) -> Mapping[str, base.Statistic]: # Always compute confidence. statistics = { 'Confident': Confident( ensemble_dim=self._ensemble_dim, climatology=self._climatology, spread_quantile_boundaries=self._confidence_quantile_boundaries, confidence_threshold=self._confidence_threshold, ), } # Conditionally compute coverage and jaccard-distance if they're actually # being used. if self._is_covered is not None: statistics['Covered'] = Covered( ensemble_dim=self._ensemble_dim, interval_quantile_boundaries=self._coverage_quantile_boundaries, ) if self._is_jaccard_distant is not None: statistics['JaccardDistant'] = JaccardDistant( ensemble_dim=self._ensemble_dim, climatology=self._climatology, threshold=self._jaccard_distance_threshold, interval_quantile_boundaries=self._jaccard_distance_quantile_boundaries, ) return statistics def _values_from_mean_statistics_per_variable( self, statistic_values: Mapping[str, xr.DataArray], ) -> xr.DataArray: """Computes opportunism per variable.""" confident = statistic_values['Confident'] if self._is_confident: statistics_values = confident else: statistics_values = 1 - confident if self._is_covered is not None: covered = statistic_values['Covered'] if self._is_covered: statistics_values = statistics_values * covered else: statistics_values = statistics_values * (1 - covered) if self._is_jaccard_distant is not None: jaccard_distant = statistic_values['JaccardDistant'] if self._is_jaccard_distant: statistics_values = statistics_values * jaccard_distant else: statistics_values = statistics_values * (1 - jaccard_distant) return statistics_values