Kerchunk JSON Generation

imported on: 2024-10-30

This notebook is from a different repository in NASA’s PO.DAAC, the-coding-club.

The original source for this document is https://github.com/podaac/the-coding-club/blob/main/notebooks/SWOT_to_kerchunk.ipynb

SWOT NetCDF’s to Kerchunk

PO.DAAC, Jet Propulsion Laboratory, California Institution of Technology
Author: Ayush Nag

Background

Kerchunk is a library that provides a unified way to represent a variety of chunked, compressed data formats (e.g. NetCDF/HDF5, GRIB2, TIFF, …), allowing efficient access to the data from traditional file systems or cloud object storage. It also provides a flexible way to create virtual datasets from multiple files. It does this by extracting the byte ranges, compression information and other information about the data and storing this metadata in a new, separate object. This means that you can create a virtual aggregate dataset over potentially many source files, for efficient, parallel and cloud-friendly in-situ access without having to copy or translate the originals. It is a gateway to in-the-cloud massive data processing for legacy data archival formats. (description from Kerchunk documentation)

## Objectives - Generate a Kerchunk JSON file for a PODAAC SWOT Collection - Create individual JSON’s for each netCDF. Then, combine into one file using MultiZarrToZarr. - Allows data to be read from PODAAC s3 as a Zarr combined data store - output file: SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1_kerchunk_DEMO.json

import os
import re
import s3fs
import glob
import zarr
import dask
import ujson
import fsspec
import requests
import numpy as np
import xarray as xr
import xml.etree.ElementTree as ET

from tqdm.notebook import tqdm
from dask import delayed, compute
from dask.distributed import Client
# from kerchunk.combine import auto_dask
from kerchunk.hdf import SingleHdf5ToZarr
from kerchunk.combine import MultiZarrToZarr

Start Dask cluster

  • Data can be read via s3 in parallel and used by Kerchunk
  • NOTE: Performance is greatly improved by more CPU’s/threads. 4 hours can become 1 hour, 90 mins can become 20 mins, etc.
  • For over 10,000 granules at least 4 threads is recommended
client = Client()
client

Client

Client-28f1d48e-41d3-11ee-8aca-c62ddac59d68

Connection method: Cluster object Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status

Cluster Info

Request PO.DAAC s3 credentials

%%time
collection = "SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1"
url = "https://archive.podaac.earthdata.nasa.gov/s3credentials"
creds = requests.get(url).json()
CPU times: user 205 ms, sys: 31.1 ms, total: 236 ms
Wall time: 4.98 s

Option 1: Get list of remote netCDF files by collection (s3 endpoints)

%%time
s3 = s3fs.S3FileSystem(anon=False, key=creds["accessKeyId"], secret=creds["secretAccessKey"], token=creds["sessionToken"])
s3path = f"s3://podaac-ops-cumulus-protected/{collection}/*.nc"
remote_urls = s3.glob(s3path)
remote_urls = ['s3://' + f for f in remote_urls]
print(len(remote_urls))
17564
CPU times: user 6.94 s, sys: 214 ms, total: 7.15 s
Wall time: 9.96 s

Open files with s3fs

%%time
remote_files = [s3.open(file) for file in tqdm(remote_urls)]

## Large amount of s3 netCDF’s to individual JSON’s (dask) - The PODAAC SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1 collection has 17564 granules! - Recommended to test with a smaller dataset such as remote_files[:20] - Single threaded - 625 granules = ~25 mins - 13680 granules = ~540 mins - 4 threads/workers - 625 granules = ~8 mins - 13680 granules = ~180 mins

def remaining_jsons(remote_urls: list, directory: str):
    """
    Find set difference between remote_urls and JSONS already made in directory.
    Extracts granule name from remote url path and JSON file name
    """
    # Extract granule names from s3 endpoints list and jsons list
    remote_granules = np.asarray([path.split('/')[-1] for path in remote_urls])
    done_granules = np.asarray([path[:-5] for path in os.listdir(directory)])
    # Find remaining files (set difference: s3 endpoints - json directory)
    remaining_indices = np.where(~np.isin(remote_granules, done_granules))[0]
    new_remote_urls = [remote_urls[idx] for idx in remaining_indices]
    print(f"{len(remote_urls)}/{len(new_remote_urls)} files already done")
    return remote_urls
%%time
out_dir = f"{collection}_kerchunk_DEMO"
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

# If session expired, only re-run on remote_urls we don't have a JSON for
remaining_urls = remaining_jsons(remote_urls, f"{collection}_kerchunk_DEMO")
    
@dask.delayed
def gen_json(u: str):
    "Generate JSON reference file for one netCDF"
    with fsspec.open(u, mode="rb", anon=False, key=creds['accessKeyId'], secret=creds['secretAccessKey'], token=creds["sessionToken"]) as inf:
        single = SingleHdf5ToZarr(inf, u, inline_threshold=0)
        # Extract granule name from s3 URL
        granule = re.search(r'[^/]+$', u).group(0)
        full_filename = f'{out_dir}/{granule}.json'
        # Generate single kerchunk reference
        s = single.translate()
        if len(s) == 0:
            warnings.warn(f"{granule} JSON generation failed")
        # Write JSON to folder
        with open(full_filename, 'w') as f:
            ujson.dump(s, f)
        return full_filename
            
# Define delayed list of jsons    
jsons = [dask.delayed(gen_json)(file) for file in remaining_urls]

# Run kerchunk single netcdf to JSON conversion in parallel
full_filenames = dask.compute(jsons)
1066/1368 files remaining
CPU times: user 3min 33s, sys: 22.4 s, total: 3min 55s
Wall time: 19min 36s

Create combined kerchunk

%%time
out_dir = f"{collection}_kerchunk_DEMO"
json_files = os.listdir(out_dir)
json_files = [f"{out_dir}/{f}" for f in json_files if f.endswith(".json")]
mzz = MultiZarrToZarr(json_files,
    remote_protocol='s3',
    remote_options={"anon": False, "key": creds['accessKeyId'], "secret": creds['secretAccessKey'], "token": creds["sessionToken"]},
    coo_map={"cycle_num": "attr:cycle_number", "pass_num": "attr:pass_number"},
    concat_dims=["cycle_num", "pass_num"]
)

out = mzz.translate()
CPU times: user 1min 50s, sys: 14.3 s, total: 2min 4s
Wall time: 15min 47s

The args passed to SingleHdf5ToZarr and MultiZarrToZarr define how the data is read and concatenated. These parameters will need to be modified depending on how you want to concat the datasets (time, cycles). The documentation for those functions is linked above. Linked in this cell is more resources for understanding and how to use Kerchunk.

%%time
json_file = f'{collection}_kerchunk_DEMO.json'
with open(json_file, 'wb') as f:
    f.write(ujson.dumps(out).encode())
CPU times: user 141 ms, sys: 44.4 ms, total: 185 ms
Wall time: 181 ms

Check by opening dataset

data = xr.open_dataset(
    "reference://", engine="zarr", chunks={}, decode_times=False,
    backend_kwargs={
        "storage_options": {
            "fo": f'{collection}_kerchunk_DEMO.json',
            "remote_protocol": "s3",
            "remote_options": {"anon": False, "key": creds['accessKeyId'], "secret": creds['secretAccessKey'], "token": creds["sessionToken"]}
        },
        "consolidated": False
    }
)
data
<xarray.Dataset>
Dimensions:                                (cycle_num: 3, pass_num: 456,
                                            num_lines: 9866, num_pixels: 71,
                                            num_sides: 2)
Coordinates:
  * cycle_num                              (cycle_num) float64 3.0 4.0 5.0
    latitude                               (cycle_num, pass_num, num_lines, num_pixels) float64 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
    latitude_nadir                         (cycle_num, pass_num, num_lines) float64 dask.array<chunksize=(1, 1, 9866), meta=np.ndarray>
    longitude                              (cycle_num, pass_num, num_lines, num_pixels) float64 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
    longitude_nadir                        (cycle_num, pass_num, num_lines) float64 dask.array<chunksize=(1, 1, 9866), meta=np.ndarray>
  * pass_num                               (pass_num) float64 1.0 2.0 ... 583.0
Dimensions without coordinates: num_lines, num_pixels, num_sides
Data variables: (12/91)
    ancillary_surface_classification_flag  (cycle_num, pass_num, num_lines, num_pixels) float32 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
    correction_flag                        (cycle_num, pass_num, num_lines, num_pixels) float32 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
    cross_track_angle                      (cycle_num, pass_num, num_lines) float64 dask.array<chunksize=(1, 1, 9866), meta=np.ndarray>
    cross_track_distance                   (cycle_num, pass_num, num_lines, num_pixels) float32 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
    dac                                    (cycle_num, pass_num, num_lines, num_pixels) float32 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
    depth_or_elevation                     (cycle_num, pass_num, num_lines, num_pixels) float32 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
    ...                                     ...
    wind_speed_karin                       (cycle_num, pass_num, num_lines, num_pixels) float32 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
    wind_speed_karin_2                     (cycle_num, pass_num, num_lines, num_pixels) float32 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
    wind_speed_model_u                     (cycle_num, pass_num, num_lines, num_pixels) float32 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
    wind_speed_model_v                     (cycle_num, pass_num, num_lines, num_pixels) float32 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
    wind_speed_rad                         (cycle_num, pass_num, num_lines, num_sides) float32 dask.array<chunksize=(1, 1, 9866, 2), meta=np.ndarray>
    x_factor                               (cycle_num, pass_num, num_lines, num_pixels) float32 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
Attributes: (12/32)
    Conventions:                CF-1.7
    contact:                    CNES aviso@altimetry.fr, JPL podaac@podaac.jp...
    cycle_number:               5
    ellipsoid_flattening:       0.003352810664781205
    ellipsoid_semi_major_axis:  6378137.0
    equator_longitude:          206.06188772087626
    ...                         ...
    right_last_longitude:       289.6585533138908
    source:                     Simulate product
    time_coverage_end:          2014-07-24T10:18:18.533147Z
    time_coverage_start:        2014-07-24T09:26:52.109265Z
    title:                      Level 2 Low Rate Sea Surface Height Data Prod...
    wavelength:                 0.008385803020979

Alternate: Large amount of s3 netCDF’s to combined kerchunk JSON (auto_dask)

%%time
# Create combined kerchunk/zarr reference. Reads cycle_number and pass_number from attributes of each netCDF
# Concats along new dimensions cycle_number and pass_number
out_dir = f"{collection}_kerchunk_DEMO"
json_files = os.listdir(out_dir)
json_files = [f"{out_dir}/{f}" for f in json_files if f.endswith(".json")]
auto_dask(urls=json_files,
          n_batches=4,
          single_driver=JustLoad,
          single_kwargs={},
          mzz_kwargs={"coo_map": {"cycle_num": "attr:cycle_number", "pass_num": "attr:pass_number"}, "concat_dims": ["cycle_num", "pass_num"]},
          remote_protocol="s3",
          remote_options={"anon": False, "key": creds['accessKeyId'], "secret": creds['secretAccessKey'], "token": creds["sessionToken"]},
          filename=f'{collection}_kerchunk_CYCLE_3.json'
     )

Appendix/Extras

  • Find granules with matching dimension (SWOT_SIMULATED granules are uneven dimensions: num_lines could be 9864, 9865, 9866, etc.)
    • New: Read DMR++ metadata paired with each netCDF for dimensions
    • Old: Open each dataset with xarray and check dimensions
  • Loop/auto-dask implementatations of creating Kerchunks
  • References:
    • https://projectpythia.org/kerchunk-cookbook/notebooks/foundations/03_kerchunk_dask.html
    • https://ncar.github.io/esds/posts/2023/cesm2-le-timeseries-kerchunk/
    • https://nbviewer.org/github/cgentemann/cloud_science/blob/master/zarr_meta/cloud_mur_v41_benchmark.ipynb

Generate list of matching size granules (faster metadata approach)

s3 = s3fs.S3FileSystem(anon=False)
s3path = 's3://podaac-ops-cumulus-protected/SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1/*.dmrpp'
remote_files = s3.glob(s3path)
print(len(remote_files))
17564

Read DMR++ metadata to get granule dimensions

matches = []
pbar = tqdm(remote_files, desc="Found 0 matches")
for file in pbar:
    with s3.open(file, 'r') as f:
        xml_str: str = f.read()
    xml = ET.fromstring(xml_str)
    if xml[0].attrib['size'] == "9866" and xml[1].attrib['size'] == "71":
        matches.append(file)
        pbar.set_description(f"Found {len(matches)} matches")
flist = ['s3://' + f.removesuffix(".dmrpp") for f in tqdm(matches)]
print(flist[0])
s3://podaac-ops-cumulus-protected/SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1/SWOT_L2_LR_SSH_Expert_001_001_20140412T120000_20140412T125126_DG10_01.nc
file = open('SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1_9866_paths.txt','w')
for item in tqdm(flist):
    file.write(item + "\n")
file.close()

Find the cycles with the most passes (size 9866 swath)

cycles = {}
z = glob.glob("SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1_kerchunk_CYCLE_3/*Expert_003_*")
len(z)
456
file = open('SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1_9866_CYCLE_3.txt','w')
for item in tqdm(all_urls):
    file.write("%s\n" % item)

Generate list of matching size granules

s3 = s3fs.S3FileSystem(anon=False)
s3path = 's3://podaac-ops-cumulus-protected/SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1/*.nc'
remote_files = s3.glob(s3path)
print(len(remote_files))
17564

Connect to s3 netCDF’s

  • 17564 granules takes around 5 mins
%%time
fileset = [s3.open(file) for file in remote_files]

Find matching size granules (dask)

# Define a function to check the condition and return remote_files if condition is met
def is_match(f):
    ds = xr.open_dataset(f, engine="h5netcdf")
    if ds['simulated_true_ssh_karin'].encoding['chunksizes'] == (9866, 71):
        return f.path
    return None

# Create a Dask bag from the fileset
fileset_bag = dask.bag.from_sequence(fileset)

# Use Dask to parallelize the processing and filter matches
with dask.diagnostics.ProgressBar():
    matches_bag = fileset_bag.map(is_match).filter(lambda x: x is not None)
    matches = matches_bag.compute()

print(f"{len(matches)} matches found")

Find matching size granules

matches = []
pbar = tqdm(fileset, desc="Found 0 matches")
for i, f in enumerate(pbar):
    ds = xr.open_dataset(f, engine="h5netcdf")
    if ds['ancillary_surface_classification_flag'].encoding['chunksizes'] == (9866, 71):
        matches.append(remote_files[i])
        pbar.set_description(f"Found {len(matches)} matches")
print(matches[0])
podaac-ops-cumulus-protected/SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1/SWOT_L2_LR_SSH_Expert_001_001_20140412T120000_20140412T125126_DG10_01.nc
flist = ['s3://' + f for f in tqdm(matches)]
print(flist[0])
s3://podaac-ops-cumulus-protected/SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1/SWOT_L2_LR_SSH_Expert_001_001_20140412T120000_20140412T125126_DG10_01.nc

Write file

file = open('SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1_9866_paths.txt','w')
for item in tqdm(flist):
    file.write(item + "\n")
file.close()

Read file

f = open("SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1_9866_same_passes_paths.txt", "r")
remote_urls = f.read().splitlines()
# remote_urls = remote_urls.split("\n")
# remote_urls.pop() # remove extra '' for EOF
f.close()
print(len(remote_urls))
remote_urls[0]
13680
's3://podaac-ops-cumulus-protected/SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1/SWOT_L2_LR_SSH_Expert_001_140_20140417T111107_20140417T120233_DG10_01.nc'

Find all granules in cycles 3, 4, 5

cycles_3 = []
for url in tqdm(remote_urls):
    if "Expert_003_" in url or "Expert_004_" in url or "Expert_005_" in url:
        cycles_3.append(url)
len(cycles_3)
1368
%%time
s3 = s3fs.S3FileSystem(anon=False, key=creds["accessKeyId"], secret=creds["secretAccessKey"], token=creds["sessionToken"])
remote_files = [s3.open(file) for file in tqdm(cycle3)]
CPU times: user 1.35 s, sys: 139 ms, total: 1.49 s
Wall time: 8.25 s

s3 NetCDF’s to Kerchunk (loop)

%%time
singles = []
for i, u in enumerate(tqdm(remote_urls[:20])):
    with fsspec.open(u, mode="rb", anon=False, key=creds['accessKeyId'], secret=creds['secretAccessKey'], token=creds["sessionToken"]) as inf:
                    single = SingleHdf5ToZarr(inf, u, inline_threshold=0)
                    filename = re.sub(r'.*/', '', u)
                    singles.append(single.translate())

## kerchunk auto_dask definition - Note: this is built in to kerchunk as of v0.1.0 using from kerchunk.combine import auto_dask

# Author: Martin Durant
# https://fsspec.github.io/kerchunk/_modules/kerchunk/combine.html#auto_dask
from typing import List
def auto_dask(
    urls: List[str],
    single_driver: str,
    single_kwargs: dict,
    mzz_kwargs: dict,
    n_batches: int,
    remote_protocol=None,
    remote_options=None,
    filename=None,
    output_options=None,
):
    """Batched tree combine using dask.

    If you wish to run on a distributed cluster (recommended), create
    a client before calling this function.

    Parameters
    ----------
    urls: list[str]
        input dataset URLs
    single_driver: class
        class with ``translate()`` method
    single_kwargs: to pass to single-input driver
    mzz_kwargs: passed to ``MultiZarrToZarr`` for each batch
    n_batches: int
        Number of MZZ instances in the first combine stage. Maybe set equal
        to the number of dask workers, or a multple thereof.
    remote_protocol: str | None
    remote_options: dict
        To fsspec for opening the remote files
    filename: str | None
        Ouput filename, if writing
    output_options
        If ``filename`` is not None, open it with these options

    Returns
    -------
    reference set
    """
    import dask

    # make delayed functions
    single_task = dask.delayed(lambda x: single_driver(x, **single_kwargs).translate())
    post = mzz_kwargs.pop("postprocess", None)
    inline = mzz_kwargs.pop("inline_threshold", None)
    # TODO: if single files produce list of reference sets (e.g., grib2)
    batch_task = dask.delayed(
        lambda u, x: MultiZarrToZarr(
            u,
            indicts=x,
            remote_protocol=remote_protocol,
            remote_options=remote_options,
            **mzz_kwargs,
        ).translate()
    )

    # sort out kwargs
    dims = mzz_kwargs.get("concat_dims", [])
    dims += [k for k in mzz_kwargs.get("coo_map", []) if k not in dims]
    kwargs = {"concat_dims": dims}
    if post:
        kwargs["postprocess"] = post
    if inline:
        kwargs["inline_threshold"] = inline
    for field in ["remote_protocol", "remote_options", "coo_dtypes", "identical_dims"]:
        if field in mzz_kwargs:
            kwargs[field] = mzz_kwargs[field]
    final_task = dask.delayed(
        lambda x: MultiZarrToZarr(
            x, remote_options=remote_options, remote_protocol=remote_protocol, **kwargs
        ).translate(filename, output_options)
    )

    # make delayed calls
    tasks = [single_task(u) for u in urls]
    tasks_per_batch = -(-len(tasks) // n_batches)
    tasks2 = []
    for batch in range(n_batches):
        in_tasks = tasks[batch * tasks_per_batch : (batch + 1) * tasks_per_batch]
        u = urls[batch * tasks_per_batch : (batch + 1) * tasks_per_batch]
        # if in_tasks:
            # skip if on last iteration and no remaining tasks
            # tasks2.append(batch_task(u, in_tasks))
    return dask.compute(final_task(tasks2))[0]



class JustLoad:
    """For auto_dask, in the case that single file references already exist"""

    def __init__(self, url, storage_options=None):
        self.url = url
        self.storage_options = storage_options or {}

    def translate(self):
        with fsspec.open(self.url, mode="rt", **self.storage_options) as f:
            return ujson.load(f)