Kerchunk JSON Generation

imported on: 2024-04-26

This notebook is from a different repository in NASA’s PO.DAAC, the-coding-club.

The original source for this document is https://github.com/podaac/the-coding-club/blob/main/notebooks/SWOT_to_kerchunk.ipynb

SWOT NetCDF’s to Kerchunk

PO.DAAC, Jet Propulsion Laboratory, California Institution of Technology
Author: Ayush Nag

Background

Kerchunk is a library that provides a unified way to represent a variety of chunked, compressed data formats (e.g. NetCDF/HDF5, GRIB2, TIFF, …), allowing efficient access to the data from traditional file systems or cloud object storage. It also provides a flexible way to create virtual datasets from multiple files. It does this by extracting the byte ranges, compression information and other information about the data and storing this metadata in a new, separate object. This means that you can create a virtual aggregate dataset over potentially many source files, for efficient, parallel and cloud-friendly in-situ access without having to copy or translate the originals. It is a gateway to in-the-cloud massive data processing for legacy data archival formats. (description from Kerchunk documentation)

## Objectives - Generate a Kerchunk JSON file for a PODAAC SWOT Collection - Create individual JSON’s for each netCDF. Then, combine into one file using MultiZarrToZarr. - Allows data to be read from PODAAC s3 as a Zarr combined data store - output file: SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1_kerchunk_DEMO.json

import os
import re
import s3fs
import glob
import zarr
import dask
import ujson
import fsspec
import requests
import numpy as np
import xarray as xr
import xml.etree.ElementTree as ET

from tqdm.notebook import tqdm
from dask import delayed, compute
from dask.distributed import Client
# from kerchunk.combine import auto_dask
from kerchunk.hdf import SingleHdf5ToZarr
from kerchunk.combine import MultiZarrToZarr

Start Dask cluster

  • Data can be read via s3 in parallel and used by Kerchunk
  • NOTE: Performance is greatly improved by more CPU’s/threads. 4 hours can become 1 hour, 90 mins can become 20 mins, etc.
  • For over 10,000 granules at least 4 threads is recommended
client = Client()
client

Client

Client-28f1d48e-41d3-11ee-8aca-c62ddac59d68

Connection method: Cluster object Cluster type: distributed.LocalCluster
Dashboard: http://127.0.0.1:8787/status

Cluster Info

Worker: 1

Comm: tcp://127.0.0.1:36915 Total threads: 1
Dashboard: http://127.0.0.1:41499/status Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:39355
Local directory: /tmp/dask-worker-space/worker-ivjiyl7u

Worker: 2

Comm: tcp://127.0.0.1:35757 Total threads: 1
Dashboard: http://127.0.0.1:33523/status Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:44603
Local directory: /tmp/dask-worker-space/worker-qggrwd45

Worker: 3

Comm: tcp://127.0.0.1:36783 Total threads: 1
Dashboard: http://127.0.0.1:39605/status Memory: 2.00 GiB
Nanny: tcp://127.0.0.1:34001
Local directory: /tmp/dask-worker-space/worker-qgfewoy7

Request PO.DAAC s3 credentials

%%time
collection = "SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1"
url = "https://archive.podaac.earthdata.nasa.gov/s3credentials"
creds = requests.get(url).json()
CPU times: user 205 ms, sys: 31.1 ms, total: 236 ms
Wall time: 4.98 s

Option 1: Get list of remote netCDF files by collection (s3 endpoints)

%%time
s3 = s3fs.S3FileSystem(anon=False, key=creds["accessKeyId"], secret=creds["secretAccessKey"], token=creds["sessionToken"])
s3path = f"s3://podaac-ops-cumulus-protected/{collection}/*.nc"
remote_urls = s3.glob(s3path)
remote_urls = ['s3://' + f for f in remote_urls]
print(len(remote_urls))
17564
CPU times: user 6.94 s, sys: 214 ms, total: 7.15 s
Wall time: 9.96 s

Open files with s3fs

%%time
remote_files = [s3.open(file) for file in tqdm(remote_urls)]

## Large amount of s3 netCDF’s to individual JSON’s (dask) - The PODAAC SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1 collection has 17564 granules! - Recommended to test with a smaller dataset such as remote_files[:20] - Single threaded - 625 granules = ~25 mins - 13680 granules = ~540 mins - 4 threads/workers - 625 granules = ~8 mins - 13680 granules = ~180 mins

def remaining_jsons(remote_urls: list, directory: str):
    """
    Find set difference between remote_urls and JSONS already made in directory.
    Extracts granule name from remote url path and JSON file name
    """
    # Extract granule names from s3 endpoints list and jsons list
    remote_granules = np.asarray([path.split('/')[-1] for path in remote_urls])
    done_granules = np.asarray([path[:-5] for path in os.listdir(directory)])
    # Find remaining files (set difference: s3 endpoints - json directory)
    remaining_indices = np.where(~np.isin(remote_granules, done_granules))[0]
    new_remote_urls = [remote_urls[idx] for idx in remaining_indices]
    print(f"{len(remote_urls)}/{len(new_remote_urls)} files already done")
    return remote_urls
%%time
out_dir = f"{collection}_kerchunk_DEMO"
if not os.path.exists(out_dir):
    os.makedirs(out_dir)

# If session expired, only re-run on remote_urls we don't have a JSON for
remaining_urls = remaining_jsons(remote_urls, f"{collection}_kerchunk_DEMO")
    
@dask.delayed
def gen_json(u: str):
    "Generate JSON reference file for one netCDF"
    with fsspec.open(u, mode="rb", anon=False, key=creds['accessKeyId'], secret=creds['secretAccessKey'], token=creds["sessionToken"]) as inf:
        single = SingleHdf5ToZarr(inf, u, inline_threshold=0)
        # Extract granule name from s3 URL
        granule = re.search(r'[^/]+$', u).group(0)
        full_filename = f'{out_dir}/{granule}.json'
        # Generate single kerchunk reference
        s = single.translate()
        if len(s) == 0:
            warnings.warn(f"{granule} JSON generation failed")
        # Write JSON to folder
        with open(full_filename, 'w') as f:
            ujson.dump(s, f)
        return full_filename
            
# Define delayed list of jsons    
jsons = [dask.delayed(gen_json)(file) for file in remaining_urls]

# Run kerchunk single netcdf to JSON conversion in parallel
full_filenames = dask.compute(jsons)
1066/1368 files remaining
CPU times: user 3min 33s, sys: 22.4 s, total: 3min 55s
Wall time: 19min 36s

Create combined kerchunk

%%time
out_dir = f"{collection}_kerchunk_DEMO"
json_files = os.listdir(out_dir)
json_files = [f"{out_dir}/{f}" for f in json_files if f.endswith(".json")]
mzz = MultiZarrToZarr(json_files,
    remote_protocol='s3',
    remote_options={"anon": False, "key": creds['accessKeyId'], "secret": creds['secretAccessKey'], "token": creds["sessionToken"]},
    coo_map={"cycle_num": "attr:cycle_number", "pass_num": "attr:pass_number"},
    concat_dims=["cycle_num", "pass_num"]
)

out = mzz.translate()
CPU times: user 1min 50s, sys: 14.3 s, total: 2min 4s
Wall time: 15min 47s

The args passed to SingleHdf5ToZarr and MultiZarrToZarr define how the data is read and concatenated. These parameters will need to be modified depending on how you want to concat the datasets (time, cycles). The documentation for those functions is linked above. Linked in this cell is more resources for understanding and how to use Kerchunk.

%%time
json_file = f'{collection}_kerchunk_DEMO.json'
with open(json_file, 'wb') as f:
    f.write(ujson.dumps(out).encode())
CPU times: user 141 ms, sys: 44.4 ms, total: 185 ms
Wall time: 181 ms

Check by opening dataset

data = xr.open_dataset(
    "reference://", engine="zarr", chunks={}, decode_times=False,
    backend_kwargs={
        "storage_options": {
            "fo": f'{collection}_kerchunk_DEMO.json',
            "remote_protocol": "s3",
            "remote_options": {"anon": False, "key": creds['accessKeyId'], "secret": creds['secretAccessKey'], "token": creds["sessionToken"]}
        },
        "consolidated": False
    }
)
data
<xarray.Dataset>
Dimensions:                                (cycle_num: 3, pass_num: 456,
                                            num_lines: 9866, num_pixels: 71,
                                            num_sides: 2)
Coordinates:
  * cycle_num                              (cycle_num) float64 3.0 4.0 5.0
    latitude                               (cycle_num, pass_num, num_lines, num_pixels) float64 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
    latitude_nadir                         (cycle_num, pass_num, num_lines) float64 dask.array<chunksize=(1, 1, 9866), meta=np.ndarray>
    longitude                              (cycle_num, pass_num, num_lines, num_pixels) float64 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
    longitude_nadir                        (cycle_num, pass_num, num_lines) float64 dask.array<chunksize=(1, 1, 9866), meta=np.ndarray>
  * pass_num                               (pass_num) float64 1.0 2.0 ... 583.0
Dimensions without coordinates: num_lines, num_pixels, num_sides
Data variables: (12/91)
    ancillary_surface_classification_flag  (cycle_num, pass_num, num_lines, num_pixels) float32 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
    correction_flag                        (cycle_num, pass_num, num_lines, num_pixels) float32 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
    cross_track_angle                      (cycle_num, pass_num, num_lines) float64 dask.array<chunksize=(1, 1, 9866), meta=np.ndarray>
    cross_track_distance                   (cycle_num, pass_num, num_lines, num_pixels) float32 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
    dac                                    (cycle_num, pass_num, num_lines, num_pixels) float32 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
    depth_or_elevation                     (cycle_num, pass_num, num_lines, num_pixels) float32 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
    ...                                     ...
    wind_speed_karin                       (cycle_num, pass_num, num_lines, num_pixels) float32 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
    wind_speed_karin_2                     (cycle_num, pass_num, num_lines, num_pixels) float32 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
    wind_speed_model_u                     (cycle_num, pass_num, num_lines, num_pixels) float32 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
    wind_speed_model_v                     (cycle_num, pass_num, num_lines, num_pixels) float32 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
    wind_speed_rad                         (cycle_num, pass_num, num_lines, num_sides) float32 dask.array<chunksize=(1, 1, 9866, 2), meta=np.ndarray>
    x_factor                               (cycle_num, pass_num, num_lines, num_pixels) float32 dask.array<chunksize=(1, 1, 9866, 71), meta=np.ndarray>
Attributes: (12/32)
    Conventions:                CF-1.7
    contact:                    CNES aviso@altimetry.fr, JPL podaac@podaac.jp...
    cycle_number:               5
    ellipsoid_flattening:       0.003352810664781205
    ellipsoid_semi_major_axis:  6378137.0
    equator_longitude:          206.06188772087626
    ...                         ...
    right_last_longitude:       289.6585533138908
    source:                     Simulate product
    time_coverage_end:          2014-07-24T10:18:18.533147Z
    time_coverage_start:        2014-07-24T09:26:52.109265Z
    title:                      Level 2 Low Rate Sea Surface Height Data Prod...
    wavelength:                 0.008385803020979
  • Conventions :
    CF-1.7
    contact :
    CNES aviso@altimetry.fr, JPL podaac@podaac.jpl.nasa.gov
    cycle_number :
    5
    ellipsoid_flattening :
    0.003352810664781205
    ellipsoid_semi_major_axis :
    6378137.0
    equator_longitude :
    206.06188772087626
    equator_time :
    2014-07-24T09:52:36.962185Z
    geospatial_lat_max :
    78.29189230598696
    geospatial_lat_min :
    -78.29203183241484
    geospatial_lon_max :
    289.6585533138908
    geospatial_lon_min :
    122.7028213028531
    history :
    2021-09-10 10:00:06Z : Creation
    institution :
    CNES/JPL
    left_first_latitude :
    -77.032982418125
    left_first_longitude :
    122.7028213028531
    left_last_latitude :
    78.29189230598696
    left_last_longitude :
    289.65746162390826
    orbit_solution :
    POE
    pass_number :
    545
    platform :
    SWOT
    product_version :
    1.1.0.dev33
    reference_document :
    D-56407_SWOT_Product_Description_L2_LR_SSH
    references :
    Gaultier, L., C. Ubelmann, and L.-L. Fu, 2016: The Challenge of Using Future SWOT Data for Oceanic Field Reconstruction. J. Atmos. Oceanic Technol., 33, 119-126, doi:10.1175/jtech-d-15-0160.1. http://dx.doi.org/10.1175/JTECH-D-15-0160.1.
    right_first_latitude :
    -78.29203183241484
    right_first_longitude :
    122.70935482261133
    right_last_latitude :
    77.03284214129418
    right_last_longitude :
    289.6585533138908
    source :
    Simulate product
    time_coverage_end :
    2014-07-24T10:18:18.533147Z
    time_coverage_start :
    2014-07-24T09:26:52.109265Z
    title :
    Level 2 Low Rate Sea Surface Height Data Product - Expert SSH with Wind and Wave
    wavelength :
    0.008385803020979
  • Alternate: Large amount of s3 netCDF’s to combined kerchunk JSON (auto_dask)

    %%time
    # Create combined kerchunk/zarr reference. Reads cycle_number and pass_number from attributes of each netCDF
    # Concats along new dimensions cycle_number and pass_number
    out_dir = f"{collection}_kerchunk_DEMO"
    json_files = os.listdir(out_dir)
    json_files = [f"{out_dir}/{f}" for f in json_files if f.endswith(".json")]
    auto_dask(urls=json_files,
              n_batches=4,
              single_driver=JustLoad,
              single_kwargs={},
              mzz_kwargs={"coo_map": {"cycle_num": "attr:cycle_number", "pass_num": "attr:pass_number"}, "concat_dims": ["cycle_num", "pass_num"]},
              remote_protocol="s3",
              remote_options={"anon": False, "key": creds['accessKeyId'], "secret": creds['secretAccessKey'], "token": creds["sessionToken"]},
              filename=f'{collection}_kerchunk_CYCLE_3.json'
         )

    Appendix/Extras

    Generate list of matching size granules (faster metadata approach)

    s3 = s3fs.S3FileSystem(anon=False)
    s3path = 's3://podaac-ops-cumulus-protected/SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1/*.dmrpp'
    remote_files = s3.glob(s3path)
    print(len(remote_files))
    17564

    Read DMR++ metadata to get granule dimensions

    matches = []
    pbar = tqdm(remote_files, desc="Found 0 matches")
    for file in pbar:
        with s3.open(file, 'r') as f:
            xml_str: str = f.read()
        xml = ET.fromstring(xml_str)
        if xml[0].attrib['size'] == "9866" and xml[1].attrib['size'] == "71":
            matches.append(file)
            pbar.set_description(f"Found {len(matches)} matches")
    flist = ['s3://' + f.removesuffix(".dmrpp") for f in tqdm(matches)]
    print(flist[0])
    s3://podaac-ops-cumulus-protected/SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1/SWOT_L2_LR_SSH_Expert_001_001_20140412T120000_20140412T125126_DG10_01.nc
    file = open('SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1_9866_paths.txt','w')
    for item in tqdm(flist):
        file.write(item + "\n")
    file.close()

    Find the cycles with the most passes (size 9866 swath)

    cycles = {}
    z = glob.glob("SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1_kerchunk_CYCLE_3/*Expert_003_*")
    len(z)
    456
    file = open('SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1_9866_CYCLE_3.txt','w')
    for item in tqdm(all_urls):
        file.write("%s\n" % item)

    Generate list of matching size granules

    s3 = s3fs.S3FileSystem(anon=False)
    s3path = 's3://podaac-ops-cumulus-protected/SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1/*.nc'
    remote_files = s3.glob(s3path)
    print(len(remote_files))
    17564

    Connect to s3 netCDF’s

    %%time
    fileset = [s3.open(file) for file in remote_files]

    Find matching size granules (dask)

    # Define a function to check the condition and return remote_files if condition is met
    def is_match(f):
        ds = xr.open_dataset(f, engine="h5netcdf")
        if ds['simulated_true_ssh_karin'].encoding['chunksizes'] == (9866, 71):
            return f.path
        return None
    
    # Create a Dask bag from the fileset
    fileset_bag = dask.bag.from_sequence(fileset)
    
    # Use Dask to parallelize the processing and filter matches
    with dask.diagnostics.ProgressBar():
        matches_bag = fileset_bag.map(is_match).filter(lambda x: x is not None)
        matches = matches_bag.compute()
    
    print(f"{len(matches)} matches found")

    Find matching size granules

    matches = []
    pbar = tqdm(fileset, desc="Found 0 matches")
    for i, f in enumerate(pbar):
        ds = xr.open_dataset(f, engine="h5netcdf")
        if ds['ancillary_surface_classification_flag'].encoding['chunksizes'] == (9866, 71):
            matches.append(remote_files[i])
            pbar.set_description(f"Found {len(matches)} matches")
    print(matches[0])
    podaac-ops-cumulus-protected/SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1/SWOT_L2_LR_SSH_Expert_001_001_20140412T120000_20140412T125126_DG10_01.nc
    flist = ['s3://' + f for f in tqdm(matches)]
    print(flist[0])
    s3://podaac-ops-cumulus-protected/SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1/SWOT_L2_LR_SSH_Expert_001_001_20140412T120000_20140412T125126_DG10_01.nc

    Write file

    file = open('SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1_9866_paths.txt','w')
    for item in tqdm(flist):
        file.write(item + "\n")
    file.close()

    Read file

    f = open("SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1_9866_same_passes_paths.txt", "r")
    remote_urls = f.read().splitlines()
    # remote_urls = remote_urls.split("\n")
    # remote_urls.pop() # remove extra '' for EOF
    f.close()
    print(len(remote_urls))
    remote_urls[0]
    13680
    's3://podaac-ops-cumulus-protected/SWOT_SIMULATED_L2_KARIN_SSH_GLORYS_SCIENCE_V1/SWOT_L2_LR_SSH_Expert_001_140_20140417T111107_20140417T120233_DG10_01.nc'

    Find all granules in cycles 3, 4, 5

    cycles_3 = []
    for url in tqdm(remote_urls):
        if "Expert_003_" in url or "Expert_004_" in url or "Expert_005_" in url:
            cycles_3.append(url)
    len(cycles_3)
    1368
    %%time
    s3 = s3fs.S3FileSystem(anon=False, key=creds["accessKeyId"], secret=creds["secretAccessKey"], token=creds["sessionToken"])
    remote_files = [s3.open(file) for file in tqdm(cycle3)]
    CPU times: user 1.35 s, sys: 139 ms, total: 1.49 s
    Wall time: 8.25 s

    s3 NetCDF’s to Kerchunk (loop)

    %%time
    singles = []
    for i, u in enumerate(tqdm(remote_urls[:20])):
        with fsspec.open(u, mode="rb", anon=False, key=creds['accessKeyId'], secret=creds['secretAccessKey'], token=creds["sessionToken"]) as inf:
                        single = SingleHdf5ToZarr(inf, u, inline_threshold=0)
                        filename = re.sub(r'.*/', '', u)
                        singles.append(single.translate())

    ## kerchunk auto_dask definition - Note: this is built in to kerchunk as of v0.1.0 using from kerchunk.combine import auto_dask

    # Author: Martin Durant
    # https://fsspec.github.io/kerchunk/_modules/kerchunk/combine.html#auto_dask
    from typing import List
    def auto_dask(
        urls: List[str],
        single_driver: str,
        single_kwargs: dict,
        mzz_kwargs: dict,
        n_batches: int,
        remote_protocol=None,
        remote_options=None,
        filename=None,
        output_options=None,
    ):
        """Batched tree combine using dask.
    
        If you wish to run on a distributed cluster (recommended), create
        a client before calling this function.
    
        Parameters
        ----------
        urls: list[str]
            input dataset URLs
        single_driver: class
            class with ``translate()`` method
        single_kwargs: to pass to single-input driver
        mzz_kwargs: passed to ``MultiZarrToZarr`` for each batch
        n_batches: int
            Number of MZZ instances in the first combine stage. Maybe set equal
            to the number of dask workers, or a multple thereof.
        remote_protocol: str | None
        remote_options: dict
            To fsspec for opening the remote files
        filename: str | None
            Ouput filename, if writing
        output_options
            If ``filename`` is not None, open it with these options
    
        Returns
        -------
        reference set
        """
        import dask
    
        # make delayed functions
        single_task = dask.delayed(lambda x: single_driver(x, **single_kwargs).translate())
        post = mzz_kwargs.pop("postprocess", None)
        inline = mzz_kwargs.pop("inline_threshold", None)
        # TODO: if single files produce list of reference sets (e.g., grib2)
        batch_task = dask.delayed(
            lambda u, x: MultiZarrToZarr(
                u,
                indicts=x,
                remote_protocol=remote_protocol,
                remote_options=remote_options,
                **mzz_kwargs,
            ).translate()
        )
    
        # sort out kwargs
        dims = mzz_kwargs.get("concat_dims", [])
        dims += [k for k in mzz_kwargs.get("coo_map", []) if k not in dims]
        kwargs = {"concat_dims": dims}
        if post:
            kwargs["postprocess"] = post
        if inline:
            kwargs["inline_threshold"] = inline
        for field in ["remote_protocol", "remote_options", "coo_dtypes", "identical_dims"]:
            if field in mzz_kwargs:
                kwargs[field] = mzz_kwargs[field]
        final_task = dask.delayed(
            lambda x: MultiZarrToZarr(
                x, remote_options=remote_options, remote_protocol=remote_protocol, **kwargs
            ).translate(filename, output_options)
        )
    
        # make delayed calls
        tasks = [single_task(u) for u in urls]
        tasks_per_batch = -(-len(tasks) // n_batches)
        tasks2 = []
        for batch in range(n_batches):
            in_tasks = tasks[batch * tasks_per_batch : (batch + 1) * tasks_per_batch]
            u = urls[batch * tasks_per_batch : (batch + 1) * tasks_per_batch]
            # if in_tasks:
                # skip if on last iteration and no remaining tasks
                # tasks2.append(batch_task(u, in_tasks))
        return dask.compute(final_task(tasks2))[0]
    
    
    
    class JustLoad:
        """For auto_dask, in the case that single file references already exist"""
    
        def __init__(self, url, storage_options=None):
            self.url = url
            self.storage_options = storage_options or {}
    
        def translate(self):
            with fsspec.open(self.url, mode="rt", **self.storage_options) as f:
                return ujson.load(f)