# Copyright 2025 Google LLC## Licensed under the Apache License, Version 2.0 (the "License");# you may not use this file except in compliance with the License.# You may obtain a copy of the License at## https://www.apache.org/licenses/LICENSE-2.0## Unless required by applicable law or agreed to in writing, software# distributed under the License is distributed on an "AS IS" BASIS,# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.# See the License for the specific language governing permissions and# limitations under the License."""Base data loader class."""importabcfromtypingimportCollection,Hashable,Mapping,Optional,UnionimportnumpyasnpfromweatherbenchXimportinterpolationsfromweatherbenchXimportxarray_treeimportxarrayasxrdefadd_nan_mask_to_data(data:Mapping[Hashable,xr.DataArray],variable_subset:Collection[str]|None=None,)->Mapping[Hashable,xr.DataArray]:"""Adds a boolean coordinate named 'mask' to each variable with False indicating NaN values. Args: data: Data to add the mask to. variable_subset: If provided, only add the mask to the variables in this list. All other variables will have a mask that is always True (note that this is a bit wasteful in terms of memory!) Returns: The data with the mask added. """data=dict(data)forvarindata:ifnotvariable_subsetorvarinvariable_subset:data[var].coords['mask']=~np.isnan(data[var])else:data[var].coords['mask']=np.ones_like(data[var],dtype=bool)returndata
[docs]classDataLoader(abc.ABC):"""Base class for data loaders. Data loaders return chunks of data compatible with the rest of the evaluation framework. Specifically, this should be an xr.Dataset or a dictionary of xr.DataArray's. It is the data loaders' job to return target and prediction chunks that can be broadcast against each other. If interpolation is required to map one dataset to another, e.g. interpolating a gridded dataset to sparse points, a reference dataset can be provided for this purpose. """def__init__(self,interpolation:Optional[interpolations.Interpolation]=None,compute:bool=True,add_nan_mask:bool=False,):"""Shared initialization for data loaders. Args: interpolation: (Optional) Interpolation to be applied to the data. compute: Load chunk into memory. Default: True. add_nan_mask: Adds a boolean coordinate named 'mask' to each variable (variables will be split into DataArrays if they aren't already), with False indicating NaN values. To be used for masked aggregation. Default: False. """self._interpolation=interpolationself._compute=computeself._add_nan_mask=add_nan_mask@abc.abstractmethoddef_load_chunk_from_source(self,init_times:np.ndarray,lead_times:Optional[Union[np.ndarray,slice]]=None,)->Mapping[Hashable,xr.DataArray]:"""Method to be implemented by data loaders."""defload_chunk(self,init_times:np.ndarray,lead_times:Optional[Union[np.ndarray,slice]]=None,reference:Optional[Mapping[Hashable,xr.DataArray]]=None,)->Mapping[Hashable,xr.DataArray]:"""Shared API for loading data chunk. Implements interpolation, splitting variables and loading data into memory. Args: init_times: List of init_times. lead_times: (Optional) List of exact lead times or lead_time interval. The exact behavior for each option depends on the data loader and should be documented there. reference: (Optional) A reference dataset to be used inside the data loader, e.g. for interpolation. Returns: data_chunk: Xarray Dataset or dictionary of DataArrays containing data for given times. """chunk=self._load_chunk_from_source(init_times,lead_times)ifself._interpolationisnotNone:# TODO(srasp): Potentially implement consistency check between lead_times# and lead_time coordinate on reference.chunk=self._interpolation.interpolate(chunk,reference)# Compute after interpolation avoids loading unnecessary data.ifself._compute:chunk=xarray_tree.map_structure(lambdax:x.compute(),chunk)ifself._add_nan_mask:chunk=add_nan_mask_to_data(chunk)returnchunk