# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data loaders for reading gridded Zarr files."""
from collections.abc import Hashable
from typing import Any, Callable, Iterable, Mapping, Optional, Union
from absl import logging
import numpy as np
from weatherbenchX.data_loaders import base
import xarray as xr
def _rename_dataset(
ds: xr.Dataset,
rename_dimensions: Optional[Union[Mapping[str, str], str]] = 'ecmwf',
rename_variables: Optional[Mapping[str, str]] = None,
convert_lat_lon_to_latitude_longitude: bool = True,
) -> xr.Dataset:
"""Rename dimensions and variables of Zarr dataset."""
# Rename dimensions
if convert_lat_lon_to_latitude_longitude:
if 'lat' in ds.coords and 'lon' in ds.coords:
ds = ds.rename({'lat': 'latitude', 'lon': 'longitude'})
if rename_dimensions == 'ecmwf': # ECMWF standard
if 'prediction_timedelta' in ds.coords: # Is forecast dataset
ds = ds.rename({'time': 'init_time', 'prediction_timedelta': 'lead_time'})
else: # Is (re-)analysis dataset
ds = ds.rename({'time': 'valid_time'})
elif isinstance(rename_dimensions, Mapping):
ds = ds.rename(rename_dimensions)
elif rename_dimensions is None:
pass
else:
raise ValueError(
'rename_dimensions must be either "ecmwf", a dict or None.'
)
# Rename variables
if rename_variables is not None:
ds = ds.rename(rename_variables)
return ds
[docs]
class XarrayDataLoader(base.DataLoader):
"""Base class for Xarray data loaders."""
def __init__(
self,
path: Optional[str] = None,
ds: Optional[xr.Dataset] = None,
variables: Optional[Iterable[str]] = None,
sel_kwargs: Optional[Mapping[str, Any]] = None,
rename_dimensions: Optional[Union[Mapping[str, str], str]] = 'ecmwf',
automatically_convert_lat_lon_to_latitude_longitude: bool = True,
rename_variables: Optional[Mapping[str, str]] = None,
preprocessing_fn: Optional[Callable[[xr.Dataset], xr.Dataset]] = None,
**kwargs,
):
"""Init.
Args:
path: (Optional) Path to xarray dataset to open. If it ends with '.zarr',
it is opened using xr.open_zarr. Otherwise, it is opened using
xr.open_dataset.
ds: (Optional) Already opened xarray dataset. Either path or ds must be
specified.
variables: (Optional) List of variables to load (after renaming). Default:
Load all variables.
sel_kwargs: (Optional) Keyword arguments to pass to .sel() after renaming.
rename_dimensions: (Optional) Dictionary of dimensions to rename. The data
loaders expect the following time dimensions: `init_time` and
`lead_time` for a forecast dataset; `valid_time` for target datasets
(e.g. reanalyses). rename_dimensions='ecmwf' (default) assumes ECMWF
standard names, {'time': 'init_time', 'prediction_timedelta':
'lead_time'} for prediction datasets and {'time': 'valid_time'} for
analysis datasets.
automatically_convert_lat_lon_to_latitude_longitude: (Optional) Whether to
automatically convert 'lat' and 'lon' dimensions to 'latitude' and
'longitude'. Default: True.
rename_variables: (Optional) Dictionary of variables to rename.
preprocessing_fn: (Optional) A function that is applied to the dataset
right after it is opened.
**kwargs: Keyword arguments to pass to base.DataLoader.
"""
if path is not None and ds is not None:
raise ValueError('Only one of path or ds can be specified, not both.')
if path is None and ds is None:
raise ValueError('Either path or ds must be specified.')
self._ds = ds
self._path = path
self._variables = variables
self._sel_kwargs = sel_kwargs
self._rename_dimensions = rename_dimensions
self._automatically_convert_lat_lon_to_latitude_longitude = (
automatically_convert_lat_lon_to_latitude_longitude
)
self._rename_variables = rename_variables
self._preprocessing_fn = preprocessing_fn
self._preprocessed = False
super().__init__(**kwargs)
def maybe_prepare_dataset(self):
"""Prepares the dataset (reads and preprocesses it, if not already done)."""
if self._preprocessed:
return
if self._ds is None:
logging.info('Opening dataset from path: %s', self._path)
assert self._path is not None
if self._path.rstrip('/').endswith('.zarr'):
self._ds = xr.open_zarr(self._path)
else:
self._ds = xr.open_dataset(self._path)
if self._preprocessing_fn is not None:
self._ds = self._preprocessing_fn(self._ds)
self._ds = _rename_dataset(
self._ds,
self._rename_dimensions,
self._rename_variables,
self._automatically_convert_lat_lon_to_latitude_longitude,
)
if self._variables is not None:
self._ds = self._ds[list(self._variables)]
if self._sel_kwargs is not None:
self._ds = self._ds.sel(**self._sel_kwargs)
self._preprocessed = True
def _load_chunk_from_source(
self,
init_times: np.ndarray,
lead_times: Optional[Union[np.ndarray, slice]] = None,
) -> Mapping[Hashable, xr.DataArray]:
raise NotImplementedError()
def load_chunk(
self,
init_times: np.ndarray,
lead_times: Optional[Union[np.ndarray, slice]] = None,
reference: Optional[Mapping[Hashable, xr.DataArray]] = None,
) -> Mapping[Hashable, xr.DataArray]:
self.maybe_prepare_dataset()
return super().load_chunk(init_times, lead_times, reference)
[docs]
class PredictionsFromXarray(XarrayDataLoader):
"""Data loader for reading prediction datasets from Xarray.
Example:
>>> init_times, lead_times
(array(['2020-01-01T00:00:00.000000000', '2020-01-01T12:00:00.000000000'],
dtype='datetime64[ns]'), array([0, 6], dtype='timedelta64[h]'))
>>> variables = ['2m_temperature', '10m_wind_speed']
>>> prediction_data_loader = PredictionsFromXarray(
>>> path=<PATH>,
>>> variables=variables,
>>> )
>>> prediction_data_loader.load_chunk(init_times, lead_times)
<xarray.Dataset>
Dimensions: (latitude: 32, longitude: 64, lead_time: 2, init_time:
2)
Coordinates:
* latitude (latitude) float64 -87.19 -81.56 -75.94 ... 81.56
87.19
* longitude (longitude) float64 0.0 5.625 11.25 ... 343.1 348.8
354.4
* lead_time (lead_time) timedelta64[ns] 00:00:00 06:00:00
* init_time (init_time) datetime64[ns] 2020-01-01
2020-01-01T12:00:00
Data variables:
10m_wind_speed (init_time, lead_time, longitude, latitude) float32
2.29 ...
2m_temperature (init_time, lead_time, longitude, latitude) float32
247.4...
"""
def _load_chunk_from_source(
self,
init_times: np.ndarray,
lead_times: Optional[Union[np.ndarray, slice]] = None,
) -> Mapping[Hashable, xr.DataArray]:
# Dataset should have been read during maybe_prepare_dataset.
assert self._ds is not None
# Exact lead times or lead time slice.
if lead_times is not None:
chunk = self._ds.sel(init_time=init_times, lead_time=lead_times)
# No lead times specified, return all.
else:
chunk = self._ds.sel(init_time=init_times)
return chunk
[docs]
class TargetsFromXarray(XarrayDataLoader):
"""Data loader for reading target datasets from Xarray.
Example:
>>> init_times, lead_times
(array(['2020-01-01T00:00:00.000000000', '2020-01-01T12:00:00.000000000'],
dtype='datetime64[ns]'), array([0, 6], dtype='timedelta64[h]'))
>>> variables = ['2m_temperature', '10m_wind_speed']
>>> target_data_loader = gridded_zarr.TargetsFromXarray(
>>> path=<PATH>,
>>> variables=variables,
>>> )
>>> target_data_loader.load_chunk(init_times, lead_times)
<xarray.Dataset>
Dimensions: (latitude: 32, longitude: 64, init_time: 2, lead_time:
2)
Coordinates:
* latitude (latitude) float64 -87.19 -81.56 -75.94 ... 81.56
87.19
* longitude (longitude) float64 0.0 5.625 11.25 ... 343.1 348.8
354.4
valid_time (init_time, lead_time) datetime64[ns] 2020-01-01 ...
2020...
* init_time (init_time) datetime64[ns] 2020-01-01
2020-01-01T12:00:00
* lead_time (lead_time) timedelta64[ns] 00:00:00 06:00:00
Data variables:
10m_wind_speed (init_time, lead_time, longitude, latitude) float32
2.221...
2m_temperature (init_time, lead_time, longitude, latitude) float32
248.5...
"""
def _load_chunk_from_source(
self,
init_times: np.ndarray,
lead_times: Optional[Union[np.ndarray, slice]] = None,
) -> Mapping[Hashable, xr.DataArray]:
# Dataset should have been read during maybe_prepare_dataset.
assert self._ds is not None
# Exact lead times.
if isinstance(lead_times, Iterable):
# Construct valid times from init and lead time combination.
valid_time = xr.DataArray(
init_times, coords={'init_time': init_times}
) + xr.DataArray(lead_times, coords={'lead_time': lead_times})
chunk = self._ds.sel(valid_time=valid_time)
# Lead time slice: not allowed.
elif isinstance(lead_times, slice):
raise ValueError('Lead time slice not supported for target data loaders.')
# No lead time slice, in this case treat the init times as valid times.
else:
chunk = self._ds.sel(valid_time=init_times)
return chunk
[docs]
class ClimatologyFromXarray(XarrayDataLoader):
"""Reads a climatology dataset as a predictions dataset."""
def __init__(
self,
climatology_time_coords: Iterable[str] = ('dayofyear', 'hour'),
rename_dimensions: Optional[Union[Mapping[str, str], str]] = None,
**kwargs
):
"""Init.
Args:
climatology_time_coords: The time coordinates of the climatology dataset
to select. Default: ('dayofyear', 'hour').
rename_dimensions: (Optional) Dictionary of dimensions to rename. Default:
None.
**kwargs: Other arguments to pass to XarrayDataLoader.
"""
super().__init__(rename_dimensions=rename_dimensions, **kwargs)
self._climatology_time_coords = climatology_time_coords
def _load_chunk_from_source(
self,
init_times: np.ndarray,
lead_times: Optional[Union[np.ndarray, slice]] = None,
) -> Mapping[Hashable, xr.DataArray]:
# Dataset should have been read during maybe_prepare_dataset.
assert self._ds is not None
# Exact lead times.
if isinstance(lead_times, Iterable):
# Construct valid times from init and lead time combination.
valid_time = xr.DataArray(
init_times, coords={'init_time': init_times}
) + xr.DataArray(lead_times, coords={'lead_time': lead_times})
sel_kwargs = {}
for coord in self._climatology_time_coords:
sel_kwargs[coord] = getattr(valid_time.dt, coord)
# Lead time slice: not allowed.
elif isinstance(lead_times, slice):
raise ValueError(
'Lead time slice not yet supported for climatology data loaders.'
)
# No lead time slice, in this case treat the init times as valid times.
else:
init_times = xr.DataArray(init_times, coords={'init_time': init_times})
sel_kwargs = {}
for coord in self._climatology_time_coords:
sel_kwargs[coord] = getattr(init_times.dt, coord)
chunk = self._ds.sel(sel_kwargs)
return chunk
[docs]
class PersistenceFromXarray(XarrayDataLoader):
"""Reads a target dataset as a prediction dataset by replicating data along lead times."""
def _load_chunk_from_source(
self,
init_times: np.ndarray,
lead_times: Optional[Union[np.ndarray, slice]] = None,
) -> Mapping[Hashable, xr.DataArray]:
# Dataset should have been read during maybe_prepare_dataset.
assert self._ds is not None
if lead_times is None or isinstance(lead_times, slice):
raise ValueError(
'Exact lead times must be specified for persistence data loader.'
)
chunk = self._ds.sel(valid_time=init_times).expand_dims(
{'lead_time': lead_times}
)
return chunk.rename({'valid_time': 'init_time'})
[docs]
class ProbabilisticClimatologyFromXarray(XarrayDataLoader):
"""Reads a target dataset and treats every year as an ensemble member.
For each valid_time, take the corresponding value for the same day of the year
and hour of the day from the target dataset between start and end year and
treat it as an ensemble member.
When querying the last day of a leap year, the loader will return the first
day of the following year for non-leap years.
This is used as a probablistic baseline for the WeatherBench website.
"""
def __init__(
self,
start_year: int,
end_year: int,
ensemble_dim: str = 'number',
**kwargs
):
"""Init.
Args:
start_year: The first year to include in the climatology.
end_year: The last year (incl.) to include in the climatology.
ensemble_dim: The dimension to use for the ensemble. Default: 'number'.
**kwargs: Other arguments to pass to XarrayDataLoader.
"""
super().__init__(**kwargs)
self._start_year = start_year
self._end_year = end_year
self._ensemble_dim = ensemble_dim
def _load_chunk_from_source(
self,
init_times: np.ndarray,
lead_times: Optional[Union[np.ndarray, slice]] = None,
) -> Mapping[Hashable, xr.DataArray]:
# Dataset should have been read during maybe_prepare_dataset.
assert self._ds is not None
if lead_times is None or isinstance(lead_times, slice):
raise ValueError(
'Exact lead times must be specified for persistence data loader.'
)
init_times = xr.DataArray(
init_times, dims=['init_time'], coords={'init_time': init_times}
)
lead_times = xr.DataArray(
lead_times, dims=['lead_time'], coords={'lead_time': lead_times}
)
valid_times = init_times + lead_times
doy = valid_times.dt.dayofyear
hod = valid_times.dt.hour
cat_times = []
for year in range(self._start_year, self._end_year + 1):
cat_times.append(
np.datetime64(str(year))
+ ((doy - 1) * 24 + hod)
* np.timedelta64(1, 'h').astype('timedelta64[ns]')
)
cat_times = xr.concat(
cat_times,
dim=xr.DataArray(
range(len(cat_times)),
dims=[self._ensemble_dim],
coords={self._ensemble_dim: range(len(cat_times))},
),
)
chunk = self._ds.sel(valid_time=cat_times)
return chunk
class ConstantLoader(base.DataLoader):
"""Loader class that returns a constant dataset."""
def __init__(self, constant_ds: xr.Dataset):
super().__init__()
self._constant_ds = constant_ds
def _load_chunk_from_source(
self,
init_times: np.ndarray,
lead_times: Optional[Union[np.ndarray, slice]] = None,
) -> Mapping[Hashable, xr.DataArray]:
return self._constant_ds