# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definition of interpolation classes."""
import abc
from collections.abc import Iterable
import dataclasses
from typing import Hashable, Mapping, Optional, Sequence, Union
import numpy as np
from weatherbenchX import xarray_tree
from weatherbenchX.metrics import spatial
from weatherbenchX.metrics import wrappers
import xarray as xr
[docs]
class Interpolation(abc.ABC):
"""Interpolation base class."""
@abc.abstractmethod
def interpolate_data_array(
self,
da: xr.DataArray,
reference: Optional[xr.DataArray] = None,
) -> xr.DataArray:
"""Implementation of the interpolation function for a single variable."""
def interpolate(
self,
ds: Mapping[Hashable, xr.DataArray],
reference: Optional[Mapping[Hashable, xr.DataArray]] = None,
) -> Mapping[Hashable, xr.DataArray]:
"""Interpolates dataset, potentially according to a reference dataset.
Args:
ds: Xarray dataset to be interpolated.
reference: Optional reference dataset, e.g. target.
Returns:
interpolated_ds: Interpolated dataset.
"""
if reference is None:
return xarray_tree.map_structure(self.interpolate_data_array, ds)
else:
return xarray_tree.map_structure(
self.interpolate_data_array, ds, reference
)
[docs]
@dataclasses.dataclass
class MultipleInterpolation(Interpolation):
"""Applies multiple interpolations to a dataset in sequence.
Attributes:
interpolations: List of interpolations to be applied in sequence.
"""
interpolations: Sequence[Interpolation]
def interpolate_data_array(
self,
da: xr.DataArray,
reference: Optional[xr.DataArray] = None,
) -> xr.DataArray:
for interpolation in self.interpolations:
da = interpolation.interpolate_data_array(da, reference)
return da
def pad_longitude(da: xr.DataArray) -> xr.DataArray:
"""Pad longitude values to allow for wrapped interpolation."""
left = da.isel(longitude=[-1])
left = left.assign_coords(longitude=left.longitude.values - 360)
right = da.isel(longitude=[0])
right = right.assign_coords(longitude=right.longitude.values + 360)
return xr.concat([left, da, right], 'longitude')
def interpolate_to_coords(
da: xr.DataArray,
dim_args: Mapping[str, Union[xr.DataArray, np.ndarray]],
method: str,
extrapolate_out_of_bounds: bool = True,
) -> xr.DataArray:
"""Interpolate to a fixed set of coordinates."""
if extrapolate_out_of_bounds:
# See xarray documentation for interpolation behaviour.
# https://docs.xarray.dev/en/latest/generated/xarray.DataArray.interp.html
if len(dim_args) > 1:
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.interpn.html
interp_kwargs = {'fill_value': None}
else:
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.interp1d.html
interp_kwargs = {'fill_value': 'extrapolate'}
else:
interp_kwargs = None
out = da.interp(
**dim_args,
method=method,
kwargs=interp_kwargs,
) # pytype: disable=wrong-arg-types
return out
class CropToBox(Interpolation):
"""Crops the dataset to the given bounding box.
Since interpolation is called before compute(), this can be useful to reduce
the amount of data that is read into memory when you are only interested in
a particular area.
This is essentially a wrapper around an xarray.Dataset.sel() call.
"""
def __init__(
self,
lat_min: float,
lat_max: float,
lon_min: float,
lon_max: float,
):
"""Init.
Args:
lat_min: Minimum latitude to crop to (inclusive).
lat_max: Maximum latitude to crop to (inclusive).
lon_min: Minimum longitude to crop to (exclusive).
lon_max: Maximum longitude to crop to (exclusive).
"""
if lat_min > lat_max:
raise ValueError('Invalid latitudes: {lat_min} and {lat_max}')
if lon_min > lon_max:
raise ValueError('Invalid longitudes: {lon_min} and {lon_max}')
self._lat_min = lat_min
self._lat_max = lat_max
self._lon_min = lon_min
self._lon_max = lon_max
def interpolate_data_array(
self,
da: xr.DataArray,
reference: Optional[xr.DataArray] = None,
) -> xr.DataArray:
# Some datasets have latitude in the descending order, or longitude that
# wraps around, so just in case, we will sort by those coordinates first.
da = da.sortby('longitude', ascending=True)
da = da.sortby('latitude', ascending=True)
da = da.sel(
latitude=slice(self._lat_min, self._lat_max),
longitude=slice(self._lon_min, self._lon_max),
)
return da
[docs]
class InterpolateToFixedCoords(Interpolation):
"""Interpolate to a fixed set of coordinates.
Interplation is done using xarray's built-in interp method:
https://docs.xarray.dev/en/latest/generated/xarray.DataArray.interp.html
"""
def __init__(
self,
method: str,
coords: Mapping[str, Union[xr.DataArray, np.ndarray]],
wrap_longitude: bool = False,
extrapolate_out_of_bounds: bool = True,
):
"""Init.
Args:
method: Interpolation method to be passed to xarray's interpolation API.
coords: Dictionary of coordinate names and values to interpolate to.
wrap_longitude: If True, perform a wrapped interpolation in the longitude
dimension. Default: False
extrapolate_out_of_bounds: If True, extrapolate to out of bounds values
using the chosen interpolation method. Default: True
"""
self._method = method
self._coords = coords
self._wrap_longitude = wrap_longitude
self._extrapolate_out_of_bounds = extrapolate_out_of_bounds
def interpolate_data_array(
self,
da: xr.DataArray,
reference: Optional[xr.DataArray] = None,
) -> xr.DataArray:
if self._wrap_longitude:
# TODO(srasp): Raise error if this isn't True but seems like it should be.
da = pad_longitude(da)
interpolated_da = interpolate_to_coords(
da,
self._coords,
self._method,
self._extrapolate_out_of_bounds,
)
return interpolated_da
[docs]
class InterpolateToReferenceCoords(Interpolation):
"""Interpolate to a reference dataset.
Interplation is done using xarray's built-in interp method:
https://docs.xarray.dev/en/latest/generated/xarray.DataArray.interp.html
"""
def __init__(
self,
method: str,
dims: Optional[Sequence[str]] = None,
wrap_longitude: bool = False,
clip_reference_coords: Optional[Iterable[str]] = None,
extrapolate_out_of_bounds: bool = True,
):
"""Init.
Args:
method: Interpolation method to be passed to xarray's interpolation API.
dims: (Optional) Dimensions over which to interpolate. If None (default),
infer dimensions from intersect of DataArray dimensions and reference
coordinates.
wrap_longitude: If True, perform a wrapped interpolation in the longitude
dimension. Default: False
clip_reference_coords: Clip the reference dataset to the maximum extent of
the data to be interpolated in the given dimensions, e.g. ['latitude',
'longitude']. Note that this can potentially lead to errors in the
reference go unnoticed. It is preferred to use a fixed interpolation
instead or ensure that the reference extent matches beforehand. Default:
None.
extrapolate_out_of_bounds: If True, extrapolate to out of bounds values
using the chosen interpolation method. Default: True
"""
self._method = method
self._dims = dims
self._wrap_longitude = wrap_longitude
self._clip_reference_coords = clip_reference_coords
self._extrapolate_out_of_bounds = extrapolate_out_of_bounds
def interpolate_data_array(
self,
da: xr.DataArray,
reference: xr.DataArray, # pytype: disable=signature-mismatch
) -> xr.DataArray:
if self._wrap_longitude:
da = pad_longitude(da)
if self._clip_reference_coords is not None:
for coord in self._clip_reference_coords:
reference = reference.sel(
{coord: slice(da[coord].min(), da[coord].max())}
)
# If dims not explicit, interpolate all dims that have a corresponding
# coordinate in the reference.
if self._dims is None:
dims = [d for d in da.dims if d in reference.coords]
else:
dims = self._dims
# Catch case where reference doesn't contain any data.
if reference.size == 0:
# Need to make sure to retain any dimensions that are not being
# interpolated.
da_dims_to_retain = set(da.dims) - set(dims)
return reference.copy().expand_dims({d: da[d] for d in da_dims_to_retain})
dim_args = {dim: reference[dim] for dim in dims}
da_like_reference = interpolate_to_coords(
da,
dim_args,
self._method,
self._extrapolate_out_of_bounds,
)
return da_like_reference
LAPSE_RATE_K_PER_M = -0.0065 # Standard atmosphere lapse rate.
[docs]
class GridToSparseWithAltitudeAdjustment(InterpolateToReferenceCoords):
"""Applies altitude adjustment to 2m_temperature and 10m_wind_speed.
Alititude adjustments are based on the difference of the grid elevation to the
station elevation. Reference:
https://rmets.onlinelibrary.wiley.com/doi/10.1002/qj.2372, Section 3.3.
Assumes that elevations are in meters and an 'elevation' coordinate exists on
the reference dataset. Requires passing a DataArray with the grid elevation
corresponding to the dataset to be interpolated. Variables must be named
'2m_temperature' and '10m_wind_speed'. Other variables will be left unchanged.
"""
def __init__(
self,
method: str,
grid_elevation: xr.DataArray,
dims: Optional[Sequence[str]] = None,
wrap_longitude: bool = False,
extrapolate_out_of_bounds: bool = True,
max_alititude_diff_in_m: float = 1500,
):
"""Init.
Args:
method: Interpolation method to be passed to xarray's interpolation API.
grid_elevation: DataArray matching the dataset coordinates specifying the
grid box elevation in m.
dims: (Optional) Dimensions over which to interpolate. If None (default),
infer dimensions from intersect of DataArray dimensions and reference
coordinates.
wrap_longitude: If True, perform a wrapped interpolation in the longitude
dimension. Default: False
extrapolate_out_of_bounds: If True, extrapolate to out of bounds values
using the chosen interpolation method. Default: True
max_alititude_diff_in_m: No adjustment is applied for elevation
differences greater than this value. Large values can appear because of
errors in the station dataset, e.g. elevation reported in ft instead of
m. Default: 1500.
"""
self._grid_elevation = grid_elevation
self._max_alititude_diff_in_m = max_alititude_diff_in_m
super().__init__(
method=method,
dims=dims,
wrap_longitude=wrap_longitude,
extrapolate_out_of_bounds=extrapolate_out_of_bounds,
)
def interpolate_data_array(
self,
da: xr.DataArray,
reference: xr.DataArray, # pytype: disable=signature-mismatch
) -> xr.DataArray:
if da.name in ['2m_temperature', '10m_wind_speed']:
da.coords['grid_elevation'] = self._grid_elevation.compute()
da_like_reference = super().interpolate_data_array(da, reference)
if da.name in ['2m_temperature', '10m_wind_speed']:
# Positive if station is higher than grid.
sparse_higher_than_grid_m = (
da_like_reference['elevation'] - da_like_reference['grid_elevation']
)
# Set "unrealistic" differences to 0.
sparse_higher_than_grid_m = sparse_higher_than_grid_m.where(
np.abs(sparse_higher_than_grid_m) < self._max_alititude_diff_in_m, 0
)
if da.name == '2m_temperature':
adjustment = sparse_higher_than_grid_m * LAPSE_RATE_K_PER_M
da_like_reference += adjustment
elif da.name == '10m_wind_speed':
# Only adjust stations > 100m above model orography.
adjustment_factor = xr.ones_like(sparse_higher_than_grid_m)
# Subtract 100m from the difference. I couldn't find this in the paper
# but it does make sense so that the different regimes overlap.
dz = sparse_higher_than_grid_m - 100
adjustment_factor = adjustment_factor.where(
sparse_higher_than_grid_m < 100,
1 + 0.002 * dz,
)
adjustment_factor = adjustment_factor.where(
sparse_higher_than_grid_m < 1100, 3
)
da_like_reference *= adjustment_factor
return da_like_reference
[docs]
class NeighborhoodThresholdProbabilities(Interpolation):
"""Converts a deterministic forecast to a probabilistic one by neighborhood averaging.
For a given threshold, the probability is devined as the fraction of the
fraction of pixels in a square neighborhood that exceeds the threshold. This
is the same computation as in the Fraction Skill Score.
"""
def __init__(
self,
neighborhood_sizes,
thresholds,
threshold_dim='threshold_value',
wrap_longitude: bool = False,
):
"""Init.
Args:
neighborhood_sizes: List of neighborhood sizes to be used in pixels. Must
be odd.
thresholds: List of thresholds to be used to binarize data.
threshold_dim: Dimension name of the thresholds. Default:
'threshold_value'
wrap_longitude: If True, perform a wrapped convolution in the longitude
dimension. Default: False
"""
self._neighborhood_sizes = neighborhood_sizes
self._thresholds = thresholds
self._threshold_dim = threshold_dim
self._wrap_longitude = wrap_longitude
def interpolate_data_array(
self,
da: xr.DataArray,
reference: Optional[xr.DataArray] = None,
) -> xr.DataArray:
da = wrappers.binarize_thresholds(
da, thresholds=self._thresholds, threshold_dim=self._threshold_dim
)
out = []
for n in self._neighborhood_sizes:
out.append(
spatial.neighborhood_averaging_for_single_size(
da, n, wrap_longitude=self._wrap_longitude
)
)
out = xr.concat(
out,
dim=xr.DataArray(
self._neighborhood_sizes, dims=['smoothing_neighborhood']
),
)
return out
class Subsample(Interpolation):
"""Subsample a DataArray along specified dimensions.
This is useful for reducing the resolution of a dataset without interpolation,
e.g. for faster evaluation at lower resolution.
"""
def __init__(
self,
dims: Sequence[str],
stride: int,
):
"""Init.
Args:
dims: Dimensions along which to subsample.
stride: Stride for subsampling. Must be a positive integer.
"""
if stride < 1:
raise ValueError(f'stride must be >= 1, got {stride}')
self._dims = dims
self._stride = stride
def interpolate_data_array(
self,
da: xr.DataArray,
reference: Optional[xr.DataArray] = None,
) -> xr.DataArray:
isel_kwargs = {
dim: slice(None, None, self._stride)
for dim in self._dims
if dim in da.dims
}
return da.isel(**isel_kwargs)