Source code for weatherbenchX.beam_pipeline

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the beam pipeline for evaluation."""

from collections.abc import Hashable
import dataclasses
import time
import typing
from typing import Callable, Iterable, Iterator, Literal, Mapping, Never, Optional, Union

from absl import logging
import apache_beam as beam
import numpy as np
from weatherbenchX import aggregation
from weatherbenchX import beam_utils
from weatherbenchX import time_chunks
from weatherbenchX.data_loaders import base as data_loaders_base
from weatherbenchX.metrics import base as metrics_base
import xarray as xr
import xarray_beam as xbeam


class LoadPredictionsAndTargets(beam.DoFn):
  """Loads prediction and target chunks."""

  def __init__(
      self,
      predictions_loader: data_loaders_base.DataLoader,
      targets_loader: data_loaders_base.DataLoader,
      setup_fn: Optional[Callable[[], None]] = None,
  ):
    """Init.

    Args:
      predictions_loader: The data loader for the predictions.
      targets_loader: The data loader for the targets.
      setup_fn: (Optional) A function to call once per worker.
    """
    self.predictions_loader = predictions_loader
    self.targets_loader = targets_loader
    self.setup_fn = setup_fn
    self.is_initialized = False
    self.target_load_time = beam.metrics.Metrics.distribution(
        'LoadPredictionsAndTargets', 'target_load_time'
    )
    self.prediction_load_time = beam.metrics.Metrics.distribution(
        'LoadPredictionsAndTargets', 'prediction_load_time'
    )

  def setup(self):
    # Call this function once per process.
    if self.setup_fn is not None:
      if not self.is_initialized:
        self.setup_fn()
        self.is_initialized = True

  def process(
      self,
      all_inputs: tuple[
          time_chunks.TimeChunkOffsets,
          tuple[np.ndarray, Union[np.ndarray, slice]],
      ],
  ) -> Iterable[
      tuple[
          time_chunks.TimeChunkOffsets,
          tuple[
              Mapping[Hashable, xr.DataArray], Mapping[Hashable, xr.DataArray]
          ],
      ]
  ]:
    """Returns prediction and target chunks for a chunk of init/lead times.

    Args:
      all_inputs: (time_chunk_offsets, (init_times, lead_times))

    Returns:
      (time_chunk_offsets, (predictions_chunk, targets_chunk))
    """
    logging.log_first_n(
        logging.INFO, 'LoadPredictionsAndTargets inputs: %s', 10, all_inputs
    )
    time_chunk_offsets, (init_times, lead_times) = all_inputs

    start_time = time.time()
    targets_chunk = self.targets_loader.load_chunk(init_times, lead_times)
    self.target_load_time.update(
        (time.time() - start_time) * 1000
    )  # In milliseconds because beam counters use longs.

    start_time = time.time()
    predictions_chunk = self.predictions_loader.load_chunk(
        init_times, lead_times, targets_chunk
    )
    self.prediction_load_time.update(
        (time.time() - start_time) * 1000
    )  # In milliseconds because beam counters use longs.

    logging.log_first_n(
        logging.INFO,
        'LoadPredictionsAndTargets outputs: %s',
        10,
        (time_chunk_offsets, (predictions_chunk, targets_chunk)),
    )
    return [(time_chunk_offsets, (predictions_chunk, targets_chunk))]


# TODO(matthjw): Consider whether we could reuse xarray_beam.Key here and
# use more of the xarray_beam API to do the aggregation.
@dataclasses.dataclass(frozen=True)
class _AggregationKey:
  """Key under which statistics are aggregated (summed or combine_by_coords)."""

  type: Literal['sum_weighted_statistics', 'sum_weights']
  statistic_name: str
  variable_name: str
  # Offsets for the chunk in the result of the aggregation. Should be None if
  # the relevant dimension is being aggregated over.
  init_time_offset: int | None
  lead_time_offset: int | None

  def drop_offsets(self) -> '_AggregationKey':
    return dataclasses.replace(
        self, init_time_offset=None, lead_time_offset=None
    )


class ComputeStatisticsAggregateAndPrepareForCombine(beam.DoFn):
  """Computes statistics needed for our metrics, for a chunk of init/lead times.

  Then performs the initial per-chunk aggregation on them using the Aggregator,
  then prepares them for further aggregation by breaking the AggregationState
  up into separate DataArrays for each statistic, variable, type (sum_weights or
  sum_weighted_statistics) and chunk offset, keyed by _AggregationKey.
  """

  def __init__(
      self,
      metrics: Mapping[str, metrics_base.Metric],
      aggregator: aggregation.Aggregator,
  ):
    self.metrics = metrics
    self.aggregator = aggregator

  def process(
      self,
      all_inputs: tuple[
          time_chunks.TimeChunkOffsets,
          tuple[
              Mapping[Hashable, xr.DataArray],
              Mapping[Hashable, xr.DataArray],
          ],
      ],
  ) -> Iterator[tuple[_AggregationKey, xr.DataArray]]:
    """Yields statistics for further aggregation.

    Args:
      all_inputs: (time_chunk_offsets, (predictions_chunk, targets_chunk))

    Yields:
      Multiple key/value pairs (aggregation_key, data_array), where the
      aggregation_key identifying the scope for further aggregation.
    """
    logging.log_first_n(
        logging.INFO,
        'ComputeStatisticsAggregateAndPrepareForCombine inputs: %s',
        10,
        all_inputs,
    )
    time_chunk_offsets, (predictions_chunk, targets_chunk) = all_inputs

    # We use a generator below and yield one at a time, to avoid holding all
    # unaggregated statistics in memory all at once in case of large statistics.
    stats_iter = metrics_base.generate_unique_statistics_for_all_metrics(
        self.metrics, predictions_chunk, targets_chunk
    )

    while True:
      try:
        # Compute stats.
        start_time = time.time()
        stat_name, stats = next(stats_iter)

        # Create a short name so that it's more readable in dashboards.
        short_stat_name = (
            stat_name[:30] + '...' if len(stat_name) > 30 else stat_name
        )

        beam.metrics.Metrics.distribution(
            'ComputeStatistics', f'compute_{short_stat_name}'
        ).update(
            (time.time() - start_time) * 1000
        )  # In milliseconds because beam counters use longs.

        for var_name, stat in stats.items():
          start_time = time.time()
          aggregation_state = self.aggregator.aggregate_stat_var(stat)
          if aggregation_state is None:
            continue
          beam.metrics.Metrics.distribution(
              'ComputeStatistics', f'agg_{var_name}_{short_stat_name}'
          ).update(
              (time.time() - start_time) * 1000
          )  # In milliseconds because beam counters use longs.
          if 'init_time' in aggregation_state.sum_weighted_statistics.dims:
            init_time_offset = time_chunk_offsets.init_time
          else:
            init_time_offset = None
          if 'lead_time' in aggregation_state.sum_weighted_statistics.dims:
            lead_time_offset = time_chunk_offsets.lead_time
          else:
            lead_time_offset = None
          aggregation_key = _AggregationKey(
              type='sum_weighted_statistics',
              statistic_name=stat_name,
              variable_name=str(var_name),
              init_time_offset=init_time_offset,
              lead_time_offset=lead_time_offset,
          )
          yield aggregation_key, aggregation_state.sum_weighted_statistics
          aggregation_key = _AggregationKey(
              type='sum_weights',
              statistic_name=stat_name,
              variable_name=str(var_name),
              init_time_offset=init_time_offset,
              lead_time_offset=lead_time_offset,
          )
          yield aggregation_key, aggregation_state.sum_weights
      except StopIteration:
        break


class ConcatPerStatisticPerVariable(beam.PTransform):
  """Concatenates DataArrays on a per-statistic, per-variable basis.

  The DataArrays correspond to chunks along whichever of the {lead_time,
  init_time} dimensions are being preserved in the result. They arrive keyed
  by _AggregationKey.
  """

  def expand(
      self, pcoll: beam.PCollection[tuple[_AggregationKey, xr.DataArray]]
  ):

    def drop_offsets_from_key(
        key: _AggregationKey, data_array: xr.DataArray
    ) -> tuple[_AggregationKey, xr.DataArray]:
      return (key.drop_offsets(), data_array)

    def combine_data_arrays_by_coords(
        key: _AggregationKey, data_arrays: Iterable[xr.DataArray]
    ) -> tuple[_AggregationKey, xr.DataArray]:

      # To deal with overlapping coordinates to be combined other than init_time
      # and lead_time, we align them here first.
      data_arrays = xr.align(
          *data_arrays,
          join='outer',
          fill_value=0,
          exclude=['init_time', 'lead_time'],
      )
      # combine_by_coords will return a Dataset if there are any names on the
      # input DataArrays, so we remove the names before calling it.
      # We also drop zero-sized arrays since combine_by_coords cannot deal with
      # them.
      data_arrays = [d.rename(None) for d in data_arrays if d.size > 0]
      # Drop non-dimension coordinates that are not present in all arrays,
      # since combine_by_coords cannot handle mismatched coordinates.
      if data_arrays:
        shared_non_dim_coords = set.intersection(
            *[set(d.coords) - set(d.dims) for d in data_arrays]
        )
        data_arrays = [
            d.drop_vars([
                c
                for c in set(d.coords) - set(d.dims)
                if c not in shared_non_dim_coords
            ])
            for d in data_arrays
        ]
      # If all arrays are empty, we need to manually return an empty DataArray,
      # since combine_by_coords will return a Dataset in this case.
      if not data_arrays:
        return key, xr.DataArray()
      return key, xr.combine_by_coords(data_arrays)

    return (
        pcoll
        # Drop the chunk offsets from the key, so that we group by statistic
        # name, variable name and type (sum_weighted_statistics or sum_weights)
        # alone.
        | 'DropOffsetsFromKey' >> beam.MapTuple(drop_offsets_from_key)
        # We use GroupByKey instead of CombinePerKey because the data all needs
        # to be in memory at once to concatenate it, there is no saving from
        # doing this incrementally via a CombineFn.
        | 'GroupByStatAndVariable' >> beam.GroupByKey()
        | 'CombineDataArraysByCoords'
        >> beam.MapTuple(combine_data_arrays_by_coords)
    )


def reconstruct_aggregation_state(
    key_value_pairs: Iterable[tuple[_AggregationKey, xr.DataArray]],
) -> aggregation.AggregationState:
  """Reconstructs an AggregationState from (_AggregationKey, DataArray) pairs.

  Args:
    key_value_pairs: Component DataArrays of the AggregationState keyed by
      _AggregationKey, as generated by
      ComputeStatisticsAggregateAndPrepareForCombine above except that all
      chunks over the lead_time and init_time dimensions have been combined
      before we reach this stage.

  Returns:
    The reconstituted AggregationState containing all statistics and all
    variables.
  """
  sum_weighted_statistics = {}
  sum_weights = {}
  for key, stat in key_value_pairs:
    if key.type == 'sum_weighted_statistics':
      add_to = sum_weighted_statistics
    elif key.type == 'sum_weights':
      add_to = sum_weights
    else:
      assert False
    variables = add_to.setdefault(key.statistic_name, {})
    variables[key.variable_name] = stat
  return aggregation.AggregationState(sum_weighted_statistics, sum_weights)


class ReconstructAggregationState(beam.PTransform):
  """Reconstructs AggregationState from all (_AggregationKey, DataArray)."""

  def expand(
      self, pcoll: beam.PCollection[tuple[_AggregationKey, xr.DataArray]]
  ) -> beam.PCollection[aggregation.AggregationState]:
    return (
        pcoll | beam_utils.GroupAll() | beam.Map(reconstruct_aggregation_state)
    )


class ComputeMetrics(beam.DoFn):
  """Computes the metrics from the aggregated statistics."""

  def __init__(self, metrics: Mapping[str, metrics_base.Metric]):
    self.metrics = metrics

  def process(
      self, aggregation_state: aggregation.AggregationState
  ) -> Iterable[xr.Dataset]:
    """Computes a metrics Dataset from the final AggregationState."""
    logging.log_first_n(
        logging.INFO, 'ComputeMetrics inputs: %s', 10, aggregation_state
    )
    return [aggregation_state.metric_values(self.metrics)]


class WriteMetrics(beam.DoFn):
  """Writes the metrics to a NetCDF file."""

  def __init__(self, out_path: str):
    self.out_path = out_path

  def process(self, metrics: xr.Dataset) -> Iterable[Never]:
    logging.log_first_n(logging.INFO, 'WriteMetrics inputs: %s', 10, metrics)
    # Remove attributes that may have been propogated from the targets or
    # predictions.
    metrics = metrics.drop_attrs(deep=True)
    beam_utils.atomic_write(
        self.out_path,
        metrics.to_netcdf(),
    )
    return []


class WriteAggregationState(beam.DoFn):
  """Writes the final AggregationState to a NetCDF file."""

  def __init__(self, out_path: str):
    self.out_path = out_path

  def process(
      self, aggregation_state: aggregation.AggregationState
  ) -> Iterable[Never]:
    aggregation_state_ds = aggregation_state.to_dataset()
    # Remove attributes that may have been propogated from the targets or
    # predictions.
    aggregation_state_ds = aggregation_state_ds.drop_attrs(deep=True)
    beam_utils.atomic_write(
        self.out_path,
        aggregation_state_ds.to_netcdf(),
    )
    return []


[docs] def define_pipeline( root: beam.Pipeline, times: time_chunks.TimeChunks, predictions_loader: data_loaders_base.DataLoader, targets_loader: data_loaders_base.DataLoader, metrics: Mapping[str, metrics_base.Metric], aggregator: aggregation.Aggregator, out_path: str | None = None, aggregation_state_out_path: str | None = None, setup_fn: Optional[Callable[[], None]] = None, ): """Defines a beam pipeline for calculating aggregated metrics. Args: root: Pipeline root. times: TimeChunks instance. predictions_loader: DataLoader instance. targets_loader: DataLoader instance. metrics: A dictionary of metrics to compute. aggregator: Aggregation instance. out_path: The full path to write the metrics to. aggregation_state_out_path: The full path to write the final aggregation state to. This can be useful if you want to compute further metrics from it later, and if you are preserving init_time, it can be useful to compute confidence intervals from later too. setup_fn: (Optional) A function to call once per worker in LoadPredictionsAndTargets. """ agg_state_pipeline = ( root | 'CreateTimeChunks' >> beam.Create(times.iter_with_chunk_offsets()) | beam.ParDo( LoadPredictionsAndTargets( predictions_loader, targets_loader, setup_fn=setup_fn ) ) # Compute statistics for each chunk, perform the initial per-chunk # aggregation on them using the Aggregator, then prepare them for further # aggregation by breaking the AggregationState up into separate # DataArrays for each statistic, variable, type (sum_weights or # sum_weighted_statistics) and chunk offset. | beam.ParDo( ComputeStatisticsAggregateAndPrepareForCombine(metrics, aggregator) ) # Sum up the statistic DataArrays over dimensions of the TimeChunks that # we are reducing over, typically just init_time but can also be # lead_time. This is done separately for each statistic, each variable, # and each chunk offset along dimensions not being reduced over (e.g. # typically lead_time is not reduced over) | 'SumPerStatisticPerVariableAndPerUnreducedOffset' >> beam.CombinePerKey(beam_utils.CombiningSum()) # Now we've reduced the size of the data as much as we can by summing, # we concatenate the resulting chunks along any remaining dimensions where # we know that coordinates will not overlap across chunks. | ConcatPerStatisticPerVariable() # Finally we gather together all the concatenated chunks for all # statistics and variables and reconstitute the full AggregationState # from them, which we can use to compute the final values of metrics. | ReconstructAggregationState() ) if out_path is None and aggregation_state_out_path is None: raise ValueError( 'At least one of (metrics) out_path or aggregation_state_out_path must ' 'be specified.' ) if out_path is not None: _ = ( agg_state_pipeline | beam.ParDo(ComputeMetrics(metrics)) | beam.ParDo(WriteMetrics(out_path)) ) if aggregation_state_out_path is not None: _ = agg_state_pipeline | beam.ParDo( WriteAggregationState(aggregation_state_out_path) )
class ComputeAndFormatStatistics(beam.DoFn): """Computes statistics and formats them for xarray-beam.""" def __init__( self, metrics: Mapping[str, metrics_base.Metric], times: time_chunks.TimeChunks, ): """Init. Args: metrics: A dictionary of metrics to compute statistics for. times: TimeChunks instance providing chunk key logic. """ self.metrics = metrics self.times = times def process( self, element: tuple[ time_chunks.TimeChunkOffsets, tuple[ Mapping[Hashable, xr.DataArray], Mapping[Hashable, xr.DataArray], ], ], ) -> Iterable[tuple[xbeam.Key, xr.Dataset]]: """Computes statistics and yields (chunk_key, dataset) tuples.""" time_chunk_offsets, (predictions_chunk, targets_chunk) = element statistics_dict = metrics_base.compute_unique_statistics_for_all_metrics( self.metrics, predictions_chunk, targets_chunk ) for stat_name, var_dict in statistics_dict.items(): for var_name, da in var_dict.items(): name = f'{stat_name}.{var_name}' chunk_ds = xr.Dataset({name: da}) dim_order = [] offsets = {} if 'init_time' in chunk_ds.dims: dim_order.append('init_time') offsets['init_time'] = time_chunk_offsets.init_time if 'lead_time' in chunk_ds.dims: dim_order.append('lead_time') offsets['lead_time'] = time_chunk_offsets.lead_time chunk_ds = chunk_ds.transpose(*dim_order, ...) chunk_key = xbeam.Key(offsets, vars={name}) yield chunk_key, chunk_ds def _get_template_dataset( metrics: Mapping[str, metrics_base.Metric], predictions_loader: data_loaders_base.DataLoader, targets_loader: data_loaders_base.DataLoader, times: time_chunks.TimeChunks, setup_fn: Optional[Callable[[], None]] = None, ) -> xr.Dataset: """Computes statistics for the first chunk to create a template dataset.""" if setup_fn is not None: setup_fn() logging.info('Building template with data from first chunk') # Evaluate statistics on the first chunk first_chunk_index = 0 try: first_init_times, first_lead_times = times[first_chunk_index] except IndexError: raise ValueError('Cannot generate template: TimeChunks is empty') from None targets_chunk = targets_loader.load_chunk(first_init_times, first_lead_times) predictions_chunk = predictions_loader.load_chunk( first_init_times, first_lead_times, targets_chunk ) statistics_dict = metrics_base.compute_unique_statistics_for_all_metrics( metrics, predictions_chunk, targets_chunk ) first_chunk = xr.Dataset() for stat_name, var_dict in statistics_dict.items(): for var_name, da in var_dict.items(): first_chunk[f'{stat_name}.{var_name}'] = da # Convert the first chunk into a template, with the proper init_time and # lead_time dimensions template = xbeam.make_template(first_chunk) if 'mask' in template.coords: raise ValueError( 'mask coordinate found in template. add_nan_mask=True on data loaders ' 'is not supported for unaggregated pipelines.' ) if 'lead_time' in template.dims: vars_to_expand = [k for k, v in template.items() if 'lead_time' in v.dims] template = template.isel(lead_time=0, drop=True) lead_times = times.lead_times if isinstance(lead_times, slice): lead_times = np.arange( lead_times.start, lead_times.stop + lead_times.step, lead_times.step ) for k in vars_to_expand: template[k] = template[k].expand_dims(lead_time=lead_times) if 'init_time' in template.dims: vars_to_expand = [k for k, v in template.items() if 'init_time' in v.dims] template = template.isel(init_time=0, drop=True) for k in vars_to_expand: template[k] = template[k].expand_dims(init_time=times.init_times) if 'init_time' in template.dims and 'lead_time' in template.dims: template.coords['valid_time'] = template.init_time + template.lead_time return template # TOOD: shoyer - consider renaming this function to refer to "statistics" (vs # the metrics calculated by define_pipeline) def define_unaggregated_pipeline( root: beam.Pipeline, times: time_chunks.TimeChunks, predictions_loader: data_loaders_base.DataLoader, targets_loader: data_loaders_base.DataLoader, metrics: Mapping[str, metrics_base.Metric], out_path: str, zarr_chunks: Mapping[str, int] | None = None, setup_fn: Optional[Callable[[], None]] = None, ): """Defines a Beam pipeline that calculates statistics without aggregation. Outputs statistics for all predictions and targets to a single Zarr store, which assumes that all statistics have compatible coordinates. If this is not the case, you'll need to run separate pipelines for incompatible statistics. Args: root: Pipeline root. times: TimeChunks instance. Must implement `get_chunk_key(index)` returning a Dict[str, slice] and `get_zarr_chunks()` returning Dict[str, int]. predictions_loader: DataLoader instance for predictions. targets_loader: DataLoader instance for targets. metrics: A dictionary of metrics to compute statistics for. out_path: The full path to write the output Zarr store to. zarr_chunks: (Optional) A dictionary of chunks to use for the output Zarr store. If None, the chunks will match those of TimeChunks. setup_fn: (Optional) A function to call once per worker in LoadPredictionsAndTargets. """ template = _get_template_dataset( metrics, predictions_loader, targets_loader, times, setup_fn ) dim_sizes = typing.cast(Mapping[str, int], template.sizes) stat_chunks = {} for dim, size in dim_sizes.items(): if dim == 'init_time': stat_chunks[dim] = times.init_time_chunk_size or -1 elif dim == 'lead_time': stat_chunks[dim] = times.lead_time_chunk_size or -1 else: stat_chunks[dim] = size # unchunked if zarr_chunks is None: zarr_chunks = {} # Use any entries in stat_chunks as defaults for zarr_chunks. # Consider raising an error for missing dimensions instead? zarr_chunks = stat_chunks | zarr_chunks _ = ( root | 'CreateTimeChunks' >> beam.Create(times.iter_with_chunk_offsets()) | 'LoadPredictionsAndTargets' >> beam.ParDo( LoadPredictionsAndTargets( predictions_loader, targets_loader, setup_fn=setup_fn ) ) | 'ComputeAndFormatStatistics' >> beam.ParDo(ComputeAndFormatStatistics(metrics, times)) | 'Rechunk' >> xbeam.Rechunk( dim_sizes, stat_chunks, zarr_chunks, itemsize=4, # assumes float32 ) | 'WriteStatisticsToZarr' >> xbeam.ChunksToZarr( out_path, template=template, zarr_chunks=zarr_chunks ) )