Skip to content

Commit

Permalink
Try out replacing ZArray
Browse files Browse the repository at this point in the history
  • Loading branch information
ghidalgo3 committed Jul 2, 2024
1 parent 91ebefe commit 5e12b88
Show file tree
Hide file tree
Showing 6 changed files with 55 additions and 59 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ dependencies = [
"ujson",
"packaging",
"universal-pathlib",
"zarr>=3.0.0a0"
]

[project.optional-dependencies]
Expand Down
15 changes: 8 additions & 7 deletions virtualizarr/kerchunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,11 @@
import ujson # type: ignore
import xarray as xr
from xarray.coding.times import CFDatetimeCoder
from zarr.array import Array

from virtualizarr.manifests.manifest import join
from virtualizarr.utils import _fsspec_openfile_from_filepath
from virtualizarr.zarr import ZArray, ZAttrs
from virtualizarr.zarr import ZAttrs

# Distinguishing these via type hints makes it a lot easier to mentally keep track of what the opaque kerchunk "reference dicts" actually mean
# (idea from https://kobzol.github.io/rust/python/2023/05/20/writing-python-like-its-rust.html)
Expand Down Expand Up @@ -195,8 +196,8 @@ def extract_array_refs(

def parse_array_refs(
arr_refs: KerchunkArrRefs,
) -> tuple[dict, ZArray, ZAttrs]:
zarray = ZArray.from_kerchunk_refs(arr_refs.pop(".zarray"))
) -> tuple[dict, Array, ZAttrs]:
zarray = Array.from_kerchunk_refs(arr_refs.pop(".zarray"))
zattrs = arr_refs.pop(".zattrs", {})
chunk_dict = arr_refs

Expand Down Expand Up @@ -297,14 +298,14 @@ def variable_to_kerchunk_arr_refs(var: xr.Variable, var_name: str) -> KerchunkAr
# TODO will this fail for a scalar?
arr_refs = {join(0 for _ in np_arr.shape): inlined_data}

zarray = ZArray(
chunks=np_arr.shape,
zarray = Array.create(
store=None, # type: ignore
shape=np_arr.shape,
dtype=np_arr.dtype,
order="C",
chunk_shape=np_arr.shape,
)

zarray_dict = zarray.to_kerchunk_json()
zarray_dict = ujson.dumps(zarray)
arr_refs[".zarray"] = zarray_dict

zattrs = {**var.attrs, **var.encoding}
Expand Down
2 changes: 1 addition & 1 deletion virtualizarr/manifests/array.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,9 @@
from typing import Any, Callable, Union

import numpy as np
from zarr.array import Array as ZArray

from ..kerchunk import KerchunkArrRefs
from ..zarr import ZArray
from .array_api import MANIFESTARRAY_HANDLED_ARRAY_FUNCTIONS
from .manifest import ChunkManifest

Expand Down
26 changes: 18 additions & 8 deletions virtualizarr/manifests/array_api.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from dataclasses import replace
from typing import TYPE_CHECKING, Callable, Iterable

import numpy as np
from zarr.metadata import ArrayV3Metadata

from virtualizarr.zarr import Codec, ceildiv

Expand Down Expand Up @@ -34,7 +36,15 @@ def _check_combineable_zarr_arrays(arrays: Iterable["ManifestArray"]) -> None:

# Can't combine different codecs in one manifest
# see https://github.com/zarr-developers/zarr-specs/issues/288
_check_same_codecs([arr.zarray.codec for arr in arrays])
# If we want to support Zarr's v2 and v3 metadata, we have to branch here
# based on the type of arr.zarray.metadata
_check_same_codecs(
[
arr.zarray.metadata.codecs # type: ignore
for arr in arrays
if isinstance(arr.zarray.metadata, ArrayV3Metadata)
]
)

# Would require variable-length chunks ZEP
_check_same_chunk_shapes([arr.chunks for arr in arrays])
Expand Down Expand Up @@ -144,9 +154,7 @@ def concatenate(
)

# chunk shape has not changed, there are just now more chunks along the concatenation axis
new_zarray = first_arr.zarray.replace(
shape=tuple(new_shape),
)
new_zarray = replace(first_arr.zarray, shape=tuple(new_shape))

return ManifestArray(chunkmanifest=concatenated_manifest, zarray=new_zarray)

Expand Down Expand Up @@ -240,8 +248,9 @@ def stack(
new_chunks = list(old_chunks)
new_chunks.insert(axis, 1)

new_zarray = first_arr.zarray.replace(
chunks=tuple(new_chunks),
new_zarray = replace(
first_arr.zarray,
chunk_shape=tuple(new_chunks),
shape=tuple(new_shape),
)

Expand Down Expand Up @@ -314,8 +323,9 @@ def broadcast_to(x: "ManifestArray", /, shape: tuple[int, ...]) -> "ManifestArra
lengths=broadcasted_lengths,
)

new_zarray = x.zarray.replace(
chunks=new_chunk_shape,
new_zarray = replace(
x.zarray,
chunk_shape=new_chunk_shape,
shape=new_shape,
)

Expand Down
33 changes: 21 additions & 12 deletions virtualizarr/tests/test_xarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
import xarray as xr
import xarray.testing as xrt
from xarray.core.indexes import Index
from zarr.array import Array
from zarr.codecs import BytesCodec, ZstdCodec

from virtualizarr import open_virtual_dataset
from virtualizarr.manifests import ChunkManifest, ManifestArray
Expand All @@ -17,15 +19,16 @@ def test_wrapping():
chunks = (5, 10)
shape = (5, 20)
dtype = np.dtype("int32")
zarray = ZArray(
chunks=chunks,
compressor="zlib",
# This passes for V3
zarray = Array.create(
store=None,
shape=shape,
dtype=dtype,
chunk_shape=chunks,
codecs=[BytesCodec(), ZstdCodec()],
fill_value=0.0,
filters=None,
order="C",
shape=shape,
zarr_format=2,
zarr_format=3,
)

chunks_dict = {
Expand All @@ -47,9 +50,11 @@ class TestEquals:
def test_equals(self):
chunks = (5, 10)
shape = (5, 20)
zarray = ZArray(
chunks=chunks,
compressor="zlib",
# This passes for v2
zarray = Array.create(
store=None,
chunk_shape=chunks,
compressor=dict(id="zlib", level=1),
dtype=np.dtype("int32"),
fill_value=0.0,
filters=None,
Expand Down Expand Up @@ -84,9 +89,13 @@ def test_equals(self):
class TestConcat:
def test_concat_along_existing_dim(self):
# both manifest arrays in this example have the same zarray properties
zarray = ZArray(
chunks=(1, 10),
compressor="zlib",
# Does this need to work for both Zarr v2 and v3?
# Because eventually the zarray.metadata object is different and the
# concatenation check has to branch based on v2 and v3
zarray = Array.create(
store=None,
chunk_shape=(1, 10),
compressor=dict(id="zlib", level=1),
dtype=np.dtype("int32"),
fill_value=0.0,
filters=None,
Expand Down
37 changes: 6 additions & 31 deletions virtualizarr/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,13 @@
Any,
Literal,
NewType,
Optional,
)

import numpy as np
import ujson # type: ignore
import xarray as xr
from pydantic import BaseModel, ConfigDict, field_validator
from zarr.array import Array

from virtualizarr.vendor.zarr.utils import json_dumps

Expand Down Expand Up @@ -106,33 +106,9 @@ def dict(self) -> dict[str, Any]:

return zarray_dict

def to_kerchunk_json(self) -> str:
return ujson.dumps(self.dict())

def replace(
self,
chunks: Optional[tuple[int, ...]] = None,
compressor: Optional[str] = None,
dtype: Optional[np.dtype] = None,
fill_value: Optional[float] = None, # float or int?
filters: Optional[list[dict]] = None, # type: ignore[valid-type]
order: Optional[Literal["C"] | Literal["F"]] = None,
shape: Optional[tuple[int, ...]] = None,
zarr_format: Optional[Literal[2] | Literal[3]] = None,
) -> "ZArray":
"""
Convenience method to create a new ZArray from an existing one by altering only certain attributes.
"""
return ZArray(
chunks=chunks if chunks is not None else self.chunks,
compressor=compressor if compressor is not None else self.compressor,
dtype=dtype if dtype is not None else self.dtype,
fill_value=fill_value if fill_value is not None else self.fill_value,
filters=filters if filters is not None else self.filters,
shape=shape if shape is not None else self.shape,
order=order if order is not None else self.order,
zarr_format=zarr_format if zarr_format is not None else self.zarr_format,
)

def to_kerchunk_json(zarray: Array) -> str:
return ujson.dumps(zarray)


def encode_dtype(dtype: np.dtype) -> str:
Expand Down Expand Up @@ -216,11 +192,10 @@ def to_zarr_json(var: xr.Variable, array_dir: Path) -> None:
metadata_file.write(json_dumps(metadata))


def zarr_v3_array_metadata(zarray: ZArray, dim_names: list[str], attrs: dict) -> dict:
def zarr_v3_array_metadata(zarray: Array, dim_names: list[str], attrs: dict) -> dict:
"""Construct a v3-compliant metadata dict from v2 zarray + information stored on the xarray variable."""
# TODO it would be nice if we could use the zarr-python metadata.ArrayMetadata classes to do this conversion for us

metadata = zarray.dict()
metadata = zarray.metadata.to_dict()

# adjust to match v3 spec
metadata["zarr_format"] = 3
Expand Down

0 comments on commit 5e12b88

Please sign in to comment.