Skip to content

Commit

Permalink
Implement pydantic models as dataclasses (zarr-developers#210)
Browse files Browse the repository at this point in the history
* Implement pydantic models as dataclasses

This removes our pydantic dependency by reimplementing them with
dataclasses. There are a few breaking changes:

1. The classes won't automatically cast the argumetns to the declared
   type. IMO, that's the preferable behavior. Some backwards
   compatability shims have been added for np.dtype, but perhpas we
   want to remove that too.
2. The models won't have any of the methods they previously inherited
   from pydantic.BaseModel. This is probably good for user-facing
   objects, we now have full control over the public API.
3. We had to reorder some of the fields on ZArray, since dataclasses
   is stricter about positional arguments. I've aligned the order
   with `zarr.create`.
  • Loading branch information
Tom Augspurger committed Aug 8, 2024
1 parent a28b210 commit fdab54c
Show file tree
Hide file tree
Showing 8 changed files with 89 additions and 89 deletions.
1 change: 0 additions & 1 deletion ci/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ dependencies:
- netcdf4
- xarray>=2024.6.0
- kerchunk>=0.2.5
- pydantic
- numpy>=2.0.0
- ujson
- packaging
Expand Down
4 changes: 4 additions & 0 deletions docs/releases.rst
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ Breaking changes

- Serialize valid ZarrV3 metadata and require full compressor numcodec config (for :pull:`193`)
By `Gustavo Hidalgo <https://github.com/ghidalgo3>`_.
- VirtualiZarr's `ZArray`, `ChunkEntry`, and `Codec` no longer subclass
`pydantic.BaseModel` (:pull:`210`)
- `ZArray`'s `__init__` signature has changed to match `zarr.Array`'s (:pull:`xxx`)


Deprecations
~~~~~~~~~~~~
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ dependencies = [
"xarray>=2024.06.0",
"kerchunk>=0.2.5",
"h5netcdf",
"pydantic",
"numpy>=2.0.0",
"ujson",
"packaging",
Expand Down
18 changes: 9 additions & 9 deletions virtualizarr/manifests/manifest.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
import dataclasses
import json
import re
from collections.abc import Iterable, Iterator
from typing import Any, Callable, Dict, NewType, Tuple, TypedDict, cast

import numpy as np
from pydantic import BaseModel, ConfigDict
from upath import UPath

from virtualizarr.types import ChunkKey
Expand All @@ -25,22 +25,18 @@ class ChunkDictEntry(TypedDict):
ChunkDict = NewType("ChunkDict", dict[ChunkKey, ChunkDictEntry])


class ChunkEntry(BaseModel):
@dataclasses.dataclass(frozen=True)
class ChunkEntry:
"""
Information for a single chunk in the manifest.
Stored in the form `{"path": "s3://bucket/foo.nc", "offset": 100, "length": 100}`.
"""

model_config = ConfigDict(frozen=True)

path: str # TODO stricter typing/validation of possible local / remote paths?
offset: int
length: int

def __repr__(self) -> str:
return f"ChunkEntry(path='{self.path}', offset={self.offset}, length={self.length})"

@classmethod
def from_kerchunk(
cls, path_and_byte_range_info: tuple[str] | tuple[str, int, int]
Expand All @@ -57,8 +53,12 @@ def to_kerchunk(self) -> tuple[str, int, int]:
"""Write out in the format that kerchunk uses for chunk entries."""
return (self.path, self.offset, self.length)

def dict(self) -> ChunkDictEntry: # type: ignore[override]
return ChunkDictEntry(path=self.path, offset=self.offset, length=self.length)
def dict(self) -> ChunkDictEntry:
return ChunkDictEntry(
path=self.path,
offset=self.offset,
length=self.length,
)


class ChunkManifest:
Expand Down
4 changes: 2 additions & 2 deletions virtualizarr/tests/test_kerchunk.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ def test_accessor_to_kerchunk_dict(self):
"refs": {
".zgroup": '{"zarr_format":2}',
".zattrs": "{}",
"a/.zarray": '{"chunks":[2,3],"compressor":null,"dtype":"<i8","fill_value":null,"filters":null,"order":"C","shape":[2,3],"zarr_format":2}',
"a/.zarray": '{"shape":[2,3],"chunks":[2,3],"dtype":"<i8","fill_value":null,"order":"C","compressor":null,"filters":null,"zarr_format":2}',
"a/.zattrs": '{"_ARRAY_DIMENSIONS":["x","y"]}',
"a/0.0": ["test.nc", 6144, 48],
},
Expand Down Expand Up @@ -133,7 +133,7 @@ def test_accessor_to_kerchunk_json(self, tmp_path):
"refs": {
".zgroup": '{"zarr_format":2}',
".zattrs": "{}",
"a/.zarray": '{"chunks":[2,3],"compressor":null,"dtype":"<i8","fill_value":null,"filters":null,"order":"C","shape":[2,3],"zarr_format":2}',
"a/.zarray": '{"shape":[2,3],"chunks":[2,3],"dtype":"<i8","fill_value":null,"order":"C","compressor":null,"filters":null,"zarr_format":2}',
"a/.zattrs": '{"_ARRAY_DIMENSIONS":["x","y"]}',
"a/0.0": ["test.nc", 6144, 48],
},
Expand Down
10 changes: 0 additions & 10 deletions virtualizarr/tests/test_manifests/test_manifest.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,16 +27,6 @@ def test_invalid_chunk_entries(self):
with pytest.raises(ValueError, match="must be of the form"):
ChunkManifest(entries=chunks)

chunks = {
"0.0.0": {
"path": "s3://bucket/foo.nc",
"offset": "some nonsense",
"length": 100,
},
}
with pytest.raises(ValueError, match="must be of the form"):
ChunkManifest(entries=chunks)

def test_invalid_chunk_keys(self):
chunks = {
"0.0.": {"path": "s3://bucket/foo.nc", "offset": 100, "length": 100},
Expand Down
28 changes: 27 additions & 1 deletion virtualizarr/tests/test_zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from virtualizarr import ManifestArray, open_virtual_dataset
from virtualizarr.kerchunk import FileType
from virtualizarr.manifests.manifest import ChunkManifest
from virtualizarr.zarr import dataset_to_zarr, metadata_from_zarr_json
from virtualizarr.zarr import ZArray, dataset_to_zarr, metadata_from_zarr_json


@pytest.fixture
Expand Down Expand Up @@ -79,3 +79,29 @@ def test_zarr_v3_metadata_conformance(tmpdir, vds_with_manifest_arrays: xr.Datas
and len(metadata["codecs"]) > 1
and all(isconfigurable(codec) for codec in metadata["codecs"])
)


def test_replace_partial():
arr = ZArray(shape=(2, 3), chunks=(1, 1), dtype=np.dtype("<i8"))
result = arr.replace(chunks=(2, 3))
expected = ZArray(shape=(2, 3), chunks=(2, 3), dtype=np.dtype("<i8"))
assert result == expected
assert result.shape == (2, 3)
assert result.chunks == (2, 3)


def test_replace_total():
arr = ZArray(shape=(2, 3), chunks=(1, 1), dtype=np.dtype("<i8"))
kwargs = dict(
shape=(4, 4),
chunks=(2, 2),
dtype=np.dtype("<f8"),
fill_value=-1.0,
order="F",
compressor={"id": "zlib", "level": 1},
filters=[{"id": "blosc", "clevel": 5}],
zarr_format=3,
)
result = arr.replace(**kwargs)
expected = ZArray(**kwargs)
assert result == expected
112 changes: 47 additions & 65 deletions virtualizarr/zarr.py
Original file line number Diff line number Diff line change
@@ -1,26 +1,12 @@
import dataclasses
import json
from pathlib import Path
from typing import (
TYPE_CHECKING,
Any,
Literal,
NewType,
Optional,
cast,
)
from typing import TYPE_CHECKING, Any, Literal, NewType, cast

import numcodecs
import numpy as np
import ujson # type: ignore
import xarray as xr
from pydantic import (
BaseModel,
ConfigDict,
Field,
field_validator,
model_validator,
)
from typing_extensions import Self

from virtualizarr.vendor.zarr.utils import json_dumps

Expand Down Expand Up @@ -50,38 +36,26 @@
"""


class Codec(BaseModel):
@dataclasses.dataclass
class Codec:
compressor: dict | None = None
filters: list[dict] | None = None

def __repr__(self) -> str:
return f"Codec(compressor={self.compressor}, filters={self.filters})"


class ZArray(BaseModel):
@dataclasses.dataclass
class ZArray:
"""Just the .zarray information"""

# TODO will this work for V3?

model_config = ConfigDict(
arbitrary_types_allowed=True, # only here so pydantic doesn't complain about the numpy dtype field
)

shape: tuple[int, ...]
chunks: tuple[int, ...]
compressor: dict | None = None
dtype: np.dtype
fill_value: FillValueT = Field(None, validate_default=True)
fill_value: FillValueT = dataclasses.field(default=None)
order: Literal["C", "F"] = "C"
compressor: dict | None = None
filters: list[dict] | None = None
order: Literal["C", "F"]
shape: tuple[int, ...]
zarr_format: ZARR_FORMAT = 2

@field_validator("dtype")
@classmethod
def validate_dtype(cls, dtype) -> np.dtype:
# Your custom validation logic here
# Convert numpy.dtype to a format suitable for Pydantic
return np.dtype(dtype)
zarr_format: Literal[2, 3] = 2

def __post_init__(self) -> None:
if len(self.shape) != len(self.chunks):
Expand All @@ -90,20 +64,18 @@ def __post_init__(self) -> None:
f"Array shape {self.shape} has ndim={self.shape} but chunk shape {self.chunks} has ndim={len(self.chunks)}"
)

@model_validator(mode="after")
def _check_fill_value(self) -> Self:
if isinstance(self.dtype, str):
# Convert dtype string to numpy.dtype
self.dtype = np.dtype(self.dtype)

if self.fill_value is None:
self.fill_value = ZARR_DEFAULT_FILL_VALUE.get(self.dtype.kind, 0.0)
return self

@property
def codec(self) -> Codec:
"""For comparison against other arrays."""
return Codec(compressor=self.compressor, filters=self.filters)

def __repr__(self) -> str:
return f"ZArray(shape={self.shape}, chunks={self.chunks}, dtype={self.dtype}, compressor={self.compressor}, filters={self.filters}, fill_value={self.fill_value})"

@classmethod
def from_kerchunk_refs(cls, decoded_arr_refs_zarray) -> "ZArray":
# coerce type of fill_value as kerchunk can be inconsistent with this
Expand All @@ -127,8 +99,8 @@ def from_kerchunk_refs(cls, decoded_arr_refs_zarray) -> "ZArray":
zarr_format=cast(ZARR_FORMAT, zarr_format),
)

def dict(self) -> dict[str, Any]: # type: ignore
zarray_dict = dict(self)
def dict(self) -> dict[str, Any]:
zarray_dict = dataclasses.asdict(self)
zarray_dict["dtype"] = encode_dtype(zarray_dict["dtype"])
return zarray_dict

Expand All @@ -138,30 +110,40 @@ def to_kerchunk_json(self) -> str:
zarray_dict["fill_value"] = None
return ujson.dumps(zarray_dict)

# ZArray.dict seems to shadow "dict", so we need the type ignore in
# the signature below.
def replace(
self,
chunks: Optional[tuple[int, ...]] = None,
compressor: Optional[dict] = None, # type: ignore[valid-type]
dtype: Optional[np.dtype] = None,
fill_value: Optional[float] = None, # float or int?
filters: Optional[list[dict]] = None, # type: ignore[valid-type]
order: Optional[Literal["C"] | Literal["F"]] = None,
shape: Optional[tuple[int, ...]] = None,
zarr_format: Optional[Literal[2] | Literal[3]] = None,
shape: tuple[int, ...] | None = None,
chunks: tuple[int, ...] | None = None,
dtype: np.dtype | str | None = None,
fill_value: FillValueT = None,
order: Literal["C", "F"] | None = None,
compressor: "dict | None" = None, # type: ignore[valid-type]
filters: list[dict] | None = None, # type: ignore[valid-type]
zarr_format: Literal[2, 3] | None = None,
) -> "ZArray":
"""
Convenience method to create a new ZArray from an existing one by altering only certain attributes.
"""
return ZArray(
chunks=chunks if chunks is not None else self.chunks,
compressor=compressor if compressor is not None else self.compressor,
dtype=dtype if dtype is not None else self.dtype,
fill_value=fill_value if fill_value is not None else self.fill_value,
filters=filters if filters is not None else self.filters,
shape=shape if shape is not None else self.shape,
order=order if order is not None else self.order,
zarr_format=zarr_format if zarr_format is not None else self.zarr_format,
)
replacements: dict[str, Any] = {}
if shape is not None:
replacements["shape"] = shape
if chunks is not None:
replacements["chunks"] = chunks
if dtype is not None:
replacements["dtype"] = dtype
if fill_value is not None:
replacements["fill_value"] = fill_value
if order is not None:
replacements["order"] = order
if compressor is not None:
replacements["compressor"] = compressor
if filters is not None:
replacements["filters"] = filters
if zarr_format is not None:
replacements["zarr_format"] = zarr_format
return dataclasses.replace(self, **replacements)

def _v3_codec_pipeline(self) -> list:
"""
Expand Down Expand Up @@ -361,8 +343,8 @@ def metadata_from_zarr_json(filepath: Path) -> tuple[ZArray, list[str], dict]:
attrs = metadata.pop("attributes")
dim_names = metadata.pop("dimension_names")

chunk_shape = metadata["chunk_grid"]["configuration"]["chunk_shape"]
shape = metadata["shape"]
chunk_shape = tuple(metadata["chunk_grid"]["configuration"]["chunk_shape"])
shape = tuple(metadata["shape"])
zarr_format = metadata["zarr_format"]

if metadata["fill_value"] is None:
Expand Down

0 comments on commit fdab54c

Please sign in to comment.