Source code for weatherbenchX.data_loaders.sparse_parquet

# Copyright 2025 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Data loaders for tabular data stored in Parquet format."""

from collections.abc import Hashable
import functools
import os
from typing import Callable, Mapping, Optional, Sequence, Union
import numpy as np
import pandas as pd
import pyarrow
from weatherbenchX.data_loaders import base
import xarray as xr


def get_parquet_files_subset(
    path: str,
    time_start: np.timedelta64,
    time_end: np.timedelta64,
    partition_by: str,
):
  """Get subset of parquet files for a given time interval."""
  if partition_by == 'month':
    unit = 'M'
  elif partition_by == 'day':
    unit = 'D'
  elif partition_by == 'hour':
    unit = 'h'
  else:
    raise NotImplementedError(f'{partition_by} not implemented.')
  time_start = np.datetime64(time_start, unit)
  time_end = np.datetime64(time_end, unit)
  td = np.timedelta64(1, unit)
  times = np.arange(time_start, time_end + td, td)
  files = []
  for time in times:
    fn = parquet_filename_for_time(path, time, unit)
    files.append(fn)
  return files


def parquet_filename_for_time(path: str, time: np.datetime64, unit: str) -> str:
  """Return parquet partition filename for a given time."""
  year = time.item().year
  month = time.item().month
  if unit == 'M':
    fn = f'year={year}/month={month}/{year}-{str(month).zfill(2)}.parquet'
  elif unit == 'D':
    day = time.item().day
    fn = f'year={year}/month={month}/day={day}/{year}-{str(month).zfill(2)}-{str(day).zfill(2)}.parquet'
  elif unit == 'h':
    day = time.item().day
    hour = time.item().hour
    fn = f'year={year}/month={month}/day={day}/hour={hour}/{year}-{str(month).zfill(2)}-{str(day).zfill(2)}T{str(hour).zfill(2)}.parquet'
  else:
    raise NotImplementedError
  fn = os.path.join(path, fn)
  return fn


[docs] class SparseObservationsFromParquet(base.DataLoader): """Reads general sparse observation data stored in Parquet format. It is assumed that the data is partitioned by month, day or hour. A daily partition would follow the following directory structure: <PATH>/year=2020/month=1/day=1/2020-01-01.parquet Since auto-discovery of files can take a long time, this data loader assumes this format to quickly query the desired sub-files for a given time interval. Currently, this assumes there are no missing files. """ def __init__( self, path: str, partitioned_by: str, time_dim: str, variables: Sequence[str], coordinate_variables: Sequence[str] = (), split_variables: bool = False, dropna: bool = False, tolerance: Optional[ np.timedelta64 | tuple[np.timedelta64, np.timedelta64] ] = None, rename_variables: Optional[Mapping[str, str]] = None, include_slice_end_time: bool = False, remove_duplicates: bool = False, pick_closest_duplicate_by: Optional[str] = None, observation_dim: Optional[str] = None, file_tolerance: np.timedelta64 = np.timedelta64(1, 'h'), preprocessing_fn: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, **kwargs, ): """Init. Args: path: Path to Parquet dataset. partitioned_by: How the Parquet file is partitioned. 'hour', 'day' or 'month'. time_dim: Time dimension on Parquet files (before renaming) to use for time filtering. variables: Variables to load (after renaming). coordinate_variables: Coordinate variables to load. These will be converted to an xarray coordinates. 'valid_time' is always a coordinate and represents the original value of the time_dim coordinate. Default: () split_variables: Whether to return the loaded data as a dictionary of DataArrays. Default: False. dropna: Whether to drop missing values. If split_variables is True, values will be dropped for each variable separately. Otherwise, only indices where all variables are non-NaN will be returned. tolerance: (Optional) Tolerance around the given valid time. If tolerance is a single timedelta, data within valid_time +/- tolerance will be returned. If tolerance is a 2-tuple of timedeltas, data within [valid_time + tolerance[0], valid_time + tolerance[1]] will be returned. This is only supported for exact lead_times. The resulting init and lead time coordinates will be those requested. The valid_time dimension will reflect the original time for each observation. rename_variables: (Optional) Renaming dictionary. include_slice_end_time: Whether slice end time is included. Default: False remove_duplicates: For exact lead times, whether duplicate stations (specified by `observation_dim`) for the same valid time are removed. If True, this will pick the closest time specified by `pick_closest_duplicate_by` to the valid_time and keep it. Default: False pick_closest_duplicate_by: (Optional) Time dimension to use to pick the closest duplicate. observation_dim: (Optional) Dimension identifying e.g. station names. This is used to remove duplicate observations. file_tolerance: 'timeObs' does not always align with the time on the partition. To make sure all required times are read, open the files with +/- file_tolerance. The 'timeObs' of most observations are within a one hour window of the nominal time. 'timeNominal' will be equal to the partition time and would therefore not require a file_tolerane. Default: 1h preprocessing_fn: (Optional) Function to apply to the dataframe after reading. **kwargs: Additional keyword arguments passed to the base DataLoader. """ super().__init__( compute=False, # Data is already loaded. **kwargs ) self._path = path if partitioned_by not in ['hour', 'day', 'month']: raise ValueError(f'Unsupported partitioned_by: {partitioned_by}') self._partitioned_by = partitioned_by self._time_dim = time_dim self._variables = variables self._coordinate_variables = list(coordinate_variables) + ['valid_time'] self._split_variables = split_variables self._dropna = dropna if tolerance is not None: if isinstance(tolerance, np.timedelta64): tolerance = (-tolerance, tolerance) if len(tolerance) != 2: raise ValueError( 'Tolerance must be a a single np.timedelta64 or a 2-tuple.' ) if (tolerance[1] - tolerance[0]) <= np.timedelta64(0, 'h'): raise ValueError( 'Tolerance range should be non-empty. This will always return an' ' empty array.' ) self._tolerance = tolerance self._rename_variables = rename_variables self._include_slice_end_time = include_slice_end_time self._remove_duplicates = remove_duplicates self._pick_closest_duplicate_by = pick_closest_duplicate_by if remove_duplicates: if observation_dim is None: raise ValueError( 'station_dim must be specified if remove_duplicates is True.' ) self._observation_dim = observation_dim self._file_tolerance = file_tolerance self._preprocessing_fn = preprocessing_fn def _pick_closest_from_duplicates( self, df: pd.DataFrame, valid_time: np.datetime64 ): """Pick row where `_pick_closest_duplicate_by` is closest to valid_time.""" if self._pick_closest_duplicate_by is not None: df['time_diff'] = np.abs(df[self._pick_closest_duplicate_by] - valid_time) df = df.sort_values('time_diff', ascending=True) non_duplicated = df[~df[self._observation_dim].duplicated(keep='first')] return non_duplicated def _load_data_for_single_time( self, valid_time: Optional[np.datetime64], lead_time_slice: Optional[slice] = None, ) -> xr.Dataset: """Load data for some valid time. If lead_time_slice is given, load data for valid_time +/- lead_time_slice. Otherwise, tolerance and file tolerance are applied around valid_time. Args: valid_time: Base time to load data for. lead_time_slice: (Optional) If given, load data for valid_time +/- lead_time_slice. Returns: xarray.Dataset with data for the given valid_time. """ if self._tolerance is None: if lead_time_slice is None: start_time = valid_time stop_time = None else: start_time = valid_time - lead_time_slice.start stop_time = valid_time + lead_time_slice.stop else: start_time = valid_time + self._tolerance[0] stop_time = valid_time + self._tolerance[1] # Get subset of files since filtering can take a very long time. # Also create additional filters to exactly get required times. if stop_time is None: file_start_time = start_time - self._file_tolerance file_stop_time = start_time + self._file_tolerance ts = pd.Timestamp(start_time) filters = [(self._time_dim, '=', ts)] else: file_start_time = start_time - self._file_tolerance file_stop_time = stop_time + self._file_tolerance ts_start = pd.Timestamp(start_time) ts_stop = pd.Timestamp(stop_time) if self._include_slice_end_time: filters = [ (self._time_dim, '>=', ts_start), (self._time_dim, '<=', ts_stop), ] else: filters = [ (self._time_dim, '>=', ts_start), (self._time_dim, '<', ts_stop), ] files = get_parquet_files_subset( self._path, file_start_time, file_stop_time, self._partitioned_by ) def _read_single_file(fn): # Filters don't work for empty files. Catch this error but make sure the # file is empty. try: df = pd.read_parquet(fn, filters=filters) except pyarrow.lib.ArrowTypeError: df = pd.read_parquet(fn) assert len(df) == 0, 'This should only happen if the file is empty.' # pylint: disable=g-explicit-length-test return df df = pd.concat([_read_single_file(fn) for fn in files], ignore_index=True) if self._preprocessing_fn is not None: df = self._preprocessing_fn(df) if self._remove_duplicates: assert ( lead_time_slice is None ), 'Removing duplicates not compatible with slice lead_time.' df = self._pick_closest_from_duplicates(df, valid_time) if self._rename_variables is not None: df = df.rename(columns=self._rename_variables) df = df.rename(columns={self._time_dim: 'valid_time'}) return df.loc[ :, self._variables + self._coordinate_variables, # pytype: disable=unsupported-operands ] def _load_chunk_from_source( self, init_times: np.ndarray, lead_times: Optional[Union[np.ndarray, slice]] = None, ) -> Mapping[Hashable, xr.DataArray]: dfs = [] # Case #1: Exact lead times or no lead_times if not isinstance(lead_times, slice): # Get data for each valid time for init_time in init_times: # Case #1.1: No lead times, i.e. init_time = valid_time if lead_times is None: if self._tolerance is None: df = self._load_data_for_single_time(init_time) else: df = self._load_data_for_single_time(init_time) dfs.append(df) # Case #1.2: Exact init_times given else: for lead_time in lead_times: valid_time = init_time + lead_time if self._tolerance is None: df = self._load_data_for_single_time(valid_time) else: df = self._load_data_for_single_time(valid_time) df['init_time'] = init_time df['lead_time'] = lead_time dfs.append(df) # Case #2: Lead time slice else: assert ( self._tolerance is None ), 'Tolerance not compatible with lead_time slice.' for init_time in init_times: df = self._load_data_for_single_time( init_time, lead_time_slice=lead_times ) df['init_time'] = init_time df['lead_time'] = df.valid_time - df.init_time dfs.append(df) # Combine dataframes combined_df = pd.concat(dfs) combined_df.index = range(len(combined_df)) time_coords = [] if lead_times is None else ['init_time', 'lead_time'] ds = combined_df.to_xarray().set_coords( self._coordinate_variables + time_coords ) if self._split_variables: dic = dict(ds) if self._dropna: for v, da in dic.items(): dic[v] = da.dropna('index') return dic else: if self._dropna: ds = ds.dropna('index') return ds
# METAR constants METAR_TO_ERA5_NAMES = { 'seaLevelPress': 'mean_sea_level_pressure', 'temperature': '2m_temperature', 'dewpoint': '2m_dewpoint_temperature', 'windSpeed': '10m_wind_speed', 'windGust': '10m_wind_gust', 'windDir': '10m_wind_direction', 'minTemp24Hour': 'min_2m_temperature_24hr', 'maxTemp24Hour': 'max_2m_temperature_24hr', 'precip1Hour': 'total_precipitation_1hr', 'precip3Hour': 'total_precipitation_3hr', 'precip6Hour': 'total_precipitation_6hr', 'precip24Hour': 'total_precipitation_24hr', 'precipRate': 'precipitation_rate', } ERA5_TO_METAR_NAMES = {v: k for k, v in METAR_TO_ERA5_NAMES.items()} METAR_QC_SUFFIX = 'DD' METAR_BAD_QUALITY_FLAGS = ('Z', 'B', 'X', 'Q', 'k') METAR_COORDINATE_VARIABLES = ( 'latitude', 'longitude', 'elevation', 'stationName', ) # METAR preprocessing functions def set_bad_quality_to_nan( df: pd.DataFrame, variables: Sequence[str], qc_suffix: str, bad_quality_flags: Sequence[str], ): for variable in variables: df[variable] = df[variable].where( ~np.isin(df[variable + qc_suffix], bad_quality_flags), np.nan ) return df def convert_longitude_to_0_to_360( df: pd.DataFrame, longitude_dim: str = 'longitude' ): df[longitude_dim] = np.mod(df[longitude_dim], 360) return df class METARFromParquet(SparseObservationsFromParquet): """Reads METAR data stored in Parquet format. This implementation of SparseObservationsFromParquet sets all the default values for METAR and adds METAR-specific preprocessing functions. - Bad quality flags are set to NaN: ('Z', 'B', 'X', 'Q', 'k') - Longitude is converted to 0 to 360. - Elevation with fill values 9.999e+03 is set to NaN. Example: >>> init_times, lead_times (array(['2020-01-01T00:00:00.000000000', '2020-01-01T12:00:00.000000000'], dtype='datetime64[ns]'), array([ 6, 12], dtype='timedelta64[h]')) >>> target_data_loader = sparse_parquet.METARFromParquet( >>> path=<PATH>, >>> variables=['2m_temperature', '10m_wind_speed'], >>> split_variables=False, >>> partitioned_by='month', >>> dropna=True, >>> time_dim='timeNominal', >>> ) >>> target_data_loader.load_chunk(init_times, lead_times) <xarray.Dataset> Dimensions: (index: 31478) Coordinates: * index (index) int64 0 1 2 3 4 ... 33021 33022 33023 33024 33025 2m_temperatureDD (index) object 'S' 'S' 'S' 'S' 'S' ... 'V' 'S' 'S' 'S' 'S' 10m_wind_speedDD (index) object 'S' 'S' 'S' 'S' 'S' ... 'V' 'S' 'S' 'S' 'S' latitude (index) float32 -77.87 -53.8 -33.38 ... 46.55 49.82 49.83 longitude (index) float32 167.0 292.2 289.2 ... 299.0 285.0 295.7 elevation (index) float32 8.0 22.0 476.0 141.0 ... 13.0 381.0 53.0 valid_time (index) datetime64[ns] 2020-01-01T06:00:00 ... 2020-01-02 stationName (index) object 'NZCM' 'SAWE' 'SCEL' ... 'CWUK' 'CWBY' init_time (index) datetime64[ns] 2020-01-01 ... 2020-01-01T12:00:00 lead_time (index) timedelta64[ns] 06:00:00 06:00:00 ... 12:00:00 Data variables: 2m_temperature (index) float32 273.1 282.1 291.1 ... 274.1 268.1 272.9 10m_wind_speed (index) float32 4.1 5.1 2.1 2.1 1.0 ... 12.4 9.3 2.1 2.1 """ def __init__( self, path: str, variables: Sequence[str], time_dim: str, split_variables: bool = False, dropna: bool = False, tolerance: Optional[np.timedelta64] = None, partitioned_by: str = 'month', rename_variables: Optional[Mapping[str, str]] = None, include_slice_end_time: bool = False, remove_duplicates: bool = False, pick_closest_duplicate_by: Optional[str] = None, file_tolerance: np.timedelta64 = np.timedelta64(1, 'h'), preprocessing_fn: Optional[Callable[[pd.DataFrame], pd.DataFrame]] = None, **kwargs, ): def metar_preprocessing_fn( df: pd.DataFrame, preprocessing_fn: Optional[ Callable[[pd.DataFrame], pd.DataFrame] ] = None, ): if preprocessing_fn is not None: df = preprocessing_fn(df) df = set_bad_quality_to_nan( df, # Rename to raw variables since this happens before renaming. [ERA5_TO_METAR_NAMES[v] for v in variables], METAR_QC_SUFFIX, METAR_BAD_QUALITY_FLAGS, ) df = convert_longitude_to_0_to_360(df) # Set elevation with fill values 9.999e+03 to NaN. df['elevation'] = df['elevation'].where( df['elevation'] < 9.999e03, np.nan ) return df super().__init__( path=path, variables=variables, time_dim=time_dim, coordinate_variables=METAR_COORDINATE_VARIABLES, observation_dim='stationName', split_variables=split_variables, dropna=dropna, tolerance=tolerance, partitioned_by=partitioned_by, rename_variables=METAR_TO_ERA5_NAMES, include_slice_end_time=include_slice_end_time, remove_duplicates=remove_duplicates, pick_closest_duplicate_by=pick_closest_duplicate_by, file_tolerance=file_tolerance, preprocessing_fn=functools.partial( metar_preprocessing_fn, preprocessing_fn=preprocessing_fn ), **kwargs, )