Skip to content

Commit

Permalink
Decoding cftime_variables (#122)
Browse files Browse the repository at this point in the history
* Try to get cftime_variables decoding

* Include encoding in zattrs

* Add comment to test

* Move functions into utils, parametrize test

* Improve docstrings/errors

* Make a hacky assert_allclose

* Use us precision so that Python 3.10 works

* Add note to release doc

* Use xarray Encoder and don't allow implicit decode_times

* Fix up existing tests

* Removing print statments
  • Loading branch information
jsignell committed Jun 28, 2024
1 parent 412c23c commit 10e7863
Show file tree
Hide file tree
Showing 5 changed files with 80 additions and 17 deletions.
2 changes: 2 additions & 0 deletions docs/releases.rst
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ New Features

- Added a `.rename_paths` convenience method to rename paths in a manifest according to a function.
(:pull:`152`) By `Tom Nicholas <https://github.com/TomNicholas>`_.
- New ``cftime_variables`` option on ``open_virtual_dataset`` enables encoding/decoding time.
(:pull:`122`) By `Julia Signell <https://github.com/jsignell>`_.

Breaking changes
~~~~~~~~~~~~~~~~
Expand Down
9 changes: 5 additions & 4 deletions virtualizarr/kerchunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
import numpy as np
import ujson # type: ignore
import xarray as xr
from xarray.coding.times import CFDatetimeCoder

from virtualizarr.manifests.manifest import join
from virtualizarr.utils import _fsspec_openfile_from_filepath
Expand Down Expand Up @@ -49,6 +50,8 @@ def default(self, obj):
return obj.tolist() # Convert NumPy array to Python list
elif isinstance(obj, np.generic):
return obj.item() # Convert NumPy scalar to Python scalar
elif isinstance(obj, np.dtype):
return str(obj)
return json.JSONEncoder.default(self, obj)


Expand Down Expand Up @@ -274,9 +277,7 @@ def variable_to_kerchunk_arr_refs(var: xr.Variable, var_name: str) -> KerchunkAr
f"Cannot serialize loaded variable {var_name}, as it is encoded with an offset"
)
if "calendar" in var.encoding:
raise NotImplementedError(
f"Cannot serialize loaded variable {var_name}, as it is encoded with a calendar"
)
np_arr = CFDatetimeCoder().encode(var.copy(), name=var_name).values

# This encoding is what kerchunk does when it "inlines" data, see https://github.com/fsspec/kerchunk/blob/a0c4f3b828d37f6d07995925b324595af68c4a19/kerchunk/hdf.py#L472
byte_data = np_arr.tobytes()
Expand All @@ -297,7 +298,7 @@ def variable_to_kerchunk_arr_refs(var: xr.Variable, var_name: str) -> KerchunkAr
zarray_dict = zarray.to_kerchunk_json()
arr_refs[".zarray"] = zarray_dict

zattrs = var.attrs
zattrs = {**var.attrs, **var.encoding}
zattrs["_ARRAY_DIMENSIONS"] = list(var.dims)
arr_refs[".zattrs"] = json.dumps(zattrs, separators=(",", ":"), cls=NumpyEncoder)

Expand Down
44 changes: 35 additions & 9 deletions virtualizarr/tests/test_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,9 +74,10 @@ def test_kerchunk_roundtrip_no_concat(self, tmpdir, format):
# assert identical to original dataset
xrt.assert_identical(roundtrip, ds)

def test_kerchunk_roundtrip_concat(self, tmpdir, format):
@pytest.mark.parametrize("decode_times,time_vars", [(False, []), (True, ["time"])])
def test_kerchunk_roundtrip_concat(self, tmpdir, format, decode_times, time_vars):
# set up example xarray dataset
ds = xr.tutorial.open_dataset("air_temperature", decode_times=False)
ds = xr.tutorial.open_dataset("air_temperature", decode_times=decode_times)

# split into two datasets
ds1, ds2 = ds.isel(time=slice(None, 1460)), ds.isel(time=slice(1460, None))
Expand All @@ -86,8 +87,25 @@ def test_kerchunk_roundtrip_concat(self, tmpdir, format):
ds2.to_netcdf(f"{tmpdir}/air2.nc")

# use open_dataset_via_kerchunk to read it as references
vds1 = open_virtual_dataset(f"{tmpdir}/air1.nc", indexes={})
vds2 = open_virtual_dataset(f"{tmpdir}/air2.nc", indexes={})
vds1 = open_virtual_dataset(
f"{tmpdir}/air1.nc",
indexes={},
loadable_variables=time_vars,
cftime_variables=time_vars,
)
vds2 = open_virtual_dataset(
f"{tmpdir}/air2.nc",
indexes={},
loadable_variables=time_vars,
cftime_variables=time_vars,
)

if decode_times is False:
assert vds1.time.dtype == np.dtype("float32")
else:
assert vds1.time.dtype == np.dtype("<M8[ns]")
assert "units" in vds1.time.encoding
assert "calendar" in vds1.time.encoding

# concatenate virtually along time
vds = xr.concat([vds1, vds2], dim="time", coords="minimal", compat="override")
Expand All @@ -97,18 +115,26 @@ def test_kerchunk_roundtrip_concat(self, tmpdir, format):
ds_refs = vds.virtualize.to_kerchunk(format=format)

# use fsspec to read the dataset from the kerchunk references dict
roundtrip = xr.open_dataset(ds_refs, engine="kerchunk", decode_times=False)
roundtrip = xr.open_dataset(
ds_refs, engine="kerchunk", decode_times=decode_times
)
else:
# write those references to disk as kerchunk references format
vds.virtualize.to_kerchunk(f"{tmpdir}/refs.{format}", format=format)

# use fsspec to read the dataset from disk via the kerchunk references
roundtrip = xr.open_dataset(
f"{tmpdir}/refs.{format}", engine="kerchunk", decode_times=False
f"{tmpdir}/refs.{format}", engine="kerchunk", decode_times=decode_times
)

# assert identical to original dataset
xrt.assert_identical(roundtrip, ds)
if decode_times is False:
# assert identical to original dataset
xrt.assert_identical(roundtrip, ds)
else:
# they are very very close! But assert_allclose doesn't seem to work on datetimes
assert (roundtrip.time - ds.time).sum() == 0
assert roundtrip.time.dtype == ds.time.dtype
assert roundtrip.time.encoding["units"] == ds.time.encoding["units"]
assert roundtrip.time.encoding["calendar"] == ds.time.encoding["calendar"]

def test_non_dimension_coordinates(self, tmpdir, format):
# regression test for GH issue #105
Expand Down
11 changes: 9 additions & 2 deletions virtualizarr/tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ def test_no_indexes(self, netcdf4_file):

def test_create_default_indexes(self, netcdf4_file):
vds = open_virtual_dataset(netcdf4_file, indexes=None)
ds = xr.open_dataset(netcdf4_file)
ds = xr.open_dataset(netcdf4_file, decode_times=False)

# TODO use xr.testing.assert_identical(vds.indexes, ds.indexes) instead once class supported by assertion comparison, see https://github.com/pydata/xarray/issues/5812
assert index_mappings_equal(vds.xindexes, ds.xindexes)
Expand Down Expand Up @@ -344,7 +344,7 @@ def test_loadable_variables(self, netcdf4_file):
else:
assert isinstance(vds[name].data, ManifestArray), name

full_ds = xr.open_dataset(netcdf4_file)
full_ds = xr.open_dataset(netcdf4_file, decode_times=False)

for name in full_ds.variables:
if name in vars_to_load:
Expand Down Expand Up @@ -413,3 +413,10 @@ def test_mixture_of_manifestarrays_and_numpy_arrays(self, netcdf4_file):
== "s3://bucket/air.nc"
)
assert isinstance(renamed_vds["lat"].data, np.ndarray)


def test_cftime_variables_must_be_in_loadable_variables(tmpdir):
ds = xr.Dataset(data_vars={"time": ["2024-06-21"]})
ds.to_netcdf(f"{tmpdir}/scalar.nc")
with pytest.raises(ValueError, match="'time' not in"):
open_virtual_dataset(f"{tmpdir}/scalar.nc", cftime_variables=["time"])
31 changes: 29 additions & 2 deletions virtualizarr/xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import xarray as xr
from xarray import register_dataset_accessor
from xarray.backends import BackendArray
from xarray.coding.times import CFDatetimeCoder
from xarray.core.indexes import Index, PandasIndex
from xarray.core.variable import IndexVariable

Expand All @@ -33,9 +34,11 @@ class ManifestBackendArray(ManifestArray, BackendArray):

def open_virtual_dataset(
filepath: str,
*,
filetype: FileType | None = None,
drop_variables: Iterable[str] | None = None,
loadable_variables: Iterable[str] | None = None,
cftime_variables: Iterable[str] | None = None,
indexes: Mapping[str, Index] | None = None,
virtual_array_class=ManifestArray,
reader_options: Optional[dict] = {
Expand Down Expand Up @@ -69,6 +72,10 @@ def open_virtual_dataset(
virtual_array_class
Virtual array class to use to represent the references to the chunks in each on-disk array.
Currently can only be ManifestArray, but once VirtualZarrArray is implemented the default should be changed to that.
cftime_variables : list[str], default is None
Interpret the value of specified vars using cftime, returning a datetime.
These will be automatically re-encoded with cftime. This list must be a subset
of ``loadable_variables``.
reader_options: dict, default {'storage_options':{'key':'', 'secret':'', 'anon':True}}
Dict passed into Kerchunk file readers. Note: Each Kerchunk file reader has distinct arguments,
so ensure reader_options match selected Kerchunk reader arguments.
Expand All @@ -95,6 +102,20 @@ def open_virtual_dataset(
if common:
raise ValueError(f"Cannot both load and drop variables {common}")

if cftime_variables is None:
cftime_variables = []
elif isinstance(cftime_variables, str):
cftime_variables = [cftime_variables]
else:
cftime_variables = list(cftime_variables)

if diff := (set(cftime_variables) - set(loadable_variables)):
missing_str = ", ".join([f"'{v}'" for v in diff])
raise ValueError(
"All ``cftime_variables`` must be included in ``loadable_variables`` "
f"({missing_str} not in ``loadable_variables``)"
)

if virtual_array_class is not ManifestArray:
raise NotImplementedError()

Expand Down Expand Up @@ -127,7 +148,9 @@ def open_virtual_dataset(
filepath=filepath, reader_options=reader_options
)

ds = xr.open_dataset(fpath, drop_variables=drop_variables)
ds = xr.open_dataset(
fpath, drop_variables=drop_variables, decode_times=False
)

if indexes is None:
# add default indexes by reading data from file
Expand All @@ -144,6 +167,10 @@ def open_virtual_dataset(
if name in loadable_variables
}

for name in cftime_variables:
var = loadable_vars[name]
loadable_vars[name] = CFDatetimeCoder().decode(var, name=name)

# if we only read the indexes we can just close the file right away as nothing is lazy
if loadable_vars == {}:
ds.close()
Expand Down Expand Up @@ -326,7 +353,7 @@ def separate_coords(
# use workaround to avoid creating IndexVariables described here https://github.com/pydata/xarray/pull/8107#discussion_r1311214263
if len(var.dims) == 1:
dim1d, *_ = var.dims
coord_vars[name] = (dim1d, var.data, var.attrs)
coord_vars[name] = (dim1d, var.data, var.attrs, var.encoding)

if isinstance(var, IndexVariable):
# unless variable actually already is a loaded IndexVariable,
Expand Down

0 comments on commit 10e7863

Please sign in to comment.