Integrating Dask, Kerchunk, Zarr and Xarray

imported on: 2024-10-30

This notebook is from a different repository in NASA’s PO.DAAC, the-coding-club.

The original source for this document is https://github.com/podaac/the-coding-club/blob/main/notebooks/SWOT_SSH_dashboard.ipynb

SWOT KaRIn Sea Surface Height (SSH) Cloud Data Loading, Subsetting, and Visualization

PO.DAAC, Jet Propulsion Laboratory, California Institution of Technology
Author: Ayush Nag

Background

  • This notebook allows you to visualize large amounts of SWOT data in an interactive dashboard
  • Connects to PO.DAAC s3 SWOT collection
  • Reads data as zarr cloud-optimized data store
  • Access to terabytes of data in milliseconds
  • To run this notebook the SWOT_L2_LR_SSH_Expert_2.0.json file is needed.

Instructions

  1. Load this notebook into a JupyterHub cloud environment
  2. Create a .netrc file (NASA Earthdata account)
cd ~
touch .netrc
echo "machine urs.earthdata.nasa.gov login [USERNAME] password [PASSWORD]" > .netrc
  1. Get the SWOT swaths shapefile for fast spatial and temporal subsetting
wget https://www.aviso.altimetry.fr/fileadmin/documents/missions/Swot/sph_science_swath.zip
unzip sph_science_swath.zip
  1. Get the combined metadata file SWOT_L2_LR_SSH_Expert_2.0.json
  • NASA VEDA JupyterHub users: The notebook already has code to read the json from VEDA s3
import s3fs
import glob
import dask
import ujson
import cartopy
import requests
import earthaccess
import hvplot.dask
import hvplot.pandas
import hvplot.xarray

import numpy as np
import panel as pn
import pandas as pd
import xarray as xr
import holoviews as hv
import geopandas as gpd
import cartopy.crs as ccrs
import panel.widgets as pnw
import matplotlib.pyplot as plt

from tqdm.notebook import tqdm
from dask.distributed import Client
from shapely.geometry import Polygon

hv.extension('bokeh')
hv.renderer('bokeh').webgl = True
pn.param.ParamMethod.loading_indicator = True

Start Dask cluster

  • Data can be read from s3 in parallel and used by hvplot
client = Client()
client

Client

Client-a98b3e17-0bf6-11ef-a03f-026f296dd6e0

Connection method: Cluster object Cluster type: distributed.LocalCluster
Dashboard: /user/ayushnag/proxy/8787/status

Cluster Info

Get SWOT PO.DAAC s3 credentials (authenticate via NASA Earthdata login)

%%time
url = "https://archive.swot.podaac.earthdata.nasa.gov/s3credentials"
creds = requests.get(url).json()
CPU times: user 285 ms, sys: 59.5 ms, total: 344 ms
Wall time: 2.29 s

Open collection using Kerchunk/Zarr

  • This tutorial uses the PO.DAAC SWOT_L2_LR_SSH_EXPERT_2.0 collection
  • Dimensions
    • cycle_number: One cycle represents when SWOT covers approximate global coverage (21 days of data)
    • pass_number: One swath represents one pass. Each granule/file in PO.DAAC is exactly one SWOT pass. The same pass number across cycles should match up spatially
    • num_lines and num_pixels represent the “length” and “width” of a satellite swath
    • num_sides represents the two swaths that the SWOT KaRIn instrument records. Here is an excellent visual.
  • Here, we use a JSON file of the entire dataset made using Kerchunk for efficiently processing large amounts of data in the cloud
  • We read in the file as zarr cloud-optimized data store from the chunked JSON

Read JSON from VEDA and save locally

veda = s3fs.S3FileSystem()
with veda.open("s3://veda-data-store-staging/SWOT_L2_LR_SSH_2.0/SWOT_L2_LR_SSH_Expert_2.0.json", 'r') as infile:
    with open("SWOT_L2_LR_SSH_Expert_2.0.json", 'w') as outfile:
        ujson.dump(ujson.load(infile), outfile)
%%time
data = xr.open_dataset(
    "reference://", engine="zarr", chunks={},
    backend_kwargs={
        "storage_options": 
        {"fo": "SWOT_L2_LR_SSH_Expert_2.0.json",
         "remote_protocol": "s3",
         "remote_options": {"anon": False, "key": creds['accessKeyId'], "secret": creds['secretAccessKey'], "token": creds["sessionToken"]}},
         "consolidated": False
    }
)
data
CPU times: user 148 ms, sys: 17.1 ms, total: 165 ms
Wall time: 231 ms
<xarray.Dataset> Size: 2TB
Dimensions:                                (cycle_num: 8, pass_num: 584,
                                            num_lines: 9866, num_pixels: 69,
                                            num_sides: 2)
Coordinates:
  * cycle_num                              (cycle_num) int64 64B 7 8 9 ... 13 14
    latitude                               (cycle_num, pass_num, num_lines, num_pixels) float64 25GB dask.array<chunksize=(1, 1, 9866, 69), meta=np.ndarray>
    latitude_nadir                         (cycle_num, pass_num, num_lines) float64 369MB dask.array<chunksize=(1, 1, 9866), meta=np.ndarray>
    longitude                              (cycle_num, pass_num, num_lines, num_pixels) float64 25GB dask.array<chunksize=(1, 1, 9866, 69), meta=np.ndarray>
    longitude_nadir                        (cycle_num, pass_num, num_lines) float64 369MB dask.array<chunksize=(1, 1, 9866), meta=np.ndarray>
  * pass_num                               (pass_num) int64 5kB 1 2 ... 583 584
Dimensions without coordinates: num_lines, num_pixels, num_sides
Data variables: (12/98)
    ancillary_surface_classification_flag  (cycle_num, pass_num, num_lines, num_pixels) float32 13GB dask.array<chunksize=(1, 1, 9866, 69), meta=np.ndarray>
    cross_track_angle                      (cycle_num, pass_num, num_lines) float64 369MB dask.array<chunksize=(1, 1, 9866), meta=np.ndarray>
    cross_track_distance                   (cycle_num, pass_num, num_lines, num_pixels) float32 13GB dask.array<chunksize=(1, 1, 9866, 69), meta=np.ndarray>
    dac                                    (cycle_num, pass_num, num_lines, num_pixels) float64 25GB dask.array<chunksize=(1, 1, 9866, 69), meta=np.ndarray>
    depth_or_elevation                     (cycle_num, pass_num, num_lines, num_pixels) float64 25GB dask.array<chunksize=(1, 1, 9866, 69), meta=np.ndarray>
    distance_to_coast                      (cycle_num, pass_num, num_lines, num_pixels) float64 25GB dask.array<chunksize=(1, 1, 9866, 69), meta=np.ndarray>
    ...                                     ...
    wind_speed_model_u                     (cycle_num, pass_num, num_lines, num_pixels) float64 25GB dask.array<chunksize=(1, 1, 9866, 69), meta=np.ndarray>
    wind_speed_model_v                     (cycle_num, pass_num, num_lines, num_pixels) float64 25GB dask.array<chunksize=(1, 1, 9866, 69), meta=np.ndarray>
    wind_speed_rad                         (cycle_num, pass_num, num_lines, num_sides) float64 738MB dask.array<chunksize=(1, 1, 9866, 2), meta=np.ndarray>
    wind_speed_ssb_cor_source              (cycle_num, pass_num, num_lines, num_pixels) float32 13GB dask.array<chunksize=(1, 1, 9866, 69), meta=np.ndarray>
    wind_speed_ssb_cor_source_2            (cycle_num, pass_num, num_lines, num_pixels) float32 13GB dask.array<chunksize=(1, 1, 9866, 69), meta=np.ndarray>
    x_factor                               (cycle_num, pass_num, num_lines, num_pixels) float32 13GB dask.array<chunksize=(1, 1, 9866, 69), meta=np.ndarray>
Attributes: (12/62)
    Conventions:                                   CF-1.7
    contact:                                       podaac@jpl.nasa.gov
    crid:                                          PIC0
    cycle_number:                                  7
    ellipsoid_flattening:                          0.0033528106647474805
    ellipsoid_semi_major_axis:                     6378137.0
    ...                                            ...
    xref_pole_location_file:                       SMM_PO1_AXXCNE20231125_020...
    xref_precipitation_files:                      SMM_CRR_AXFCNE20231123_065...
    xref_reforbittrack_files:                      SWOT_RefOrbitTrack125mPass...
    xref_sea_ice_mask_files:                       SMM_ICS_AXFCNE20231124_052...
    xref_statickarincal_files:                     SWOT_StaticKaRInCalAdjusta...
    xref_wave_model_files:                         SMM_WMA_AXPCNE20231124_072...
print(f'{data.nbytes / 1e12} TB')
1.867173812352 TB

Precompute pass timings for fast temporal subsetting (optional)

  • Note: the calculated start_time coordinate is an approximation based on the known pass and cycle timings
  • For second level precison use the data.time variable
swaths_gdf = gpd.read_file("swot_science_orbit_sept2015-v2_swath.shp", crs="4326")
swath_times = pd.to_datetime(swaths_gdf.START_TIME, format='Day %d %H:%M:%S').drop_duplicates(keep='first')
# Set swath_times to match dataset timing
swath_times = swath_times + (data.time.isel(cycle_num=0, pass_num=0, num_lines=0).values - swath_times.iloc[0])
# cycle_delta = timedelta(duration of full SWOT cycle [start, finish]) = timedelta(start_times[0, n]) + end_time[n]
cycle_delta = (swath_times.iloc[-1] - swath_times.iloc[0]) + (swath_times.iloc[1] - swath_times.iloc[0])
# Extend swath_times to continue for len(data.cycles)
start_times = np.tile(swath_times, len(data.cycle_num)) + np.repeat((np.arange(len(data.cycle_num)) * cycle_delta), len(data.pass_num))
start_times = start_times.reshape(len(data.cycle_num), len(data.pass_num))
data = data.assign_coords(start_time=(("cycle_num", "pass_num"), start_times))

Visualize with interactive dashboard

def dashboard(data: xr.Dataset, x="longitude", y="latitude", z="ssh_karin", groupby=["cycle_num", "pass_num"], aggregator="mean", cnorm="eq_hist", cmap="plasma", projection="PlateCarree", **kwargs):
    """
    Creates dashboard that can visualize SWOT collection
    """
    # --- WIDGETS --- #
    # only keep the vars that are have swath dimensions
    plotting_vars = [var for var in data.data_vars if set(data[var].dims) == {'cycle_num', 'pass_num', 'num_lines', 'num_pixels'}]
    var_select = pnw.Select(value=z, options=plotting_vars, name="variable")
    
    cmap_select = pnw.Select(value=cmap, options=sorted(plt.colormaps, key=lambda x: x.lower()), name="colormap")
    agg_fn_select = pnw.Select(value=aggregator, options=['count','sum','min','max','mean','var','std'], name="aggregration function")
    cnorm_select = pnw.Select(value=cnorm, options=['linear', 'eq_hist'], name="normalization")
    proj_select = pnw.Select(value=projection, options=['PlateCarree', 'SouthPolarStereo', 'NorthPolarStereo'], name="projection")
    # cycle_num_slider = pnw.IntSlider(name="cycle_num", start=int(data.cycle_num.isel(cycle_num=0)), end=int(data.cycle_num.isel(cycle_num=-1)), step=1, value=int(data.cycle_num.isel(cycle_num=0)), width=500)
    cycle_num_select = pnw.Select(name="cycle_num", options=list(data.cycle_num.values), width=500)
    # pass_num_slider = pnw.IntSlider(name="pass_num", start=int(data.pass_num.isel(pass_num=0)), end=int(data.pass_num.isel(pass_num=-1)), step=1, value=int(data.pass_num.isel(pass_num=0)), width=500)
    pass_num_select = pnw.Select(name="pass_num", options=list(data.pass_num.values), width=500)
    desc_text = pnw.StaticText(name="title", value=data.title, width=300)
    sensor_text = pnw.StaticText(name="source", value=data.source, width=300)
    
    def time_boxes(c, p):
        return pn.Column(pnw.StaticText(name='pass start: ', value=str(data.time.sel(cycle_num=c, pass_num=p).values[0])), pnw.StaticText(name='pass end: ', value=str(data.time.sel(cycle_num=c, pass_num=p).values[-1])))

    def var_attr_box(var):
        value = ""
        if data[var].attrs:
            value = data[var].attrs["comment"]
        return pnw.StaticText(name="variable info", value=value, width=300)

    def data_plot(var, proj, c, p):
        default_params = {
            "x": x, "xlabel": x,
            "y": y, "ylabel": y,
            "z": var, "groupby": groupby,
            "widget_location": "bottom",
            "widgets": {'cycle_num': cycle_num_select, "pass_num": pass_num_select},
            "aggregator": agg_fn_select, "cnorm": cnorm_select, "cmap": cmap_select,
            "colorbar": True,
            "features": ['coastline', 'ocean'],
            "geo": True,
            "projection": proj,
            "rasterize": True,
            "project": True,
            "frame_width": 650,
            "frame_height": 500,
            "title": f"PO.DAAC SWOT L2 Dashboard ({str(var)}, cycle_num={c}, pass_num={p})"
        }
        plot_params = default_params.copy()
        plot_params.update(kwargs)
        plot = data.hvplot.quadmesh(**plot_params)
        return plot

    # bind functions to widgets (like update handlers)
    idata_plot = pn.bind(data_plot, var=var_select, proj=projection, c=cycle_num_select, p=pass_num_select)
    itime_boxes = pn.bind(time_boxes, c=cycle_num_select, p=pass_num_select)
    ivar_attr_box = pn.bind(var_attr_box, var=var_select)

    # --- DEFINE LAYOUT --- #
    return pn.Row(pn.Column(idata_plot), pn.Column(var_select, agg_fn_select, cnorm_select, cmap_select, proj_select, itime_boxes, desc_text, sensor_text, ivar_attr_box))
dashboard(data, z="sig0_karin", cmap="bone")