# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Binning class definitions."""
import abc
from typing import Any, Hashable, Mapping, Optional, Sequence, Tuple, Union
import numpy as np
import xarray as xr
[docs]
class Binning(abc.ABC):
"""Binning base class."""
def __init__(self, bin_dim_name: str):
"""Init.
Args:
bin_dim_name: Name of binning dimension.
"""
self.bin_dim_name = bin_dim_name
@abc.abstractmethod
def create_bin_mask(
self,
statistic: xr.DataArray,
) -> xr.DataArray:
"""Creates a bin mask for a statistic.
It is assumed that all information required to compute bins is included in
the statistics element.
Args:
statistic: Individual DataArray with statistic values.
Returns:
bin_mask: Boolean mask with shape that boradcasts against the statistic
DataArray.
"""
def _create_lat_mask(
lat: xr.DataArray, lat_lims: Tuple[int, int]
) -> xr.DataArray:
"""Computes a boolean mask for a latitude limits region."""
if lat_lims[0] >= lat_lims[1]:
raise ValueError(
f'`lat_lims[0]` must be smaller than `lat_lims[1]`, got {lat_lims}`'
)
return np.logical_and(lat >= lat_lims[0], lat <= lat_lims[1])
def _create_lon_mask(
lon: xr.DataArray, lon_lims: Tuple[int, int]
) -> xr.DataArray:
"""Computes a boolean mask for a longitude limits region."""
# Make sure we are in the [0, 360] interval.
lon = np.mod(lon, 360)
lon_lims = np.mod(lon_lims[0], 360), np.mod(lon_lims[1], 360)
if lon_lims[1] > lon_lims[0]:
# Same as the latitude.
lon_mask = np.logical_and(lon >= lon_lims[0], lon <= lon_lims[1])
else:
# In this case it means we need to wrap longitude around the other side of
# the globe.
lon_mask = np.logical_or(lon <= lon_lims[1], lon >= lon_lims[0])
return lon_mask
def _region_to_mask(
lat: xr.DataArray,
lon: xr.DataArray,
lat_lims: Tuple[int, int],
lon_lims: Tuple[int, int],
) -> xr.DataArray:
"""Computes a boolean mask for a lat/lon limits region."""
lat_mask = _create_lat_mask(lat, lat_lims)
lon_mask = _create_lon_mask(lon, lon_lims)
return np.logical_and(lat_mask, lon_mask)
class LandSea(Binning):
"""Class for land/sea mask binning."""
def __init__(
self,
land_sea_fraction: xr.DataArray,
land_sea_threshold: float = 0.5,
bin_dim_name: str = 'land_sea',
include_global_mask: bool = False,
):
"""Init.
Args:
land_sea_fraction: Floating point land-sea fraction with same latitude/
longitude coordinates as the statistic. 100% land is represented as 1
and 100% sea as 0.
land_sea_threshold: Threshold to classify as land. Computed as
land_sea_fraction >= land_sea_threshold. (Default of 0.5 follows ECMWF
convention).
bin_dim_name: Name of binning dimension. Default: 'land_sea'
include_global_mask: If True, the output mask will consist of ['land',
'sea', 'global'], otherwise ['land', 'sea']. 'global' is the union of
land and sea. Default: False.
"""
super().__init__(bin_dim_name)
# Force to bool to make sure it is a boolean mask.
self._land_mask = land_sea_fraction >= land_sea_threshold
self._include_global_mask = include_global_mask
def create_bin_mask(
self,
statistic: xr.DataArray,
) -> xr.DataArray:
"""Creates a bin mask for a statistic.
Args:
statistic: Individual DataArray with statistic values.
Returns:
bin_mask: Boolean mask with output bins: ['land', 'sea', 'global'].
"""
masks = [self._land_mask, 1 - self._land_mask]
labels = ['land', 'sea']
if self._include_global_mask:
masks.append(xr.ones_like(self._land_mask))
labels.append('global')
masks = xr.concat(
masks,
dim=self.bin_dim_name,
)
masks.coords[self.bin_dim_name] = np.array(labels)
return masks
[docs]
class Regions(Binning):
"""Class for rectangular region binning.
Note that coordinate must be named `latitude` and `longitude`.
"""
def __init__(
self,
regions: Mapping[Hashable, Tuple[Tuple[int, int], Tuple[int, int]]],
bin_dim_name: str = 'region',
land_sea_mask: Optional[xr.DataArray] = None,
):
"""Init.
Args:
regions: Dictionary specifying {name: ((lat_lims), (lon_lims))}.
bin_dim_name: Name of binning dimension. Default: 'region'
land_sea_mask: (Optional) Boolean mask (land = True) with same
latitude/longitude coordinates as the statistic. If provided, for each
region will add a new land-onlybin with the name {region}_land.
"""
super().__init__(bin_dim_name)
self._regions = regions
self._land_sea_mask = land_sea_mask
def _regions_to_masks(
self,
lat: xr.DataArray,
lon: xr.DataArray,
) -> xr.DataArray:
"""Computes and stacks masks for all regions."""
masks = []
for region_name, (lat_lims, lon_lims) in self._regions.items():
mask = _region_to_mask(lat, lon, lat_lims, lon_lims)
mask = mask.expand_dims(dim=self.bin_dim_name, axis=0)
mask.coords[self.bin_dim_name] = np.array([region_name])
masks.append(mask)
return xr.concat(masks, dim=self.bin_dim_name)
def create_bin_mask(
self,
statistic: xr.DataArray,
) -> xr.DataArray:
masks = self._regions_to_masks(statistic.latitude, statistic.longitude)
if self._land_sea_mask is not None:
assert np.array_equal(
np.sort(masks.latitude), np.sort(self._land_sea_mask.latitude)
) and np.array_equal(
masks.longitude, self._land_sea_mask.longitude
), 'Land/sea mask coordinates do not match.'
land_masks = masks * self._land_sea_mask.astype(bool)
region_names = [f'{r}_land' for r in masks.coords[self.bin_dim_name].data]
land_masks.coords[self.bin_dim_name] = np.array(region_names)
masks = xr.concat([masks, land_masks], dim=self.bin_dim_name)
return masks
class LatitudeBins(Binning):
"""Class for binning by latitude bands."""
def __init__(
self,
degrees: float,
lat_range: Tuple[int, int] = (-90, 90),
bin_dim_name: str = 'latitude_bins',
):
"""Init.
Args:
degrees: Grid spacing in degrees.
lat_range: Tuple of (min_lat, max_lat).
bin_dim_name: Name of binning dimension.
"""
super().__init__(bin_dim_name)
self._degrees = degrees
self._lat_bins = np.arange(
lat_range[0], lat_range[1] + self._degrees, self._degrees
)
def create_bin_mask(
self,
statistic: xr.DataArray,
) -> xr.DataArray:
"""Creates a bin mask for a statistic."""
masks = []
for lat_start in self._lat_bins[:-1]:
lat_end = lat_start + self._degrees
mask = _create_lat_mask(
statistic.latitude,
(lat_start, lat_end),
)
# Broadcast the mask to the shape of statistic
mask = mask.broadcast_like(statistic)
mask = mask.expand_dims(dim=self.bin_dim_name, axis=0)
mask.coords[self.bin_dim_name] = np.array([lat_start])
masks.append(mask)
return xr.concat(masks, dim=self.bin_dim_name)
class LongitudeBins(Binning):
"""Class for binning by longitude bands."""
def __init__(
self,
degrees: float,
lon_range: Tuple[int, int] = (0, 360),
bin_dim_name: str = 'longitude_bins',
):
"""Init.
Args:
degrees: Grid spacing in degrees.
lon_range: Tuple of (min_lon, max_lon).
bin_dim_name: Name of binning dimension.
"""
super().__init__(bin_dim_name)
self._degrees = degrees
lon_end = lon_range[1]
if lon_range[0] >= lon_range[1]:
lon_end += 360
self._lon_bins = np.arange(
lon_range[0], lon_end + self._degrees, self._degrees
)
def create_bin_mask(
self,
statistic: xr.DataArray,
) -> xr.DataArray:
"""Creates a bin mask for a statistic."""
masks = []
for lon_start in self._lon_bins[:-1]:
lon_end = lon_start + self._degrees
mask = _create_lon_mask(
statistic.longitude,
(lon_start, lon_end),
)
# Broadcast the mask to the shape of statistic
mask = mask.broadcast_like(statistic)
mask = mask.expand_dims(dim=self.bin_dim_name, axis=0)
mask.coords[self.bin_dim_name] = np.array([np.mod(lon_start, 360)])
masks.append(mask)
return xr.concat(masks, dim=self.bin_dim_name)
def vectorized_coord_mask(
coord: xr.DataArray,
coord_name: str,
bin_dim_name: str,
add_global_bin: bool = False,
) -> xr.DataArray:
"""Helper to create bin masks for unique coordinate values."""
unique_coord = np.unique(coord)
ndims = len(coord.dims)
# Use vectorized equal. This also works in the case of empty statistic.
masks = xr.DataArray(
np.equal(coord.values, unique_coord.reshape((-1,) + (1,) * ndims)),
coords={bin_dim_name: unique_coord}
| {dim: coord[dim] for dim in coord.dims},
dims=[bin_dim_name] + list(coord.dims),
)
if add_global_bin:
mask = (
xr.ones_like(coord.astype(bool))
.drop(coord_name) # Drop the coordinate
.expand_dims(bin_dim_name) # Add as a dimension
)
mask.coords[bin_dim_name] = ['global']
# Dtypes of bin coordinates need to match. If they don't cast both to
# str.
if mask[bin_dim_name].dtype != masks[bin_dim_name].dtype:
masks.coords[bin_dim_name] = masks[bin_dim_name].astype('str')
mask.coords[bin_dim_name] = mask[bin_dim_name].astype('str')
masks = xr.concat([mask, masks], dim=bin_dim_name)
return masks
[docs]
class ByExactCoord(Binning):
"""Binning by unique coordinate values.
This will create a bin for each unique coordinate value, for example for each
unique lead time in the case of sparse forecasts where lead_time is a
coordinate but not a dimension.
"""
def __init__(self, coord: str, add_global_bin: bool = False):
"""Init.
Args:
coord: Name of coordinate to bin by.
add_global_bin: If True, add a global bin containing all data. Default:
False.
"""
super().__init__(coord)
self.coord = coord
self.add_global_bin = add_global_bin
def create_bin_mask(
self,
statistic: xr.DataArray,
) -> xr.DataArray:
assert (
self.coord not in statistic.dims
), 'For dimensions, specify reduce_dims in aggregation.'
coord = statistic[self.coord]
# Coord name and bin_dim_name are the same in this case.
masks = vectorized_coord_mask(
coord, self.coord, self.coord, self.add_global_bin
)
return masks
def _extract_time_unit(
time_coord: xr.DataArray,
unit: str,
) -> xr.DataArray:
"""Extract time unit values from a datetime/timedelta coordinate.
Args:
time_coord: A datetime64 or timedelta64 xarray DataArray.
unit: Time unit to extract, e.g. 'second', 'minute', 'hour', 'day', 'week',
'year', 'month', 'dayofyear', etc.
Returns:
DataArray containing the extracted time unit values.
Raises:
ValueError: If the unit is not supported for timedelta coordinates.
"""
dt = time_coord.dt
if isinstance(dt, xr.core.accessor_dt.TimedeltaAccessor):
coord = time_coord.dt.total_seconds()
if unit == 'minute':
coord = coord // (60)
elif unit == 'hour':
coord = coord // (60 * 60)
elif unit == 'day':
coord = coord // (60 * 60 * 24)
elif unit == 'week':
coord = coord // (60 * 60 * 24 * 7)
elif unit == 'year':
coord = coord // (60 * 60 * 24 * 365)
elif unit != 'second':
raise ValueError(f'Unsupported unit for timedelta: {unit}')
else:
assert isinstance(dt, xr.core.accessor_dt.DatetimeAccessor)
coord = getattr(time_coord.dt, unit)
return coord
[docs]
class ByTimeUnit(Binning):
"""Bin by time unit for given axis.
This uses the .dt datetime accessor in xarray, and will work with both
datetime64 and timedelta64 coordinates. However, the units should be in the
datetime64 convention, i.e. 'second', 'minute', 'hour', etc.
See:
https://docs.xarray.dev/en/latest/generated/xarray.core.accessor_dt.DatetimeAccessor.html
Example:
```
unit = 'hour'
time_dim = 'init_time'
```
This will aggregate together all data initialized at the same time of day,
e.g. [0, 1, 2, .., 23].
"""
def __init__(self, unit: str, time_dim: str, add_global_bin: bool = False):
"""Init.
Args:
unit: Time unit to bin by.
time_dim: Time dimension to bin by.
add_global_bin: If True, add a global bin containing all data. Default:
False.
"""
super().__init__(f'{time_dim}_{unit}')
self.unit = unit
self.time_dim = time_dim
self.add_global_bin = add_global_bin
def create_bin_mask(
self,
statistic: xr.DataArray,
) -> xr.DataArray:
coord = _extract_time_unit(statistic[self.time_dim], self.unit)
masks = vectorized_coord_mask(
coord,
self.time_dim,
f'{self.time_dim}_{self.unit}',
self.add_global_bin,
)
return masks
class ByTimeUnitSets(Binning):
"""Bin by sets of time unit values for a given axis.
This combines the time unit extraction logic of ByTimeUnit with the set-based
binning logic of BySets. It allows grouping by arbitrary sets of time unit
values, for example grouping hours 0 and 12 together, and hours 6 and 18
together.
Example:
```
sets = {'00/12': [0, 12], '06/18': [6, 18]}
unit = 'hour'
dim = 'init_time'
```
This will create two bins: one for data initialized at hours 0 or 12, and
another for data initialized at hours 6 or 18.
"""
def __init__(
self,
sets: Mapping[str, Sequence[Any] | Any],
unit: str,
dim: str,
bin_dim_name: Optional[str] = None,
add_global_bin: bool = False,
):
"""Init.
Args:
sets: Dictionary specifying sets of time unit values to bin by. Keys are
bin names, values are sequences of time unit values (e.g. hours).
unit: Time unit to extract, e.g. 'hour', 'day', 'month', 'dayofyear'.
dim: Time dimension/coordinate to bin by.
bin_dim_name: Name of binning dimension. Default: `{dim}_{unit}_sets`.
add_global_bin: If True, add a global bin containing all data. Default:
False.
"""
if bin_dim_name is None:
bin_dim_name = f'{dim}_{unit}_sets'
super().__init__(bin_dim_name)
self.sets = sets
self.unit = unit
self.dim = dim
self.add_global_bin = add_global_bin
def create_bin_mask(
self,
statistic: xr.DataArray,
) -> xr.DataArray:
time_unit_values = _extract_time_unit(statistic[self.dim], self.unit)
masks = []
for name, s in self.sets.items():
if isinstance(s, (Sequence,)) and not isinstance(s, str):
s = list(s)
else:
s = [s]
s = np.array(s)
mask = time_unit_values.isin(s)
mask = mask.expand_dims(self.bin_dim_name, axis=0)
mask.coords[self.bin_dim_name] = [name]
masks.append(mask)
if self.add_global_bin:
mask = xr.full_like(time_unit_values, True, dtype=bool).expand_dims(
self.bin_dim_name
)
mask.coords[self.bin_dim_name] = ['global']
masks.append(mask)
return xr.concat(masks, self.bin_dim_name)
class ByTimeUnitFromSeconds(Binning):
"""Similar to ByTimeUnit, but with the coordinate in seconds as a scalar.
The seconds values will be converted to the desired time unit.
This is useful if you want to wrap the computation in jax.jit, which does not
support datetime64/timedelta64 coordinates.
"""
def __init__(
self, unit: str, time_dim: str, bins: Sequence[int] | None = None
):
"""Init.
Args:
unit: Time unit to bin by, one of 'second', 'minute', 'hour'.
time_dim: Time dimension to bin by.
bins: Sequence of bins to bin by. If None, will use default bins depending
on the unit (e.g. 0 through 23 for hour). Note that these defaults won't
always make sense (e.g. if binning by lead time, hours can be > 23).
"""
super().__init__(f'{time_dim}_{unit}')
self.unit = unit
self.time_dim = time_dim
self.bins = bins
def create_bin_mask(
self,
statistic: xr.DataArray,
) -> xr.DataArray:
coord = statistic[self.time_dim]
bins = self.bins
if self.unit == 'second':
bins = bins if bins is not None else np.arange(0, 60)
elif self.unit == 'minute':
coord = coord // (60)
bins = bins if bins is not None else np.arange(0, 60)
elif self.unit == 'hour':
coord = coord // (60 * 60)
bins = bins if bins is not None else np.arange(0, 24)
else:
raise ValueError(f'Unsupported unit: {self.unit}')
bin_dim_name = f'{self.time_dim}_{self.unit}'
masks = coord == xr.DataArray(bins, dims=[bin_dim_name]).broadcast_like(
coord
)
masks = masks.assign_coords({bin_dim_name: bins})
return masks
[docs]
class ByCoordBins(Binning):
"""Binning by specified bins over a coordinate."""
def __init__(
self,
dim_name: str,
bin_edges: np.ndarray,
add_global_bin: bool = False,
):
"""Init.
Args:
dim_name: Name of dimension to bin by.
bin_edges: Bin edges to bin by.
add_global_bin: If True, add a global bin containing all data. Default:
False.
"""
super().__init__(dim_name)
self.dim_name = dim_name
self.bin_edges = bin_edges
self.add_global_bin = add_global_bin
def create_bin_mask(
self,
statistic: xr.DataArray,
) -> xr.DataArray:
masks = []
# TODO(srasp): Potentially optimize using np.digitize.
for start, stop in zip(self.bin_edges[:-1], self.bin_edges[1:]):
mask = np.logical_and(
statistic.coords[self.dim_name] >= start,
statistic.coords[self.dim_name] < stop,
)
mask = mask.drop([self.dim_name]).expand_dims(self.dim_name, axis=0)
coord_val = str(start) if self.add_global_bin else start
mask.coords[self.dim_name] = np.array([coord_val])
mask.assign_coords({
self.dim_name
+ '_left_edge': xr.DataArray([start], dims=[self.dim_name]),
self.dim_name
+ '_right_edge': xr.DataArray([stop], dims=[self.dim_name]),
})
masks.append(mask)
if self.add_global_bin:
mask = xr.full_like(statistic.coords[self.dim_name], True, dtype=bool)
mask = mask.drop([self.dim_name]).expand_dims(self.dim_name, axis=0)
mask.coords[self.dim_name] = ['global']
masks.append(mask)
if not masks: # Catch possibility of empty input arrays.
dtype = statistic[self.dim_name].dtype
masks = (
xr.ones_like(statistic)
.drop(self.dim_name)
.expand_dims(
{
self.dim_name: (
xr.DataArray([], dims=[self.dim_name]).astype(dtype)
)
},
axis=0,
)
)
return masks
else:
return xr.concat(masks, self.dim_name)
[docs]
class BySets(Binning):
"""Bin by sets of values along a coordinate.
This is, for example, useful for binning by different sets of station names.
"""
def __init__(
self,
sets: Mapping[str, Sequence[Any] | Any],
coord_name: str,
bin_dim_name: Optional[str] = None,
add_set_complements: bool = False,
add_global_bin: bool = False,
):
"""Init.
Args:
sets: Dictionary specifying sets of values to bin by.
coord_name: Name of coordinate to bin over.
bin_dim_name: Name of binning dimension. Default: `dim_name`
add_set_complements: If True, for each set, also add a bin for all values
not in the set.
add_global_bin: If True, add a global bin containing all data. Default:
False.
"""
if bin_dim_name is None or bin_dim_name == coord_name:
raise ValueError(
'bin_dim_name must be defined and be different from coord_name.'
)
super().__init__(bin_dim_name)
self.sets = sets
self.coord_name = coord_name
self.add_set_complements = add_set_complements
self.add_global_bin = add_global_bin
def create_bin_mask(
self,
statistic: Union[xr.DataArray, xr.Dataset],
) -> xr.DataArray:
masks = []
for name, s in self.sets.items():
# Convert s to a numpy array to handle different input types and
# ensure compatibility with isin and JAX.
if isinstance(s, (Sequence,)) and not isinstance(s, str):
s = list(s)
else:
s = [s]
s = np.array(s)
mask = statistic[self.coord_name].isin(s)
mask = mask.expand_dims(self.bin_dim_name, axis=0)
mask.coords[self.bin_dim_name] = [name]
masks.append(mask)
if self.add_set_complements:
not_in_mask = ~mask.copy()
not_in_mask.coords[self.bin_dim_name] = [f'not_in_{name}']
masks.append(not_in_mask)
if self.add_global_bin:
mask = xr.full_like(
statistic[self.coord_name], True, dtype=bool
).expand_dims(
self.bin_dim_name
) # Add as a dimension
mask.coords[self.bin_dim_name] = ['global']
masks.append(mask)
return xr.concat(masks, self.bin_dim_name)