# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Definition of interpolation classes."""
import abc
from collections.abc import Iterable
import dataclasses
from typing import Hashable, Mapping, Optional, Sequence, Union
import numpy as np
from weatherbenchX import xarray_tree
from weatherbenchX.metrics import spatial
from weatherbenchX.metrics import wrappers
import xarray as xr
[docs]
class Interpolation(abc.ABC):
"""Interpolation base class."""
@abc.abstractmethod
def interpolate_data_array(
self,
da: xr.DataArray,
reference: Optional[xr.DataArray] = None,
) -> xr.DataArray:
"""Implementation of the interpolation function for a single variable."""
def interpolate(
self,
ds: Mapping[Hashable, xr.DataArray],
reference: Optional[Mapping[Hashable, xr.DataArray]] = None,
) -> Mapping[Hashable, xr.DataArray]:
"""Interpolates dataset, potentially according to a reference dataset.
Args:
ds: Xarray dataset to be interpolated.
reference: Optional reference dataset, e.g. target.
Returns:
interpolated_ds: Interpolated dataset.
"""
if reference is None:
return xarray_tree.map_structure(self.interpolate_data_array, ds)
else:
return xarray_tree.map_structure(
self.interpolate_data_array, ds, reference
)
[docs]
@dataclasses.dataclass
class MultipleInterpolation(Interpolation):
"""Applies multiple interpolations to a dataset in sequence.
Attributes:
interpolations: List of interpolations to be applied in sequence.
"""
interpolations: Sequence[Interpolation]
def interpolate_data_array(
self,
da: xr.DataArray,
reference: Optional[xr.DataArray] = None,
) -> xr.DataArray:
for interpolation in self.interpolations:
da = interpolation.interpolate_data_array(da, reference)
return da
def pad_longitude(da: xr.DataArray) -> xr.DataArray:
"""Pad longitude values to allow for wrapped interpolation."""
left = da.isel(longitude=[-1])
left = left.assign_coords(longitude=left.longitude.values - 360)
right = da.isel(longitude=[0])
right = right.assign_coords(longitude=right.longitude.values + 360)
return xr.concat([left, da, right], 'longitude')
def interpolate_to_coords(
da: xr.DataArray,
dim_args: Mapping[str, Union[xr.DataArray, np.ndarray]],
method: str,
extrapolate_out_of_bounds: bool = True,
) -> xr.DataArray:
"""Interpolate to a fixed set of coordinates."""
if extrapolate_out_of_bounds:
# See xarray documentation for interpolation behaviour.
# https://docs.xarray.dev/en/latest/generated/xarray.DataArray.interp.html
if len(dim_args) > 1:
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.interpn.html
interp_kwargs = {'fill_value': None}
else:
# https://docs.scipy.org/doc/scipy/reference/generated/scipy.interpolate.interp1d.html
interp_kwargs = {'fill_value': 'extrapolate'}
else:
interp_kwargs = None
out = da.interp(
**dim_args,
method=method,
kwargs=interp_kwargs,
) # pytype: disable=wrong-arg-types
return out
[docs]
class InterpolateToFixedCoords(Interpolation):
"""Interpolate to a fixed set of coordinates.
Interplation is done using xarray's built-in interp method:
https://docs.xarray.dev/en/latest/generated/xarray.DataArray.interp.html
"""
def __init__(
self,
method: str,
coords: Mapping[str, Union[xr.DataArray, np.ndarray]],
wrap_longitude: bool = False,
extrapolate_out_of_bounds: bool = True,
):
"""Init.
Args:
method: Interpolation method to be passed to xarray's interpolation API.
coords: Dictionary of coordinate names and values to interpolate to.
wrap_longitude: If True, perform a wrapped interpolation in the longitude
dimension. Default: False
extrapolate_out_of_bounds: If True, extrapolate to out of bounds values
using the chosen interpolation method. Default: True
"""
self._method = method
self._coords = coords
self._wrap_longitude = wrap_longitude
self._extrapolate_out_of_bounds = extrapolate_out_of_bounds
def interpolate_data_array(
self,
da: xr.DataArray,
reference: Optional[xr.DataArray] = None,
) -> xr.DataArray:
if self._wrap_longitude:
# TODO(srasp): Raise error if this isn't True but seems like it should be.
da = pad_longitude(da)
interpolated_da = interpolate_to_coords(
da,
self._coords,
self._method,
self._extrapolate_out_of_bounds,
)
return interpolated_da
[docs]
class InterpolateToReferenceCoords(Interpolation):
"""Interpolate to a reference dataset.
Interplation is done using xarray's built-in interp method:
https://docs.xarray.dev/en/latest/generated/xarray.DataArray.interp.html
"""
def __init__(
self,
method: str,
dims: Optional[Sequence[str]] = None,
wrap_longitude: bool = False,
clip_reference_coords: Optional[Iterable[str]] = None,
extrapolate_out_of_bounds: bool = True,
):
"""Init.
Args:
method: Interpolation method to be passed to xarray's interpolation API.
dims: (Optional) Dimensions over which to interpolate. If None (default),
infer dimensions from intersect of DataArray dimensions and reference
coordinates.
wrap_longitude: If True, perform a wrapped interpolation in the longitude
dimension. Default: False
clip_reference_coords: Clip the reference dataset to the maximum extent of
the data to be interpolated in the given dimensions, e.g. ['latitude',
'longitude']. Note that this can potentially lead to errors in the
reference go unnoticed. It is preferred to use a fixed interpolation
instead or ensure that the reference extent matches beforehand. Default:
None.
extrapolate_out_of_bounds: If True, extrapolate to out of bounds values
using the chosen interpolation method. Default: True
"""
self._method = method
self._dims = dims
self._wrap_longitude = wrap_longitude
self._clip_reference_coords = clip_reference_coords
self._extrapolate_out_of_bounds = extrapolate_out_of_bounds
def interpolate_data_array(
self,
da: xr.DataArray,
reference: xr.DataArray, # pytype: disable=signature-mismatch
) -> xr.DataArray:
# Catch case where reference doesn't contain any data.
if len(reference) == 0:
return reference.copy()
if self._wrap_longitude:
da = pad_longitude(da)
if self._clip_reference_coords is not None:
for coord in self._clip_reference_coords:
reference = reference.sel(
{coord: slice(da[coord].min(), da[coord].max())}
)
# If dims not explicit, interpolate all dims that have a corresponding
# coordinate in the reference.
if self._dims is None:
dims = [d for d in da.dims if d in reference.coords]
else:
dims = self._dims
dim_args = {dim: reference[dim] for dim in dims}
da_like_reference = interpolate_to_coords(
da,
dim_args,
self._method,
self._extrapolate_out_of_bounds,
)
return da_like_reference
LAPSE_RATE_K_PER_M = -0.0065 # Standard atmosphere lapse rate.
[docs]
class GridToSparseWithAltitudeAdjustment(InterpolateToReferenceCoords):
"""Applies altitude adjustment to 2m_temperature and 10m_wind_speed.
Alititude adjustments are based on the difference of the grid elevation to the
station elevation. Reference:
https://rmets.onlinelibrary.wiley.com/doi/10.1002/qj.2372, Section 3.3.
Assumes that elevations are in meters and an 'elevation' coordinate exists on
the reference dataset. Requires passing a DataArray with the grid elevation
corresponding to the dataset to be interpolated. Variables must be named
'2m_temperature' and '10m_wind_speed'. Other variables will be left unchanged.
"""
def __init__(
self,
method: str,
grid_elevation: xr.DataArray,
dims: Optional[Sequence[str]] = None,
wrap_longitude: bool = False,
extrapolate_out_of_bounds: bool = True,
max_alititude_diff_in_m: float = 1500,
):
"""Init.
Args:
method: Interpolation method to be passed to xarray's interpolation API.
grid_elevation: DataArray matching the dataset coordinates specifying the
grid box elevation in m.
dims: (Optional) Dimensions over which to interpolate. If None (default),
infer dimensions from intersect of DataArray dimensions and reference
coordinates.
wrap_longitude: If True, perform a wrapped interpolation in the longitude
dimension. Default: False
extrapolate_out_of_bounds: If True, extrapolate to out of bounds values
using the chosen interpolation method. Default: True
max_alititude_diff_in_m: No adjustment is applied for elevation
differences greater than this value. Large values can appear because of
errors in the station dataset, e.g. elevation reported in ft instead of
m. Default: 1500.
"""
self._grid_elevation = grid_elevation
self._max_alititude_diff_in_m = max_alititude_diff_in_m
super().__init__(
method=method,
dims=dims,
wrap_longitude=wrap_longitude,
extrapolate_out_of_bounds=extrapolate_out_of_bounds,
)
def interpolate_data_array(
self,
da: xr.DataArray,
reference: xr.DataArray, # pytype: disable=signature-mismatch
) -> xr.DataArray:
if da.name in ['2m_temperature', '10m_wind_speed']:
da.coords['grid_elevation'] = self._grid_elevation.compute()
da_like_reference = super().interpolate_data_array(da, reference)
if da.name in ['2m_temperature', '10m_wind_speed']:
# Positive if station is higher than grid.
sparse_higher_than_grid_m = (
da_like_reference['elevation'] - da_like_reference['grid_elevation']
)
# Set "unrealistic" differences to 0.
sparse_higher_than_grid_m = sparse_higher_than_grid_m.where(
np.abs(sparse_higher_than_grid_m) < self._max_alititude_diff_in_m, 0
)
if da.name == '2m_temperature':
adjustment = sparse_higher_than_grid_m * LAPSE_RATE_K_PER_M
da_like_reference += adjustment
elif da.name == '10m_wind_speed':
# Only adjust stations > 100m above model orography.
adjustment_factor = xr.ones_like(sparse_higher_than_grid_m)
# Subtract 100m from the difference. I couldn't find this in the paper
# but it does make sense so that the different regimes overlap.
dz = sparse_higher_than_grid_m - 100
adjustment_factor = adjustment_factor.where(
sparse_higher_than_grid_m < 100,
1 + 0.002 * dz,
)
adjustment_factor = adjustment_factor.where(
sparse_higher_than_grid_m < 1100, 3
)
da_like_reference *= adjustment_factor
return da_like_reference
[docs]
class NeighborhoodThresholdProbabilities(Interpolation):
"""Converts a deterministic forecast to a probabilistic one by neighborhood averaging.
For a given threshold, the probability is devined as the fraction of the
fraction of pixels in a square neighborhood that exceeds the threshold. This
is the same computation as in the Fraction Skill Score.
"""
def __init__(
self,
neighborhood_sizes,
thresholds,
threshold_dim='threshold_value',
wrap_longitude: bool = False,
):
"""Init.
Args:
neighborhood_sizes: List of neighborhood sizes to be used in pixels. Must
be odd.
thresholds: List of thresholds to be used to binarize data.
threshold_dim: Dimension name of the thresholds. Default:
'threshold_value'
wrap_longitude: If True, perform a wrapped convolution in the longitude
dimension. Default: False
"""
self._neighborhood_sizes = neighborhood_sizes
self._thresholds = thresholds
self._threshold_dim = threshold_dim
self._wrap_longitude = wrap_longitude
def interpolate_data_array(
self,
da: xr.DataArray,
reference: Optional[xr.DataArray] = None,
) -> xr.DataArray:
da = wrappers.binarize_thresholds(
da, thresholds=self._thresholds, threshold_dim=self._threshold_dim
)
out = []
for n in self._neighborhood_sizes:
out.append(
spatial.neighborhood_averaging_for_single_size(
da, n, wrap_longitude=self._wrap_longitude
)
)
out = xr.concat(
out,
dim=xr.DataArray(
self._neighborhood_sizes, dims=['smoothing_neighborhood']
),
)
return out