Wrap metrics/compute binary metrics¶
If we want to compute binary metrics like the CSI (Critical Success Index) from real valued forecasts, these need to be thresholded first. For this, we wrap the metrics with input transforms since the metrics/statistics expect the data to be in binary format already.
Let’s take an example of an ensemble forecast that we want to compute the CSI for. Doing this requires several transforms on the prediction and target data.
The continuous, real-valued forecasts need to be converted to binary forecasts based on a threshold value. (In this case for total_precipitation). Then the binary ensembles have to be averaged to produce a probability forecast for each of the thresholds. Finally, the probability forecasts have to be thresholded by probability values to produce a binary output that we can compute the CSI for. Let’s load the data and apply all the wrappers around the CSI metric.
# IMPORTANT: If you are running this on Colab, uncomment the cell below to access the cloud datasets.
# from google.colab import auth
# auth.authenticate_user()
import numpy as np
from weatherbenchX import aggregation
from weatherbenchX.data_loaders import xarray_loaders
from weatherbenchX.metrics import categorical
from weatherbenchX.metrics import wrappers
prediction_path = 'gs://weatherbench2/datasets/ifs_ens/2018-2022-64x32_equiangular_conservative.zarr'
target_path = 'gs://weatherbench2/datasets/era5/1959-2022-6h-64x32_equiangular_conservative.zarr'
variables = ['total_precipitation_6hr']
target_data_loader = xarray_loaders.TargetsFromXarray(
path=target_path,
variables=variables,
)
prediction_data_loader = xarray_loaders.PredictionsFromXarray(
path=prediction_path,
variables=variables,
)
init_times = np.array(['2020-01-01T00'], dtype='datetime64[ns]')
lead_times = np.array([6], dtype='timedelta64[h]').astype('timedelta64[ns]') # To silence xr warnings.
target_chunk = target_data_loader.load_chunk(init_times, lead_times)
prediction_chunk = prediction_data_loader.load_chunk(init_times, lead_times)
target_chunk
<xarray.Dataset> Size: 9kB
Dimensions: (latitude: 32, longitude: 64, init_time: 1,
lead_time: 1)
Coordinates:
* latitude (latitude) float64 256B -87.19 -81.56 ... 87.19
* longitude (longitude) float64 512B 0.0 5.625 ... 348.8 354.4
valid_time (init_time, lead_time) datetime64[ns] 8B 2020-01...
* init_time (init_time) datetime64[ns] 8B 2020-01-01
* lead_time (lead_time) timedelta64[ns] 8B 06:00:00
Data variables:
total_precipitation_6hr (init_time, lead_time, longitude, latitude) float32 8kB ...
Attributes:
long_name: Total precipitation
short_name: tp
units: mprediction_chunk
<xarray.Dataset> Size: 411kB
Dimensions: (latitude: 32, longitude: 64, number: 50,
lead_time: 1, init_time: 1)
Coordinates:
* latitude (latitude) float64 256B -87.19 -81.56 ... 87.19
* longitude (longitude) float64 512B 0.0 5.625 ... 348.8 354.4
* number (number) int32 200B 1 2 3 4 5 6 ... 46 47 48 49 50
* lead_time (lead_time) timedelta64[ns] 8B 06:00:00
* init_time (init_time) datetime64[ns] 8B 2020-01-01
Data variables:
total_precipitation_6hr (init_time, number, lead_time, longitude, latitude) float32 410kB ...Note that the wrappers are applied in the order of the given list, so in this case ContinuousToBinary is applied first.
wrapped_csi = wrappers.WrappedMetric(
metric=categorical.CSI(),
transforms=[
wrappers.ContinuousToBinary(
which='both',
threshold_value=[1/1000, 5/1000], # Raw values are in m
threshold_dim='threshold_precipitation'
),
wrappers.EnsembleMean(
which='predictions', ensemble_dim='number'
),
wrappers.ContinuousToBinary(
which='predictions',
threshold_value=[0.25, 0.75],
threshold_dim='threshold_probability'
),
],
)
metrics = {'csi': wrapped_csi}
aggregator = aggregation.Aggregator(
reduce_dims=['init_time', 'latitude', 'longitude'],
)
aggregation.compute_metric_values_for_single_chunk(
metrics,
aggregator,
prediction_chunk,
target_chunk
)
<xarray.Dataset> Size: 72B
Dimensions: (lead_time: 1, threshold_precipitation: 2,
threshold_probability: 2)
Coordinates:
* lead_time (lead_time) timedelta64[ns] 8B 06:00:00
* threshold_precipitation (threshold_precipitation) float64 16B 0.001 ...
* threshold_probability (threshold_probability) float64 16B 0.25 0.75
Data variables:
csi.total_precipitation_6hr (lead_time, threshold_precipitation, threshold_probability) float64 32B ...As we can see the final result has two additional dimensions: threshold_precipitation and threshold_probability.