Source code for weatherbenchX.data_loaders.base

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base data loader class."""

import abc
from typing import Collection, Hashable, Mapping, Optional, Union
import numpy as np
from weatherbenchX import interpolations
from weatherbenchX import xarray_tree
import xarray as xr


def add_nan_mask_to_data(
    data: Mapping[Hashable, xr.DataArray],
    variable_subset: Collection[str] | None = None,
) -> Mapping[Hashable, xr.DataArray]:
  """Adds a boolean coordinate named 'mask' to each variable with False indicating NaN values.

  Args:
    data: Data to add the mask to.
    variable_subset: If provided, only add the mask to the variables in this
      list. All other variables will have a mask that is always True (note that
      this is a bit wasteful in terms of memory!)

  Returns:
    The data with the mask added.
  """
  data = dict(data)
  for var in data:
    if not variable_subset or var in variable_subset:
      data[var].coords['mask'] = ~np.isnan(data[var])
    else:
      data[var].coords['mask'] = np.ones_like(data[var], dtype=bool)
  return data


[docs] class DataLoader(abc.ABC): """Base class for data loaders. Data loaders return chunks of data compatible with the rest of the evaluation framework. Specifically, this should be an xr.Dataset or a dictionary of xr.DataArray's. It is the data loaders' job to return target and prediction chunks that can be broadcast against each other. If interpolation is required to map one dataset to another, e.g. interpolating a gridded dataset to sparse points, a reference dataset can be provided for this purpose. """ def __init__( self, interpolation: Optional[interpolations.Interpolation] = None, compute: bool = True, add_nan_mask: bool = False, ): """Shared initialization for data loaders. Args: interpolation: (Optional) Interpolation to be applied to the data. compute: Load chunk into memory. Default: True. add_nan_mask: Adds a boolean coordinate named 'mask' to each variable (variables will be split into DataArrays if they aren't already), with False indicating NaN values. To be used for masked aggregation. Default: False. """ self._interpolation = interpolation self._compute = compute self._add_nan_mask = add_nan_mask @abc.abstractmethod def _load_chunk_from_source( self, init_times: np.ndarray, lead_times: Optional[Union[np.ndarray, slice]] = None, ) -> Mapping[Hashable, xr.DataArray]: """Method to be implemented by data loaders.""" def load_chunk( self, init_times: np.ndarray, lead_times: Optional[Union[np.ndarray, slice]] = None, reference: Optional[Mapping[Hashable, xr.DataArray]] = None, ) -> Mapping[Hashable, xr.DataArray]: """Shared API for loading data chunk. Implements interpolation, splitting variables and loading data into memory. Args: init_times: List of init_times. lead_times: (Optional) List of exact lead times or lead_time interval. The exact behavior for each option depends on the data loader and should be documented there. reference: (Optional) A reference dataset to be used inside the data loader, e.g. for interpolation. Returns: data_chunk: Xarray Dataset or dictionary of DataArrays containing data for given times. """ chunk = self._load_chunk_from_source(init_times, lead_times) if self._interpolation is not None: # TODO(srasp): Potentially implement consistency check between lead_times # and lead_time coordinate on reference. chunk = self._interpolation.interpolate(chunk, reference) # Compute after interpolation avoids loading unnecessary data. if self._compute: chunk = xarray_tree.map_structure(lambda x: x.compute(), chunk) if self._add_nan_mask: chunk = add_nan_mask_to_data(chunk) return chunk