Implement a new data loader

WB-X can, in theory, deal with any kind of data. All that is required is an appropriate data loader that, for a given init/lead time combination returns an xr.Dataset.

Let’s go through the building blocks of a data loader. The __init__ of the data loader base class requires three arguments:

Args:
  interpolation: (Optional) Interpolation to be applied to the data.
  compute: Load chunk into memory. Default: True.
  add_nan_mask: Adds a boolean coordinate named 'mask' to each variable
    (variables will be split into DataArrays if they are not already), with
    False indicating NaN values. To be used for masked aggregation. Default:
    False.

The next and only other thing that needs to be implemented is _load_chunk_from_source() which takes the init/lead times and returns the appropriate data array.

Note that lead_time can be either an array or a slice. The latter case is used in cases where target or prediction data comes at random times (e.g. in the case of weather station data). In most cases, however, you probably want exact lead times.

Depending on how the data is stored, the init/lead times can be accessed in a single call, e.g. for Zarr data, or separately. The example below shows a loop over init/lead time in a case where each init/lead needs to be accessed separately.

class MyNewDataLoader(data_loader_base.DataLoader):
  def __init__(
      self,
      *args,
      interpolation: Optional[interpolations.Interpolation] = None,
      compute: bool = True,
      add_nan_mask: bool = False,
  ):
    super().__init__(
        interpolation=interpolation,
        compute=compute,
        add_nan_mask=add_nan_mask,
    )

  def _load_chunk_from_source(
      self,
      init_times: np.ndarray,
      lead_times: Optional[Union[np.ndarray, slice]] = None,
  ) -> Mapping[Hashable, xr.DataArray]:
    if not isinstance(lead_times, np.ndarray):
      raise ValueError('Only exact lead times are supported.')

    datasets = []
    for init_time in init_times:
      for lead_time in lead_times:
        ds = some_data_loading_function(init_time, lead_time)
        datasets.append(ds)
    chunk = xr.merge(datasets)
    return chunk