Skip to content

Commit

Permalink
try using zarr array metadata
Browse files Browse the repository at this point in the history
  • Loading branch information
ghidalgo3 committed Jul 9, 2024
1 parent 92a7e81 commit 42b3f3a
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 192 deletions.
13 changes: 7 additions & 6 deletions virtualizarr/kerchunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
import ujson # type: ignore
import xarray as xr
from xarray.coding.times import CFDatetimeCoder
from zarr.array import ArrayMetadata, ArrayV2Metadata

from virtualizarr.manifests.manifest import join
from virtualizarr.utils import _fsspec_openfile_from_filepath
from virtualizarr.zarr import ZArray, ZAttrs
from virtualizarr.zarr import ZAttrs, from_kerchunk_refs, to_kerchunk_json

# Distinguishing these via type hints makes it a lot easier to mentally keep track of what the opaque kerchunk "reference dicts" actually mean
# (idea from https://kobzol.github.io/rust/python/2023/05/20/writing-python-like-its-rust.html)
Expand Down Expand Up @@ -195,8 +196,8 @@ def extract_array_refs(

def parse_array_refs(
arr_refs: KerchunkArrRefs,
) -> tuple[dict, ZArray, ZAttrs]:
zarray = ZArray.from_kerchunk_refs(arr_refs.pop(".zarray"))
) -> tuple[dict, ArrayMetadata, ZAttrs]:
zarray = from_kerchunk_refs(arr_refs.pop(".zarray"))
zattrs = arr_refs.pop(".zattrs", {})
chunk_dict = arr_refs

Expand Down Expand Up @@ -296,15 +297,15 @@ def variable_to_kerchunk_arr_refs(var: xr.Variable, var_name: str) -> KerchunkAr
# TODO can this be generalized to save individual chunks of a dask array?
# TODO will this fail for a scalar?
arr_refs = {join(0 for _ in np_arr.shape): inlined_data}

zarray = ZArray(
zarray = ArrayV2Metadata(
chunks=np_arr.shape,
shape=np_arr.shape,
dtype=np_arr.dtype,
order="C",
fill_value=np.nan,
)

zarray_dict = zarray.to_kerchunk_json()
zarray_dict = to_kerchunk_json(zarray)
arr_refs[".zarray"] = zarray_dict

zattrs = {**var.attrs, **var.encoding}
Expand Down
38 changes: 27 additions & 11 deletions virtualizarr/manifests/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from typing import Any, Callable, Union

import numpy as np
from zarr.array import ArrayMetadata, ArrayV2Metadata, ArrayV3Metadata, RegularChunkGrid

from ..kerchunk import KerchunkArrRefs
from ..zarr import ZArray
from .array_api import MANIFESTARRAY_HANDLED_ARRAY_FUNCTIONS
from .manifest import ChunkManifest

Expand All @@ -22,27 +22,32 @@ class ManifestArray:
"""

_manifest: ChunkManifest
_zarray: ZArray
_zarray: ArrayMetadata

def __init__(
self,
zarray: ZArray | dict,
zarray: ArrayMetadata | dict,
chunkmanifest: dict | ChunkManifest,
) -> None:
"""
Create a ManifestArray directly from the .zarray information of a zarr array and the manifest of chunks.
Parameters
----------
zarray : dict or ZArray
zarray : dict or zarr.array.ArrayMetadata
chunkmanifest : dict or ChunkManifest
"""

if isinstance(zarray, ZArray):
_zarray = zarray
else:
# try unpacking the dict
_zarray = ZArray(**zarray)
match zarray:
case ArrayMetadata():
_zarray = zarray
case dict():
zarray = zarray.copy()
zarr_format = zarray.pop("zarr_format", None)
if zarr_format == 3:
_zarray = ArrayV3Metadata(**zarray)
else:
_zarray = ArrayV2Metadata(**zarray)

if isinstance(chunkmanifest, ChunkManifest):
_chunkmanifest = chunkmanifest
Expand Down Expand Up @@ -79,12 +84,20 @@ def manifest(self) -> ChunkManifest:
return self._manifest

@property
def zarray(self) -> ZArray:
def zarray(self) -> ArrayMetadata:
return self._zarray

@property
def chunks(self) -> tuple[int, ...]:
return tuple(self.zarray.chunks)
"""
Individual chunk size by number of elements.
"""
if isinstance(self._zarray.chunk_grid, RegularChunkGrid):
return self._zarray.chunk_grid.chunk_shape
else:
raise NotImplementedError(
"Only RegularChunkGrid is currently supported for chunk size"
)

@property
def dtype(self) -> np.dtype:
Expand All @@ -93,6 +106,9 @@ def dtype(self) -> np.dtype:

@property
def shape(self) -> tuple[int, ...]:
"""
Array shape by number of elements along each dimension.
"""
return tuple(int(length) for length in list(self.zarray.shape))

@property
Expand Down
40 changes: 27 additions & 13 deletions virtualizarr/manifests/array_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
from typing import TYPE_CHECKING, Callable, Iterable
from dataclasses import replace
from typing import TYPE_CHECKING, Callable, Iterable, Union

import numpy as np
from zarr.abc.codec import Codec as ZCodec
from zarr.array import ArrayMetadata, ArrayV2Metadata, ArrayV3Metadata, RegularChunkGrid

from virtualizarr.zarr import Codec, ceildiv

Expand Down Expand Up @@ -34,7 +37,7 @@ def _check_combineable_zarr_arrays(arrays: Iterable["ManifestArray"]) -> None:

# Can't combine different codecs in one manifest
# see https://github.com/zarr-developers/zarr-specs/issues/288
_check_same_codecs([arr.zarray.codec for arr in arrays])
_check_same_codecs([arr.zarray for arr in arrays])

# Would require variable-length chunks ZEP
_check_same_chunk_shapes([arr.chunks for arr in arrays])
Expand All @@ -51,7 +54,20 @@ def _check_same_dtypes(dtypes: list[np.dtype]) -> None:
)


def _check_same_codecs(codecs: list[Codec]) -> None:
def _check_same_codecs(zarrays: list[ArrayMetadata]) -> None:
if len({zarry.zarr_format for zarry in zarrays}) > 1:
raise ValueError("Cannot concatenate arrays with different zarr formats.")

def to_codec(zarray: ArrayMetadata) -> Union[Codec | tuple[ZCodec, ...]]:
match zarray:
case ArrayV2Metadata(compressor=compressor, filters=filters):
return Codec(compressor=compressor, filters=filters)
case ArrayV3Metadata(codecs=codecs):
return codecs
case _:
raise ValueError("Unknown ArrayMetadata type")

codecs = [to_codec(zarray) for zarray in zarrays]
first_codec, *other_codecs = codecs
for codec in other_codecs:
if codec != first_codec:
Expand Down Expand Up @@ -144,10 +160,7 @@ def concatenate(
)

# chunk shape has not changed, there are just now more chunks along the concatenation axis
new_zarray = first_arr.zarray.replace(
shape=tuple(new_shape),
)

new_zarray = replace(first_arr.zarray, shape=tuple(new_shape))
return ManifestArray(chunkmanifest=concatenated_manifest, zarray=new_zarray)


Expand Down Expand Up @@ -239,10 +252,10 @@ def stack(
old_chunks = first_arr.chunks
new_chunks = list(old_chunks)
new_chunks.insert(axis, 1)

new_zarray = first_arr.zarray.replace(
chunks=tuple(new_chunks),
new_zarray = replace(
first_arr.zarray,
shape=tuple(new_shape),
chunk_grid=RegularChunkGrid(chunk_shape=tuple(new_chunks)),
)

return ManifestArray(chunkmanifest=stacked_manifest, zarray=new_zarray)
Expand Down Expand Up @@ -314,9 +327,10 @@ def broadcast_to(x: "ManifestArray", /, shape: tuple[int, ...]) -> "ManifestArra
lengths=broadcasted_lengths,
)

new_zarray = x.zarray.replace(
chunks=new_chunk_shape,
shape=new_shape,
new_zarray = replace(
x.zarray,
shape=tuple(new_shape),
chunk_grid=RegularChunkGrid(chunk_shape=tuple(new_chunk_shape)),
)

return ManifestArray(chunkmanifest=broadcasted_manifest, zarray=new_zarray)
Expand Down
8 changes: 4 additions & 4 deletions virtualizarr/tests/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,11 @@
import numpy as np
import pytest
from packaging.version import Version
from zarr.array import ArrayV2Metadata

from virtualizarr.manifests import ChunkManifest, ManifestArray
from virtualizarr.manifests.manifest import join
from virtualizarr.zarr import ZArray, ceildiv
from virtualizarr.zarr import ceildiv

network = pytest.mark.network

Expand Down Expand Up @@ -46,15 +47,14 @@ def create_manifestarray(
The manifest is populated with a (somewhat) unique path, offset, and length for each key.
"""

zarray = ZArray(
zarray = ArrayV2Metadata(
chunks=chunks,
compressor="zlib",
compressor={"id": "zlib"},
dtype=np.dtype("float32"),
fill_value=0.0, # TODO change this to NaN?
filters=None,
order="C",
shape=shape,
zarr_format=2,
)

chunk_grid_shape = tuple(
Expand Down
11 changes: 6 additions & 5 deletions virtualizarr/tests/test_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@ def test_zarr_v3_roundtrip(tmpdir):
),
zarray=dict(
shape=(2, 3),
dtype=np.dtype("<i8"),
chunks=(2, 3),
compressor=None,
filters=None,
data_type=np.dtype("<i8"),
chunk_grid={"name": "regular", "configuration": {"chunk_shape": [2, 3]}},
chunk_key_encoding={"name": "default", "configuration": {"separator": "."}},
codecs=(),
attributes={},
dimension_names=None,
fill_value=np.nan,
order="C",
zarr_format=3,
),
)
Expand Down
Loading

0 comments on commit 42b3f3a

Please sign in to comment.