# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Defines the beam pipeline for evaluation."""
import logging
from typing import Callable, Hashable, Iterable, Mapping, Optional, Tuple, Union
import apache_beam as beam
import fsspec
import numpy as np
from weatherbenchX import aggregation
from weatherbenchX import beam_utils
from weatherbenchX import time_chunks
from weatherbenchX.data_loaders import base as data_loaders_base
from weatherbenchX.metrics import base as metrics_base
import xarray as xr
class LoadPredictionsAndTargets(beam.DoFn):
"""Loads prediction and target chunks from their respective data loaders."""
def __init__(
self,
predictions_loader: data_loaders_base.DataLoader,
targets_loader: data_loaders_base.DataLoader,
setup_fn: Optional[Callable[[], None]] = None,
):
"""Init.
Args:
predictions_loader: The data loader for the predictions.
targets_loader: The data loader for the targets.
setup_fn: (Optional) A function to call once per worker.
"""
self.predictions_loader = predictions_loader
self.targets_loader = targets_loader
self.setup_fn = setup_fn
self.is_initialized = False
def setup(self):
# Call this function once per process.
if self.setup_fn is not None:
if not self.is_initialized:
self.setup_fn()
self.is_initialized = True
def process(
self, all_inputs: Tuple[int, Tuple[np.ndarray, Union[np.ndarray, slice]]]
) -> Iterable[
Tuple[
int,
Tuple[
Mapping[Hashable, xr.DataArray],
Mapping[Hashable, xr.DataArray],
],
]
]:
"""Returns the predictions and targets chunks for a given init/lead time.
Args:
all_inputs: (chunk_index, (init_times, lead_times))
Returns:
(chunk_index, (predictions_chunk, targets_chunk))
"""
logging.info('LoadPredictionsAndTargets inputs: %s', all_inputs)
chunk_index, (init_times, lead_times) = all_inputs
targets_chunk = self.targets_loader.load_chunk(init_times, lead_times)
predictions_chunk = self.predictions_loader.load_chunk(
init_times, lead_times, targets_chunk
)
logging.info(
'LoadPredictionsAndTargets outputs: %s',
(chunk_index, (predictions_chunk, targets_chunk)),
)
return [(chunk_index, (predictions_chunk, targets_chunk))]
class ComputeStatisticsAndAggregateChunks(beam.DoFn):
"""Computes the statistics for each metric and aggregates chunks."""
def __init__(
self,
metrics: Mapping[str, metrics_base.Metric],
aggregator: aggregation.Aggregator,
):
"""Init.
Args:
metrics: A dictionary of metrics to compute.
aggregator: Aggregation instance.
"""
self.metrics = metrics
self.aggregator = aggregator
def process(
self,
all_inputs: Tuple[
int,
Tuple[
Mapping[Hashable, xr.DataArray],
Mapping[Hashable, xr.DataArray],
],
],
) -> Iterable[Tuple[int, aggregation.AggregationState]]:
"""Returns AggregationState for given predictions and targets chunks.
Args:
all_inputs: (chunk_index, (predictions_chunk, targets_chunk))
Returns:
(chunk_index, aggregation_state)
"""
logging.info('ComputeStatisticsAndAggregateChunks inputs: %s', all_inputs)
chunk_index, (predictions_chunk, targets_chunk) = all_inputs
statistics = metrics_base.compute_unique_statistics_for_all_metrics(
self.metrics, predictions_chunk, targets_chunk
)
aggregation_state = self.aggregator.aggregate_statistics(statistics)
logging.info(
'ComputeStatisticsAndAggregateChunks outputs: %s',
(chunk_index, aggregation_state),
)
return [(chunk_index, aggregation_state)]
class ComputeMetrics(beam.DoFn):
"""Computes the metrics from the aggregated statistics."""
def __init__(self, metrics: Mapping[Hashable, metrics_base.Metric]):
"""Init.
Args:
metrics: A dictionary of metrics to compute. Same as passed to the
ComputeStatisticsAndAggregateChunks.
"""
self.metrics = metrics
def process(
self, aggregation_state: aggregation.AggregationState
) -> Iterable[xr.Dataset]:
"""Returns results Dataset from AggregationState.
Args:
aggregation_state: The AggregationState to compute the metrics from.
Returns:
A Dataset with the metrics (in a list for Beam).
"""
logging.info('ComputeMetrics inputs: %s', aggregation_state)
return [aggregation_state.metric_values(self.metrics)]
class WriteMetrics(beam.DoFn):
"""Writes the metrics to a file."""
def __init__(self, out_path: str):
"""Init.
Args:
out_path: The full path to write the metrics to.
"""
self.out_path = out_path
def process(self, metrics: xr.Dataset):
"""Writes the metrics to a NetCDF file.
Args:
metrics: Metrics dataset to write to disc.
"""
logging.info('WriteMetrics inputs: %s', metrics)
with fsspec.open(self.out_path, 'wb', auto_mkdir=True) as f:
f.write(metrics.to_netcdf())
[docs]
def define_pipeline(
root: beam.Pipeline,
times: time_chunks.TimeChunks,
predictions_loader: data_loaders_base.DataLoader,
targets_loader: data_loaders_base.DataLoader,
metrics: Mapping[str, metrics_base.Metric],
aggregator: aggregation.Aggregator,
out_path: str,
max_chunks_per_aggregation_stage: Optional[int] = 10,
setup_fn: Optional[Callable[[], None]] = None,
):
"""Defines the beam pipeline.
Args:
root: Pipeline root.
times: TimeChunks instance.
predictions_loader: DataLoader instance.
targets_loader: DataLoader instance.
metrics: A dictionary of metrics to compute.
aggregator: Aggregation instance.
out_path: The full path to write the metrics to.
max_chunks_per_aggregation_stage: The maximum number of chunks to aggregate
in a single worker. If None, does aggregation in a single step. Default:
10
setup_fn: (Optional) A function to call once per worker in
LoadPredictionsAndTargets.
"""
if max_chunks_per_aggregation_stage is None:
max_chunks_per_aggregation_stage = len(times)
_ = (
root
| 'CreateTimeChunks' >> beam.Create(enumerate(times)) # pytype: disable=wrong-arg-types
| 'LoadPredictionsAndTargets'
>> beam.ParDo(
LoadPredictionsAndTargets(
predictions_loader, targets_loader, setup_fn=setup_fn
)
)
| 'ComputeStatisticsAndAggregateChunks'
>> beam.ParDo(ComputeStatisticsAndAggregateChunks(metrics, aggregator))
| 'AggregateStates'
>> beam_utils.CombineMultiStage(
total_num_elements=len(times),
max_bin_size=max_chunks_per_aggregation_stage,
combine_fn=beam_utils.SumAggregationStates(),
)
| 'ComputeMetrics' >> beam.ParDo(ComputeMetrics(metrics))
| 'WriteMetrics' >> beam.ParDo(WriteMetrics(out_path))
)