# Copyright 2019, by the California Institute of Technology.
# ALL RIGHTS RESERVED. United States Government Sponsorship acknowledged.
# Any commercial use must be negotiated with the Office of Technology
# Transfer at the California Institute of Technology.
#
# This software may be subject to U.S. export control laws. By accepting
# this software, the user agrees to comply with all applicable U.S. export
# laws and regulations. User has the responsibility to obtain export
# licenses, or other export authority as may be required before exporting
# such information to foreign countries or providing access to foreign
# persons.
"""
=========
subset.py
=========
Functions related to subsetting a NetCDF file.
"""
import os
from itertools import zip_longest
from typing import List, Union
import geopandas as gpd
import netCDF4 as nc
import numpy as np
import xarray as xr
from shapely.geometry import Point
from podaac.subsetter import (
datatree_subset,
tree_time_converting as tree_time_converting
)
from podaac.subsetter.utils import mask_utils
from podaac.subsetter.utils import coordinate_utils
from podaac.subsetter.utils import metadata_utils
from podaac.subsetter.utils import spatial_utils
from podaac.subsetter.utils import time_utils
from podaac.subsetter.utils import file_utils
from podaac.subsetter.utils import variables_utils
SERVICE_NAME = 'l2ss-py'
[docs]
def subset_with_shapefile_multi(dataset: xr.Dataset,
lat_var_names: List[str],
lon_var_names: List[str],
shapefile: str,
cut: bool,
pixel_subset: bool) -> xr.Dataset:
"""
Subset an xarray Dataset using a shapefile for multiple latitude and longitude variable pairs
Returns
-------
xr.Dataset
The subsetted dataset
"""
if len(lat_var_names) != len(lon_var_names):
raise ValueError("Number of latitude variables must match number of longitude variables")
shapefile_df = gpd.read_file(shapefile).to_crs("EPSG:4326")
masks = {}
for lat_var_name, lon_var_name in zip(lat_var_names, lon_var_names):
lat = dataset[lat_var_name]
lon = dataset[lon_var_name]
lat_scale = lat.attrs.get("scale_factor", 1.0)
lon_scale = lon.attrs.get("scale_factor", 1.0)
lat_offset = lat.attrs.get("add_offset", 0.0)
lon_offset = lon.attrs.get("add_offset", 0.0)
# Apply scale and offset
lat_vals = lat.values * lat_scale + lat_offset
lon_vals = lon.values * lon_scale + lon_offset
# Handle 2D or 1D lat/lon
if lat_vals.ndim == 1 and lon_vals.ndim == 1:
lon2d, lat2d = np.meshgrid(lon_vals, lat_vals)
else:
lat2d = lat_vals
lon2d = lon_vals
# Convert shapefile to 0-360 if needed
current_shapefile_df = shapefile_df.copy()
if coordinate_utils.is_360(lon, lon_scale, lon_offset):
current_shapefile_df["geometry"] = current_shapefile_df["geometry"].apply(spatial_utils.translate_longitude)
# Flatten points and convert to GeoDataFrame
flat_points = np.column_stack((lon2d.ravel(), lat2d.ravel()))
point_gdf = gpd.GeoDataFrame(
geometry=[Point(xy) for xy in flat_points],
crs="EPSG:4326"
)
# Spatial join to find points inside shapefile
joined = gpd.sjoin(point_gdf, current_shapefile_df, how="left", predicate="intersects")
inside_mask_flat = ~joined.index_right.isna().to_numpy()
inside_mask = inside_mask_flat.reshape(lat2d.shape)
# Create DataArray aligned with original dims
mask_da = xr.DataArray(inside_mask, dims=lat.dims, coords=lat.coords)
lat_path = file_utils.get_path(lat_var_name)
masks[lat_path] = mask_da
# Apply your datatree-aware masking logic
return_dataset = datatree_subset.where_tree(dataset, masks, cut, pixel_subset)
return return_dataset
[docs]
def subset_with_bbox(dataset: xr.Dataset, # pylint: disable=too-many-branches
lat_var_names: list,
lon_var_names: list,
time_var_names: list,
bbox: np.ndarray = None,
cut: bool = True,
min_time: str = None,
max_time: str = None,
pixel_subset: bool = False) -> np.ndarray:
"""
Subset an xarray Dataset using a spatial bounding box.
Parameters
----------
dataset : xr.Dataset
Dataset to subset
lat_var_names : list
Name of the latitude variables in the given dataset
lon_var_names : list
Name of the longitude variables in the given dataset
time_var_names : list
Name of the time variables in the given dataset
variables : list[str]
List of variables to include in the result
bbox : np.array
Spatial bounding box to subset Dataset with.
cut : bool
True if scanline should be cut.
min_time : str
ISO timestamp of min temporal bound
max_time : str
ISO timestamp of max temporal bound
pixel_subset : boolean
Cut the lon lat based on the rows and columns within the bounding box,
but could result with lon lats that are outside the bounding box
TODO: add docstring and type hint for `variables` parameter.
Returns
-------
np.array
Spatial bounds of Dataset after subset operation
TODO - fix this docstring type and the type hint to match code (currently returning a list[xr.Dataset])
"""
lon_bounds, lat_bounds = coordinate_utils.convert_bbox(bbox, dataset, lat_var_names[0], lon_var_names[0])
# condition should be 'or' instead of 'and' when bbox lon_min > lon_max
oper = np.logical_and
if lon_bounds[0] > lon_bounds[1]:
oper = np.logical_or
subset_dictionary = {}
if not time_var_names: # time_var_names == [] or evaluates to False
iterator = zip_longest(lat_var_names, lon_var_names, [])
else:
iterator = zip(lat_var_names, lon_var_names, time_var_names)
for lat_var_name, lon_var_name, time_var_name in iterator:
lat_path = file_utils.get_path(lat_var_name)
lon_path = file_utils.get_path(lon_var_name)
lon_data = dataset[lon_var_name]
lat_data = dataset[lat_var_name]
temporal_cond = time_utils.build_temporal_cond(min_time, max_time, dataset, time_var_name)
time_path = None
if time_var_name:
time_path = file_utils.get_path(time_var_name)
time_data = dataset[time_var_name]
if time_data.ndim == 1 and lon_data.ndim == 2 and temporal_cond is not True:
temporal_cond = mask_utils.align_time_to_lon_dim(time_data, lon_data, temporal_cond)
operation = (
oper((lon_data >= lon_bounds[0]), (lon_data <= lon_bounds[1])) &
(lat_data >= lat_bounds[0]) &
(lat_data <= lat_bounds[1]) &
temporal_cond
)
# We want the lon lat time path to be the same
# timeMidScan_datetime is a time made for ges disc collection in a ScanTime group
if (
lat_path == lon_path == time_path
or (time_var_name is not None and 'timeMidScan_datetime' in time_var_name)
or (lon_path == lat_path and time_var_name is None)
):
subset_dictionary[lat_path] = operation
elif lat_path == lon_path and len(time_var_names) == 1:
subset_dictionary[lat_path] = operation
return_dataset = datatree_subset.where_tree(dataset, subset_dictionary, cut, pixel_subset)
return return_dataset
[docs]
def subset(file_to_subset: str, bbox: np.ndarray, output_file: str,
variables: Union[List[str], str, None] = (),
# pylint: disable=too-many-branches, disable=too-many-statements
cut: bool = True, shapefile: str = None, min_time: str = None, max_time: str = None,
origin_source: str = None,
lat_var_names: List[str] = (), lon_var_names: List[str] = (), time_var_names: List[str] = (),
pixel_subset: bool = False, stage_file_name_subsetted_true: str = None,
stage_file_name_subsetted_false: str = None
) -> Union[np.ndarray, None]:
"""
Subset a given NetCDF file given a bounding box
Parameters
----------
file_to_subset : string
The location of the file which will be subset
bbox : np.ndarray
The chosen bounding box. This is a tuple of tuples formatted
as such: ((west, east), (south, north)). The assumption is that
the valid range is ((-180, 180), (-90, 90)). This will be
transformed as appropriate if the actual longitude range is
0-360.
output_file : string
The file path for the output of the subsetting operation.
variables : list, str, optional
List of variables to include in the resulting data file.
NOTE: This will remove ALL variables which are not included
in this list, including coordinate variables!
cut : boolean
True if the scanline should be cut, False if the scanline should
not be cut. Defaults to True.
shapefile : str
Name of local shapefile used to subset given file.
min_time : str
ISO timestamp representing the lower bound of the temporal
subset to be performed. If this value is not provided, the
granule will not be subset temporally on the lower bound.
max_time : str
ISO timestamp representing the upper bound of the temporal
subset to be performed. If this value is not provided, the
granule will not be subset temporally on the upper bound.
origin_source : str
Original location or filename of data to be used in "derived from"
history element.
lat_var_names : list
List of variables that represent the latitude coordinate
variables for this granule. This list will only contain more
than one value in the case where there are multiple groups and
different coordinate variables for each group.
lon_var_names : list
List of variables that represent the longitude coordinate
variables for this granule. This list will only contain more
than one value in the case where there are multiple groups and
different coordinate variables for each group.
time_var_names : list
List of variables that represent the time coordinate
variables for this granule. This list will only contain more
than one value in the case where there are multiple groups and
different coordinate variables for each group.
pixel_subset : boolean
Cut the lon lat based on the rows and columns within the bounding box,
but could result with lon lats that are outside the bounding box
stage_file_name_subsetted_true: str
stage file name if subsetting is true name depends on result of subset
stage_file_name_subsetted_false: str
stage file name if subsetting is false name depends on result of subset
# clean up time variable in SNDR before decode_times
# SNDR.AQUA files have ascending node time blank
if any('__asc_node_tai93' in i for i in list(nc_dataset.variables)):
asc_time_var = nc_dataset.variables['__asc_node_tai93']
if not asc_time_var[:] > 0:
del nc_dataset.variables['__asc_node_tai93']
"""
file_extension = os.path.splitext(file_to_subset)[1]
file_utils.override_decode_cf_datetime()
hdf_type = False
args = {
'decode_coords': False,
'mask_and_scale': False,
'decode_times': False
}
with xr.open_datatree(file_to_subset, **args) as dataset:
if '.HDF5' == file_extension:
for group in dataset.groups:
if "ScanTime" in group:
hdf_type = 'GPM'
if min_time or max_time:
fill_value_f8 = nc.default_fillvals.get('f8')
float_dtypes = ['float64', 'float32']
args['decode_times'] = True
# try to open file to see if we can access the time variable
try:
with nc.Dataset(file_to_subset, 'r') as nc_dataset:
for time_variable in (v for v in nc_dataset.variables.keys() if 'time' in v):
time_var = nc_dataset[time_variable]
if (getattr(time_var, '_FillValue', None) == fill_value_f8 and time_var.dtype in float_dtypes) or \
(getattr(time_var, 'long_name', None) == "reference time of sst file"):
args['mask_and_scale'] = True
if getattr(time_var, 'long_name', None) == "reference time of sst file":
args['mask_and_scale'] = file_utils.test_access_sst_dtime_values(nc_dataset)
break
except Exception: # pylint: disable=broad-exception-caught
pass
if hdf_type == 'GPM':
args['decode_times'] = False
time_encoding = {}
time_calendar_attributes = {}
if args['decode_times']:
# Get time encoding
with xr.open_datatree(file_to_subset, decode_times=False) as dataset:
lat_var_names, lon_var_names, time_var_names = coordinate_utils.get_coordinate_variable_names(
dataset=dataset,
lat_var_names=lat_var_names,
lon_var_names=lon_var_names,
time_var_names=time_var_names
)
for time in time_var_names:
time_var = dataset[time]
var_name = os.path.basename(time)
group_path = os.path.dirname(time)
units = time_var.attrs.get('units')
dtype = time_var.dtype
calendar = time_var.attrs.get('calendar')
if group_path not in time_encoding:
time_encoding[group_path] = {}
time_encoding[group_path][var_name] = {}
if calendar:
time_encoding[group_path][var_name]['calendar'] = calendar
if units:
time_encoding[group_path][var_name]['units'] = units
time_encoding[group_path][var_name]['dtype'] = dtype
if calendar:
time_calendar_attributes[time] = calendar
with xr.open_datatree(file_to_subset, **args) as dataset:
hdf_type = file_utils.get_hdf_type(dataset)
lat_var_names, lon_var_names, time_var_names = coordinate_utils.get_coordinate_variable_names(
dataset=dataset,
lat_var_names=lat_var_names,
lon_var_names=lon_var_names,
time_var_names=time_var_names
)
if '.HDF5' == file_extension:
new_time_var_names = []
for group in dataset.groups:
if "ScanTime" in group:
group_dataset = dataset[group].ds
dataset[group].ds = datatree_subset.update_dataset_with_time(group_dataset, group_path=group)
if 'timeMidScan_datetime' in dataset[group].ds:
new_time_var_names.append(group + '/timeMidScan_datetime')
if new_time_var_names:
time_var_names = new_time_var_names
if not time_var_names and (min_time or max_time):
raise ValueError('Could not determine time variable')
if hdf_type and (min_time or max_time):
dataset, _ = tree_time_converting.convert_to_datetime(dataset, time_var_names, hdf_type)
chunks = file_utils.calculate_chunks(dataset)
all_vars = variables_utils.get_all_variable_names_from_dtree(dataset)
if chunks:
dataset = dataset.chunk(chunks)
if variables:
# Drop variables that aren't explicitly requested, except lat_var_name and
# lon_var_name which are needed for subsetting
normalized_variables = [f"/{s.replace('__', '/').lstrip('/')}".upper() for s in variables]
keep_variables = normalized_variables + lon_var_names + lat_var_names + time_var_names
keep_variables = variables_utils.normalize_candidate_paths_against_dtree(keep_variables, all_vars)
all_data_variables = datatree_subset.get_vars_with_paths(dataset)
drop_variables = [
var for var in all_data_variables
if var not in keep_variables and var.upper() not in keep_variables
]
dataset = datatree_subset.drop_vars_by_path(dataset, drop_variables)
lon_var_names = variables_utils.normalize_candidate_paths_against_dtree(lon_var_names, all_vars)
lat_var_names = variables_utils.normalize_candidate_paths_against_dtree(lat_var_names, all_vars)
time_var_names = variables_utils.normalize_candidate_paths_against_dtree(time_var_names, all_vars)
if shapefile:
subsetted_dataset = subset_with_shapefile_multi(
dataset,
lat_var_names,
lon_var_names,
shapefile,
cut,
pixel_subset
)
elif bbox is not None:
subsetted_dataset = subset_with_bbox(
dataset=dataset,
lat_var_names=lat_var_names,
lon_var_names=lon_var_names,
time_var_names=time_var_names,
bbox=bbox,
cut=cut,
min_time=min_time,
max_time=max_time,
pixel_subset=pixel_subset
)
else:
raise ValueError('Either bbox or shapefile must be provided')
metadata_utils.set_version_history(subsetted_dataset, cut, bbox, shapefile)
metadata_utils.set_json_history(subsetted_dataset, cut, file_to_subset, bbox, shapefile, origin_source)
if time_calendar_attributes:
for time_var, calendar in time_calendar_attributes.items():
if 'calendar' in subsetted_dataset[time_var].attrs:
subsetted_dataset[time_var].attrs['calendar'] = calendar
# if we set the calendar attribute remove calendar encoding
var_name = os.path.basename(time_var)
group_path = os.path.dirname(time_var)
# Safely remove calendar from encoding if it exists
if group_path in time_encoding and var_name in time_encoding[group_path]:
time_encoding[group_path][var_name].pop('calendar', None)
subsetted_dataset = datatree_subset.clean_inherited_coords(subsetted_dataset)
encoding = datatree_subset.prepare_basic_encoding(subsetted_dataset, time_encoding)
spatial_bounds_array = datatree_subset.tree_get_spatial_bounds(
subsetted_dataset,
lat_var_names,
lon_var_names
)
metadata_utils.update_netcdf_attrs(output_file,
subsetted_dataset,
lon_var_names,
lat_var_names,
spatial_bounds_array,
stage_file_name_subsetted_true,
stage_file_name_subsetted_false)
try:
subsetted_dataset.to_netcdf(output_file, encoding=encoding)
except AttributeError as e:
if "NetCDF: Name contains illegal characters" in str(e):
metadata_utils.fix_illegal_datatree_attrs(subsetted_dataset)
subsetted_dataset.to_netcdf(output_file, encoding=encoding)
else:
raise
metadata_utils.ensure_time_units(output_file, time_encoding)
# ensure all the dimensions are on the root node when we pixel subset
if pixel_subset:
def add_all_group_dims_to_root_inplace(nc_path):
def collect_dims(group, dims):
for dimname, dim in group.dimensions.items():
if dimname not in dims:
dims[dimname] = len(dim) if not dim.isunlimited() else None
for subgrp in group.groups.values():
collect_dims(subgrp, dims)
with nc.Dataset(nc_path, 'r+') as ds:
all_dims = {}
collect_dims(ds, all_dims)
for dimname, size in all_dims.items():
if dimname not in ds.dimensions:
ds.createDimension(dimname, size)
add_all_group_dims_to_root_inplace(output_file)
return spatial_bounds_array