Source code for weatherbenchX.aggregation

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definition of aggregation methods and AggregationState."""

import dataclasses
import math
from typing import Collection, Hashable, Mapping, Optional, Sequence

import numpy as np
from weatherbenchX import binning
from weatherbenchX import weighting
from weatherbenchX import xarray_tree
from weatherbenchX.metrics import base as metrics_base
import xarray as xr


def _combining_sum(
    data_arrays: Sequence[Optional[xr.DataArray]],
) -> Optional[xr.DataArray]:
  """A sum which combines / aligns coordinates in case they don't match.

  It's semantically equivalent to

  sum(xarray.align(data_arrays, join='outer', fill_value=0))

  but when all summands have the same coordinates it's just a plain sum.

  Args:
    data_arrays: To be summed/combined. All must all have the same set of
      dimensions and must have coordinates for all dimensions so we can be sure
      we have aligned them correctly. None values are allowed and will be
      dropped from the sum (treated as empty/zero).

  Returns:
    A data_array whose index coordinates along each dimension are the union of
    the index coordinates of all the arguments along that dimension, or None
    if there were no non-None arguments.
  """
  # Arrays for individual statistics can be None if an AggregationMethod wasn't
  # able to handle that statistic (e.g. aggregation.Aggregator will return None
  # for statistics that don't contain the requested preserve_dims).
  data_arrays = [s for s in data_arrays if s is not None]
  if not data_arrays:
    return None

  if len(data_arrays) == 1:
    return data_arrays[0]

  dims = set(data_arrays[0].dims)
  for a in data_arrays[1:]:
    if set(a.dims) != dims:
      raise ValueError(
          f'Different dims encountered by _combining_sum: {a.dims} vs {dims}.'
      )

  for a in data_arrays:
    for dim in dims:
      if dim not in a.coords:
        raise ValueError(
            'All dimensions must have coordinates to ensure alignment when '
            f'summing statistics, but dimemsion {dim} lacked coordinates.'
        )

  # Fast path when index coordinates are all the same.
  with xr.set_options(arithmetic_join='exact'):
    try:
      return sum(data_arrays[1:], start=data_arrays[0])
    except ValueError:
      # Coordinates were not exactly aligned.
      pass

  # Potentially-slow but general path, the other paths above do the same thing
  # as this but may be faster.
  # This will extend each array to use the union of all the coordinates, padding
  # with zeros for any missing coordinates, and only sum after padding each
  # array. As such it may be quadratic in len(data_arrays) in the worst case.
  data_arrays = xr.align(*data_arrays, join='outer', fill_value=0, copy=False)
  return sum(data_arrays[1:], start=data_arrays[0])


def _fast_dot(
    a: xr.DataArray, b: xr.DataArray, reduce_dims: set[str]
) -> xr.DataArray:
  """Implementation of dot product semantically almost equivalent to xr.dot.

  Difference one is that this will only work if all reduce dims are present in
  both arrays. xr.dot will also sum out any reduce dims present in only one of
  the arrays. Here this is already taken care of in the calling function, so
  an error would be raised if this is not the case.

  Difference two is that this implementation requires that all common dims
  between a and b are also part of reduce_dims. In other words, any binning
  dimension must also be reduced.

  Args:
    a: First array.
    b: Second array.
    reduce_dims: Set of dimensions to reduce over.

  Returns:
    The dot product of a and b.
  """
  # If data is empty, reshaping array below will fail. In this case, we can just
  # use the xarray implementation.
  if a.size == 0 or b.size == 0:
    return xr.dot(a, b, dim=reduce_dims)

  assert (
      not set(a.dims).intersection(set(b.dims)) - reduce_dims
  ), 'Not all common dims between a and b are also part of reduce_dims.'

  def reshape_data(x):
    """Transformations to apply to both arrays."""
    array_dims = set(x.dims)
    assert reduce_dims.issubset(
        array_dims
    ), 'Not all reduce_dims are present in the array.'
    non_reduce_dims = sorted(list(array_dims - reduce_dims))
    coords = {c: x.coords[c] for c in non_reduce_dims}
    ordered_dims = non_reduce_dims + sorted(list(reduce_dims))

    # Make sure we are also preserving non-dimension coordinates. But only if
    # they don't contain any of the reduce_dims.
    other_coords = {
        c: x.coords[c]
        for c in x.coords
        if c not in x.dims
        and not set(x.coords[c].dims) - set(non_reduce_dims)  # Must be empty.
    }
    len_dims = [x.sizes[c] for c in non_reduce_dims]
    x = x.transpose(*ordered_dims)
    # Switch to numpy for reshaping since this is faster than xarray's stack.
    # Array should how have two dimensions with shape:
    # (product(non_reduce_dims), product(reduce_dims)).
    x = x.values.reshape(math.prod(len_dims), -1)
    return x, non_reduce_dims, coords, other_coords, len_dims

  a, non_reduce_dims_a, coords_a, other_coords_a, len_dims_a = reshape_data(a)
  b, non_reduce_dims_b, coords_b, other_coords_b, len_dims_b = reshape_data(b)
  len_dims_out = len_dims_a + len_dims_b

  out = np.dot(a, b.T)
  out = out.reshape(len_dims_out)
  out = xr.DataArray(
      out,
      dims=non_reduce_dims_a + non_reduce_dims_b,
      coords=coords_a | coords_b | other_coords_a | other_coords_b,
  )
  return out


[docs] @dataclasses.dataclass class AggregationState: """An object that contains sum of weighted statistics and sum of weights. Allows for aggregation over multiple chunks, e.g. in a Beam pipeline. Attributes: sum_weighted_statistics: Structure containing summed/aggregated statistics. sum_weights: Structure containing the corresponding summed weights. """ sum_weighted_statistics: Optional[ Mapping[str, Mapping[Hashable, xr.DataArray]] ] sum_weights: Optional[Mapping[str, Mapping[Hashable, xr.DataArray]]] @classmethod def zero(cls) -> 'AggregationState': """An initial/'zero' aggregation state.""" return cls(sum_weighted_statistics=None, sum_weights=None) def __add__(self, other: 'AggregationState') -> 'AggregationState': return self.sum([self, other]) @classmethod def sum( cls, aggregation_states: list['AggregationState'] ) -> 'AggregationState': """Sum of aggregation states.""" sum_weighted_statistics_and_sum_weights_tuples = [ (a.sum_weighted_statistics, a.sum_weights) for a in aggregation_states if a.sum_weighted_statistics is not None ] # Sometimes beam does a reduction with only Zero states. In this case, we # end up with an empty collection. In these cases, we need to return a zero # state. if not sum_weighted_statistics_and_sum_weights_tuples: return cls.zero() # Sum over each element in the nested dictionaries sum_weighted_statistics, sum_weights = xarray_tree.map_structure( lambda *a: _combining_sum(a), *sum_weighted_statistics_and_sum_weights_tuples, ) return cls(sum_weighted_statistics, sum_weights) def mean_statistics(self) -> Mapping[str, Mapping[Hashable, xr.DataArray]]: """Returns the statistics normalized by their corresponding weights.""" def normalize(sum_weighted_statistics, sum_weights): return sum_weighted_statistics / sum_weights return xarray_tree.map_structure( normalize, self.sum_weighted_statistics, self.sum_weights ) def metric_values( self, metrics: Mapping[Hashable, metrics_base.Metric] ) -> xr.Dataset: """Returns metrics computed from the normalized statistics. Args: metrics: Dictionary of metric names and instances. Returns: values: Combined dataset with naming convention <metric>.<variable> """ mean_statistics = self.mean_statistics() values = xr.Dataset() for metric_name, metric in metrics.items(): values_for_metric = metric.values_from_mean_statistics(mean_statistics) for var_name, da in values_for_metric.items(): values[f'{metric_name}.{var_name}'] = da return values
[docs] @dataclasses.dataclass class Aggregator: """Defines aggregation over set of dataset dimensions. Note on NaNs: By default, all reductions are performed with skipna=False, meaning that the aggregated statistics will be NaN if any of the input statistics are NaN. Currently, there is one awkward use case, where even if the input NaNs are outside the binning mask, e.g. if NaNs appear in a different region from the binning region, the aggregated statistics will still be NaN. Use the masking option to avoid this. Attributes: reduce_dims: Dimensions to average over. Any variables that don't have these dimensions will be filtered out during aggregation. bin_by: List of binning instances. All bins will be multiplied. weigh_by: List of weighting instance. All weights will be multiplied. masked: If True, aggregation will only be performed for non-masked (True on the mask) values. This requires a 'mask' coordinate on the statistics passed to aggregate_statistics. skipna: If True, NaNs will be omitted in the aggregation. This option is not recommended, as it won't catch unexpected NaNs. """ reduce_dims: Collection[str] bin_by: Sequence[binning.Binning] | None = None weigh_by: Sequence[weighting.Weighting] | None = None masked: bool = False skipna: bool = False def aggregation_fn( self, stat: xr.DataArray, ) -> xr.DataArray | None: """Returns the aggregation function.""" # Recall that masked out values have already been set to zero in # aggregate_statistics. The logic below has to respect this. # Take product of all weights and bins. if self.weigh_by is not None: weights = math.prod( [weighting_method.weights(stat) for weighting_method in self.weigh_by] ) stat = stat * weights reduce_dims_set = set(self.reduce_dims) eval_unit_dims = set(stat.dims) if not reduce_dims_set.issubset(eval_unit_dims): # Can't reduce over dims that aren't present as evaluation unit dims. return None if self.bin_by is not None: bin_masks = xr.DataArray( math.prod([ binning_method.create_bin_mask(stat) for binning_method in self.bin_by ]) ) bin_dim_names = set([binning.bin_dim_name for binning in self.bin_by]) if len(bin_dim_names) != len(self.bin_by): raise ValueError('Bin dimension names must be unique.') # bin_masks_dims are all of the dims the mask operate with on the input # data (e.g. the actual bin dimension does not count). bin_masks_dims = set(bin_masks.dims) - set(bin_dim_names) if not bin_masks_dims.issubset(eval_unit_dims): # Can't bin based on dims that aren't present as evaluation unit dims: return None # These dimensions don't need preserving, and are also not explicitly # used by the masks, so we just sum them first. non_bin_index_reduce_dims = reduce_dims_set - bin_masks_dims stat = stat.sum(non_bin_index_reduce_dims, skipna=False) # Finally we compute the element-wise product, reducing only across # the bin masks dimensions that we are not preserving. The # bin_dim_names dimensions will always be preserved. bin_index_reduce_dims = reduce_dims_set - non_bin_index_reduce_dims binned_data = _fast_dot(bin_masks, stat, bin_index_reduce_dims) return binned_data else: # Simple sum when no binning is applied. return stat.sum(reduce_dims_set, skipna=False) def aggregate_statistics( self, statistics: Mapping[str, Mapping[Hashable, xr.DataArray]], ) -> AggregationState: """Aggregate all statistics for a batch. Args: statistics: Full statistics for a batch. Returns: AggregationState instance with a sum of weighted statistics and a sum of weights for the current batch. These can be summed over multiple batches, and then used to compute weighted mean statistics, and from these the final values of the metrics. """ # Different aggregator for each variable def batch_aggregator_for_var_and_stat(stat): if self.skipna: # Set NaNs to zero, so that they will be ignored in the sum. stat = stat.where(~stat.isnull(), 0) if self.masked and hasattr(stat, 'mask'): # Set masked values to Zero for stat and weights, which will therefore # be ignored in mean_statistics(). this is equivalent to multiplying by # the mask, but avoids NaN * 0 -> NaN in cases where there are NaNs in # masked positions. Only for variables with a mask attribute. if hasattr(stat, 'mask'): stat = stat.where(stat.mask, 0) return self.aggregation_fn(stat) def batch_aggregator_weights_for_var_and_stat(stat): ones = xr.ones_like(stat) # Make sure the weights are also zero for skipna and masked aggregation. if self.skipna: ones = ones.where(~stat.isnull(), 0) if self.masked and hasattr(stat, 'mask'): ones = ones.where(stat.mask, 0) return batch_aggregator_for_var_and_stat(ones) def filter_nones(x): result = {} for name, values in x.items(): if isinstance(values, xr.Dataset): # Dataset has already had None's filtered out by xarray_tree, # but we want to preserve its type: result[name] = values else: result[name] = {k: v for k, v in values.items() if v is not None} return result sum_weighted_statistics = filter_nones( xarray_tree.map_structure(batch_aggregator_for_var_and_stat, statistics) ) sum_weights = filter_nones( xarray_tree.map_structure( batch_aggregator_weights_for_var_and_stat, statistics ) ) # Aggregator for every dataset in statistics return AggregationState(sum_weighted_statistics, sum_weights)
def compute_metric_values_for_single_chunk( metrics: Mapping[str, metrics_base.Metric], aggregator: Aggregator, predictions: Mapping[Hashable, xr.DataArray], targets: Mapping[Hashable, xr.DataArray], ) -> xr.Dataset: """Convenience function to compute metric results for a given predictions/targets pair. This is not intended to accumulate over multiple chunks. Args: metrics: Dictionary of metrics instances. aggregator: Aggregator instance. predictions: Xarray Dataset or dictionary of DataArrays. targets: Xarray Dataset or dictionary of DataArrays. Returns: results: Xarray Dataset of metric values. """ statistics = metrics_base.compute_unique_statistics_for_all_metrics( metrics, predictions, targets ) aggregation_state = aggregator.aggregate_statistics(statistics) results = aggregation_state.metric_values(metrics) return results