Source code for weatherbenchX.binning

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Binning class definitions."""

import abc
from typing import Any, Hashable, Mapping, Optional, Sequence, Tuple, Union
import numpy as np
import xarray as xr


[docs] class Binning(abc.ABC): """Binning base class.""" def __init__(self, bin_dim_name: str): """Init. Args: bin_dim_name: Name of binning dimension. """ self.bin_dim_name = bin_dim_name @abc.abstractmethod def create_bin_mask( self, statistic: xr.DataArray, ) -> xr.DataArray: """Creates a bin mask for a statistic. It is assumed that all information required to compute bins is included in the statistics element. Args: statistic: Individual DataArray with statistic values. Returns: bin_mask: Boolean mask with shape that boradcasts against the statistic DataArray. """
def _region_to_mask( lat: xr.DataArray, lon: xr.DataArray, lat_lims: Tuple[int, int], lon_lims: Tuple[int, int], ) -> xr.DataArray: """Computes a boolean mask for a lat/lon limits region.""" if lat_lims[0] >= lat_lims[1]: raise ValueError( f'`lat_lims[0]` must be smaller than `lat_lims[1]`, got {lat_lims}`' ) lat_mask = np.logical_and(lat >= lat_lims[0], lat <= lat_lims[1]) # Make sure we are in the [0, 360] interval. lon = np.mod(lon, 360) lon_lims = np.mod(lon_lims[0], 360), np.mod(lon_lims[1], 360) if lon_lims[1] > lon_lims[0]: # Same as the latitude. lon_mask = np.logical_and(lon >= lon_lims[0], lon <= lon_lims[1]) else: # In this case it means we need to wrap longitude around the other side of # the globe. lon_mask = np.logical_or(lon <= lon_lims[1], lon >= lon_lims[0]) return np.logical_and(lat_mask, lon_mask)
[docs] class Regions(Binning): """Class for rectangular region binning. Note that coordinate must be named `latitude` and `longitude`. """ def __init__( self, regions: Mapping[Hashable, Tuple[Tuple[int, int], Tuple[int, int]]], bin_dim_name: str = 'region', land_sea_mask: Optional[xr.DataArray] = None, ): """Init. Args: regions: Dictionary specifying {name: ((lat_lims), (lon_lims))}. bin_dim_name: Name of binning dimension. Default: 'region' land_sea_mask: (Optional) Boolean mask (land = True) with same latitude/longitude coordinates as the statistic. If provided, for each region will add a new land-onlybin with the name {region}_land. """ super().__init__(bin_dim_name) self._regions = regions self._land_sea_mask = land_sea_mask def _regions_to_masks( self, lat: xr.DataArray, lon: xr.DataArray, ) -> xr.DataArray: """Computes and stacks masks for all regions.""" masks = [] for region_name, (lat_lims, lon_lims) in self._regions.items(): mask = _region_to_mask(lat, lon, lat_lims, lon_lims) mask = mask.expand_dims(dim=self.bin_dim_name, axis=0) mask.coords[self.bin_dim_name] = np.array([region_name]) masks.append(mask) return xr.concat(masks, dim=self.bin_dim_name) def create_bin_mask( self, statistic: xr.DataArray, ) -> xr.DataArray: masks = self._regions_to_masks(statistic.latitude, statistic.longitude) if self._land_sea_mask is not None: assert np.array_equal( np.sort(masks.latitude), np.sort(self._land_sea_mask.latitude) ) and np.array_equal( masks.longitude, self._land_sea_mask.longitude ), 'Land/sea mask coordinates do not match.' land_masks = masks * self._land_sea_mask.astype(bool) region_names = [f'{r}_land' for r in masks.coords[self.bin_dim_name].data] land_masks.coords[self.bin_dim_name] = np.array(region_names) masks = xr.concat([masks, land_masks], dim=self.bin_dim_name) return masks
def vectorized_coord_mask( coord: xr.DataArray, coord_name: str, bin_dim_name: str, add_global_bin: bool = False, ) -> xr.DataArray: """Helper to create bin masks for unique coordinate values.""" unique_coord = np.unique(coord) ndims = len(coord.dims) # Use vectorized equal. This also works in the case of empty statistic. masks = xr.DataArray( np.equal(coord.values, unique_coord.reshape((-1,) + (1,) * ndims)), coords={bin_dim_name: unique_coord} | {dim: coord[dim] for dim in coord.dims}, dims=[bin_dim_name] + list(coord.dims), ) if add_global_bin: mask = ( xr.ones_like(coord.astype(bool)) .drop(coord_name) # Drop the coordinate .expand_dims(bin_dim_name) # Add as a dimension ) mask.coords[bin_dim_name] = ['global'] # Dtypes of bin coordinates need to match. If they don't cast both to # str. if mask[bin_dim_name].dtype != masks[bin_dim_name].dtype: masks.coords[bin_dim_name] = masks[bin_dim_name].astype('str') mask.coords[bin_dim_name] = mask[bin_dim_name].astype('str') masks = xr.concat([mask, masks], dim=bin_dim_name) return masks
[docs] class ByExactCoord(Binning): """Binning by unique coordinate values. This will create a bin for each unique coordinate value, for example for each unique lead time in the case of sparse forecasts where lead_time is a coordinate but not a dimension. """ def __init__(self, coord: str, add_global_bin: bool = False): """Init. Args: coord: Name of coordinate to bin by. add_global_bin: If True, add a global bin containing all data. Default: False. """ super().__init__(coord) self.coord = coord self.add_global_bin = add_global_bin def create_bin_mask( self, statistic: xr.DataArray, ) -> xr.DataArray: assert ( self.coord not in statistic.dims ), 'For dimensions, specify reduce_dims in aggregation.' coord = statistic[self.coord] # Coord name and bin_dim_name are the same in this case. masks = vectorized_coord_mask( coord, self.coord, self.coord, self.add_global_bin ) return masks
[docs] class ByTimeUnit(Binning): """Bin by time unit for given axis. This uses the .dt datetime accessor in xarray, so this will only work for datetime64 coordinates. See: https://docs.xarray.dev/en/latest/generated/xarray.core.accessor_dt.DatetimeAccessor.html Example: ``` unit = 'hour' time_dim = 'init_time' ``` This will aggregate together all data initialized at the same time of day, e.g. [0, 1, 2, .., 23]. """ def __init__(self, unit: str, time_dim: str, add_global_bin: bool = False): # TODO(srasp): Add support for sequence of units. """Init. Args: unit: Time unit to bin by. time_dim: Time dimension to bin by. add_global_bin: If True, add a global bin containing all data. Default: False. """ super().__init__(f'{time_dim}_{unit}') self.unit = unit self.time_dim = time_dim self.add_global_bin = add_global_bin def create_bin_mask( self, statistic: xr.DataArray, ) -> xr.DataArray: coord = getattr(statistic[self.time_dim].dt, self.unit) masks = vectorized_coord_mask( coord, self.time_dim, f'{self.time_dim}_{self.unit}', self.add_global_bin, ) return masks
[docs] class ByCoordBins(Binning): """Binning by specified bins over a coordinate.""" def __init__(self, dim_name: str, bin_edges: np.ndarray): """Init. Args: dim_name: Name of dimension to bin by. bin_edges: Bin edges to bin by. """ super().__init__(dim_name) self.dim_name = dim_name self.bin_edges = bin_edges def create_bin_mask( self, statistic: xr.DataArray, ) -> xr.DataArray: masks = [] # TODO(srasp): Potentially optimize using np.digitize. for start, stop in zip(self.bin_edges[:-1], self.bin_edges[1:]): mask = np.logical_and( statistic.coords[self.dim_name] >= start, statistic[self.dim_name] < stop, ) mask = mask.drop([self.dim_name]).expand_dims(self.dim_name, axis=0) mask.coords[self.dim_name] = np.array([start]) mask.assign_coords({ self.dim_name + '_left_edge': xr.DataArray([start], dims=[self.dim_name]), self.dim_name + '_right_edge': xr.DataArray([stop], dims=[self.dim_name]), }) masks.append(mask) if not masks: # Catch possibility of empty input arrays. dtype = statistic[self.dim_name].dtype masks = ( xr.ones_like(statistic) .drop(self.dim_name) .expand_dims( { self.dim_name: xr.DataArray([], dims=[self.dim_name]).astype( dtype ) }, axis=0, ) ) return masks else: return xr.concat(masks, self.dim_name)
[docs] class BySets(Binning): """Bin by sets of values along a coordinate. This is, for example, useful for binning by different sets of station names. """ def __init__( self, sets: Mapping[str, Sequence[Any]], coord_name: str, bin_dim_name: Optional[str] = None, add_global_bin: bool = False, ): """Init. Args: sets: Dictionary specifying sets of values to bin by. coord_name: Name of coordinate to bin over. bin_dim_name: Name of binning dimension. Default: `dim_name` add_global_bin: If True, add a global bin containing all data. Default: False. """ if bin_dim_name is None or bin_dim_name == coord_name: raise ValueError( 'bin_dim_name must be defined and be different from coord_name.' ) super().__init__(bin_dim_name) self.sets = sets self.coord_name = coord_name self.add_global_bin = add_global_bin def create_bin_mask( self, statistic: Union[xr.DataArray, xr.Dataset], ) -> xr.DataArray: masks = [] for name, s in self.sets.items(): mask = statistic[self.coord_name].isin(s) mask = mask.expand_dims(self.bin_dim_name, axis=0) mask.coords[self.bin_dim_name] = [name] masks.append(mask) if self.add_global_bin: mask = xr.full_like( statistic[self.coord_name], True, dtype=bool ).expand_dims( self.bin_dim_name ) # Add as a dimension mask.coords[self.bin_dim_name] = ['global'] masks.append(mask) return xr.concat(masks, self.bin_dim_name)