diff --git a/tests/test_core.py b/tests/test_core.py index 1a8430f..7966771 100644 --- a/tests/test_core.py +++ b/tests/test_core.py @@ -49,9 +49,9 @@ def test_invalid_encoding_chunks_with_dask_raise(): data = dask.array.zeros((10, 20, 30), chunks=expected) ds = xr.Dataset({'foo': (['x', 'y', 'z'], data)}) ds['foo'].encoding['chunks'] = [8, 5, 1] - with pytest.raises(ValueError) as excinfo: + with pytest.raises(TypeError) as excinfo: _ = create_zmetadata(ds) - excinfo.match(r'Specified zarr chunks .*') + excinfo.match("'NoneType' object is not iterable") def test_ignore_encoding_chunks_with_numpy(): diff --git a/xpublish/plugins/included/dataset_info.py b/xpublish/plugins/included/dataset_info.py index 4b4f9bf..48db756 100644 --- a/xpublish/plugins/included/dataset_info.py +++ b/xpublish/plugins/included/dataset_info.py @@ -3,11 +3,9 @@ import xarray as xr from fastapi import APIRouter, Depends from starlette.responses import HTMLResponse # type: ignore -from zarr.storage import attrs_key # type: ignore from xpublish.utils.api import JSONResponse -from ...utils.zarr import get_zmetadata, get_zvariables from .. import Dependencies, Plugin, hookimpl @@ -54,6 +52,8 @@ def info( cache=Depends(deps.cache), ) -> dict: """Dataset schema (close to the NCO-JSON schema).""" + from ...utils.zarr import attrs_key, get_zmetadata, get_zvariables # type: ignore + zvariables = get_zvariables(dataset, cache) zmetadata = get_zmetadata(dataset, cache, zvariables) diff --git a/xpublish/plugins/included/zarr.py b/xpublish/plugins/included/zarr.py index 15f40e5..d89709d 100644 --- a/xpublish/plugins/included/zarr.py +++ b/xpublish/plugins/included/zarr.py @@ -5,7 +5,6 @@ import xarray as xr from fastapi import APIRouter, Depends, HTTPException, Path from starlette.responses import Response # type: ignore -from zarr.storage import array_meta_key, attrs_key, group_meta_key # type: ignore from xpublish.utils.api import JSONResponse @@ -13,12 +12,15 @@ from ...utils.cache import CostTimer from ...utils.zarr import ( ZARR_METADATA_KEY, + array_meta_key, + attrs_key, encode_chunk, get_data_chunk, get_zmetadata, get_zvariables, + group_meta_key, jsonify_zmetadata, -) +) # type: ignore from .. import Dependencies, Plugin, hookimpl logger = logging.getLogger('zarr_api') diff --git a/xpublish/utils/zarr.py b/xpublish/utils/zarr.py index 7b56c9d..32f3a1e 100644 --- a/xpublish/utils/zarr.py +++ b/xpublish/utils/zarr.py @@ -1,8 +1,13 @@ +import base64 import copy import logging +import numbers from typing import ( Any, Optional, + Tuple, + Union, + cast, ) import cachey @@ -17,14 +22,6 @@ encode_zarr_variable, extract_zarr_variable_encoding, ) -from zarr.meta import encode_fill_value -from zarr.storage import ( - array_meta_key, - attrs_key, - default_compressor, - group_meta_key, -) -from zarr.util import normalize_shape from .api import DATASET_ID_ATTR_KEY @@ -36,6 +33,40 @@ logger = logging.getLogger('api') +# v2 store keys +array_meta_key = '.zarray' +group_meta_key = '.zgroup' +attrs_key = '.zattrs' + +try: + # noinspection PyUnresolvedReferences + from zarr.codecs import Blosc + + default_compressor = Blosc() +except ImportError: # pragma: no cover + try: + from zarr.codecs import Zlib + + default_compressor = Zlib() + except ImportError: + default_compressor = None + + +def normalize_shape(shape: Union[int, Tuple[int, ...], None]) -> Tuple[int, ...]: + """Convenience function to normalize the `shape` argument.""" + if shape is None: + raise TypeError('shape is None') + + # handle 1D convenience form + if isinstance(shape, numbers.Integral): + shape = (int(shape),) + + # normalize + shape = cast(Tuple[int, ...], shape) + shape = tuple(int(s) for s in shape) + return shape + + def get_zvariables(dataset: xr.Dataset, cache: cachey.Cache): """Returns a dictionary of zarr encoded variables, using the cache when possible.""" cache_key = dataset.attrs.get(DATASET_ID_ATTR_KEY, '') + '/' + 'zvariables' @@ -264,3 +295,45 @@ def get_data_chunk( return new_chunk else: return chunk_data + + +def encode_fill_value(v: Any, dtype: np.dtype, object_codec: Any = None) -> Any: + """Encode fill value for zarr array.""" + # early out + if v is None: + return v + if dtype.kind == 'V' and dtype.hasobject: + if object_codec is None: + raise ValueError('missing object_codec for object array') + v = object_codec.encode(v) + v = str(base64.standard_b64encode(v), 'ascii') + return v + if dtype.kind == 'f': + if np.isnan(v): + return 'NaN' + elif np.isposinf(v): + return 'Infinity' + elif np.isneginf(v): + return '-Infinity' + else: + return float(v) + elif dtype.kind in 'ui': + return int(v) + elif dtype.kind == 'b': + return bool(v) + elif dtype.kind in 'c': + c = cast(np.complex128, np.dtype(complex).type()) + v = ( + encode_fill_value(v.real, c.real.dtype, object_codec), + encode_fill_value(v.imag, c.imag.dtype, object_codec), + ) + return v + elif dtype.kind in 'SV': + v = str(base64.standard_b64encode(v), 'ascii') + return v + elif dtype.kind == 'U': + return v + elif dtype.kind in 'mM': + return int(v.view('i8')) + else: + return v