Skip to content

Commit

Permalink
Refactor out utility functions from to_zarr (#9695)
Browse files Browse the repository at this point in the history
* Refactor out utility functions from to_zarr

* Use `emit_user_level_warning` instead of explicit `stacklevel`

* tiny reordering

* Some more

* comment
  • Loading branch information
dcherian authored Oct 31, 2024
1 parent cdec18f commit 7467b1e
Show file tree
Hide file tree
Showing 2 changed files with 97 additions and 87 deletions.
83 changes: 13 additions & 70 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
_normalize_path,
)
from xarray.backends.locks import _get_scheduler
from xarray.backends.zarr import _zarr_v3
from xarray.core import indexing
from xarray.core.combine import (
_infer_concat_order_from_positions,
Expand Down Expand Up @@ -2131,73 +2130,33 @@ def to_zarr(
See `Dataset.to_zarr` for full API docs.
"""
from xarray.backends.zarr import _choose_default_mode, _get_mappers

# validate Dataset keys, DataArray names
_validate_dataset_names(dataset)

# Load empty arrays to avoid bug saving zero length dimensions (Issue #5741)
# TODO: delete when min dask>=2023.12.1
# https://github.com/dask/dask/pull/10506
for v in dataset.variables.values():
if v.size == 0:
v.load()

# expand str and path-like arguments
store = _normalize_path(store)
chunk_store = _normalize_path(chunk_store)

kwargs = {}
if storage_options is None:
mapper = store
chunk_mapper = chunk_store
else:
if not isinstance(store, str):
raise ValueError(
f"store must be a string to use storage_options. Got {type(store)}"
)

if _zarr_v3():
kwargs["storage_options"] = storage_options
mapper = store
chunk_mapper = chunk_store
else:
from fsspec import get_mapper

mapper = get_mapper(store, **storage_options)
if chunk_store is not None:
chunk_mapper = get_mapper(chunk_store, **storage_options)
else:
chunk_mapper = chunk_store

if encoding is None:
encoding = {}

if mode is None:
if append_dim is not None:
mode = "a"
elif region is not None:
mode = "r+"
else:
mode = "w-"

if mode not in ["a", "a-"] and append_dim is not None:
raise ValueError("cannot set append_dim unless mode='a' or mode=None")

if mode not in ["a", "a-", "r+"] and region is not None:
raise ValueError(
"cannot set region unless mode='a', mode='a-', mode='r+' or mode=None"
)

if mode not in ["w", "w-", "a", "a-", "r+"]:
raise ValueError(
"The only supported options for mode are 'w', "
f"'w-', 'a', 'a-', and 'r+', but mode={mode!r}"
)

# validate Dataset keys, DataArray names
_validate_dataset_names(dataset)
kwargs, mapper, chunk_mapper = _get_mappers(
storage_options=storage_options, store=store, chunk_store=chunk_store
)
mode = _choose_default_mode(mode=mode, append_dim=append_dim, region=region)

if mode == "r+":
already_consolidated = consolidated
consolidate_on_close = False
else:
already_consolidated = False
consolidate_on_close = consolidated or consolidated is None

zstore = backends.ZarrStore.open_group(
store=mapper,
mode=mode,
Expand All @@ -2209,30 +2168,14 @@ def to_zarr(
append_dim=append_dim,
write_region=region,
safe_chunks=safe_chunks,
stacklevel=4, # for Dataset.to_zarr()
zarr_version=zarr_version,
zarr_format=zarr_format,
write_empty=write_empty_chunks,
**kwargs,
)

if region is not None:
zstore._validate_and_autodetect_region(dataset)
# can't modify indexes with region writes
dataset = dataset.drop_vars(dataset.indexes)
if append_dim is not None and append_dim in region:
raise ValueError(
f"cannot list the same dimension in both ``append_dim`` and "
f"``region`` with to_zarr(), got {append_dim} in both"
)

if encoding and mode in ["a", "a-", "r+"]:
existing_var_names = set(zstore.zarr_group.array_keys())
for var_name in existing_var_names:
if var_name in encoding:
raise ValueError(
f"variable {var_name!r} already exists, but encoding was provided"
)
dataset = zstore._validate_and_autodetect_region(dataset)
zstore._validate_encoding(encoding)

writer = ArrayWriter()
# TODO: figure out how to properly handle unlimited_dims
Expand Down
101 changes: 84 additions & 17 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,7 @@
import json
import os
import struct
import warnings
from collections.abc import Iterable
from collections.abc import Hashable, Iterable, Mapping
from typing import TYPE_CHECKING, Any, Literal

import numpy as np
Expand Down Expand Up @@ -46,6 +45,66 @@
from xarray.core.datatree import DataTree


def _get_mappers(*, storage_options, store, chunk_store):
# expand str and path-like arguments
store = _normalize_path(store)
chunk_store = _normalize_path(chunk_store)

kwargs = {}
if storage_options is None:
mapper = store
chunk_mapper = chunk_store
else:
if not isinstance(store, str):
raise ValueError(
f"store must be a string to use storage_options. Got {type(store)}"
)

if _zarr_v3():
kwargs["storage_options"] = storage_options
mapper = store
chunk_mapper = chunk_store
else:
from fsspec import get_mapper

mapper = get_mapper(store, **storage_options)
if chunk_store is not None:
chunk_mapper = get_mapper(chunk_store, **storage_options)
else:
chunk_mapper = chunk_store
return kwargs, mapper, chunk_mapper


def _choose_default_mode(
*,
mode: ZarrWriteModes | None,
append_dim: Hashable | None,
region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None,
) -> ZarrWriteModes:
if mode is None:
if append_dim is not None:
mode = "a"
elif region is not None:
mode = "r+"
else:
mode = "w-"

if mode not in ["a", "a-"] and append_dim is not None:
raise ValueError("cannot set append_dim unless mode='a' or mode=None")

if mode not in ["a", "a-", "r+"] and region is not None:
raise ValueError(
"cannot set region unless mode='a', mode='a-', mode='r+' or mode=None"
)

if mode not in ["w", "w-", "a", "a-", "r+"]:
raise ValueError(
"The only supported options for mode are 'w', "
f"'w-', 'a', 'a-', and 'r+', but mode={mode!r}"
)
return mode


def _zarr_v3() -> bool:
# TODO: switch to "3" once Zarr V3 is released
return module_available("zarr", minversion="2.99")
Expand Down Expand Up @@ -567,7 +626,6 @@ def open_store(
append_dim=None,
write_region=None,
safe_chunks=True,
stacklevel=2,
zarr_version=None,
zarr_format=None,
use_zarr_fill_value_as_mask=None,
Expand All @@ -587,7 +645,6 @@ def open_store(
consolidate_on_close=consolidate_on_close,
chunk_store=chunk_store,
storage_options=storage_options,
stacklevel=stacklevel,
zarr_version=zarr_version,
use_zarr_fill_value_as_mask=use_zarr_fill_value_as_mask,
zarr_format=zarr_format,
Expand Down Expand Up @@ -622,7 +679,6 @@ def open_group(
append_dim=None,
write_region=None,
safe_chunks=True,
stacklevel=2,
zarr_version=None,
zarr_format=None,
use_zarr_fill_value_as_mask=None,
Expand All @@ -642,7 +698,6 @@ def open_group(
consolidate_on_close=consolidate_on_close,
chunk_store=chunk_store,
storage_options=storage_options,
stacklevel=stacklevel,
zarr_version=zarr_version,
use_zarr_fill_value_as_mask=use_zarr_fill_value_as_mask,
zarr_format=zarr_format,
Expand Down Expand Up @@ -1105,7 +1160,10 @@ def _auto_detect_regions(self, ds, region):
region[dim] = slice(idxs[0], idxs[-1] + 1)
return region

def _validate_and_autodetect_region(self, ds) -> None:
def _validate_and_autodetect_region(self, ds: Dataset) -> Dataset:
if self._write_region is None:
return ds

region = self._write_region

if region == "auto":
Expand Down Expand Up @@ -1153,8 +1211,26 @@ def _validate_and_autodetect_region(self, ds) -> None:
f".drop_vars({non_matching_vars!r})"
)

if self._append_dim is not None and self._append_dim in region:
raise ValueError(
f"cannot list the same dimension in both ``append_dim`` and "
f"``region`` with to_zarr(), got {self._append_dim} in both"
)

self._write_region = region

# can't modify indexes with region writes
return ds.drop_vars(ds.indexes)

def _validate_encoding(self, encoding) -> None:
if encoding and self._mode in ["a", "a-", "r+"]:
existing_var_names = set(self.zarr_group.array_keys())
for var_name in existing_var_names:
if var_name in encoding:
raise ValueError(
f"variable {var_name!r} already exists, but encoding was provided"
)


def open_zarr(
store,
Expand Down Expand Up @@ -1329,7 +1405,6 @@ def open_zarr(
"overwrite_encoded_chunks": overwrite_encoded_chunks,
"chunk_store": chunk_store,
"storage_options": storage_options,
"stacklevel": 4,
"zarr_version": zarr_version,
"zarr_format": zarr_format,
}
Expand Down Expand Up @@ -1398,7 +1473,6 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
consolidated=None,
chunk_store=None,
storage_options=None,
stacklevel=3,
zarr_version=None,
zarr_format=None,
store=None,
Expand All @@ -1416,7 +1490,6 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti
consolidate_on_close=False,
chunk_store=chunk_store,
storage_options=storage_options,
stacklevel=stacklevel + 1,
zarr_version=zarr_version,
use_zarr_fill_value_as_mask=None,
zarr_format=zarr_format,
Expand Down Expand Up @@ -1453,7 +1526,6 @@ def open_datatree(
consolidated=None,
chunk_store=None,
storage_options=None,
stacklevel=3,
zarr_version=None,
zarr_format=None,
**kwargs,
Expand All @@ -1474,7 +1546,6 @@ def open_datatree(
consolidated=consolidated,
chunk_store=chunk_store,
storage_options=storage_options,
stacklevel=stacklevel,
zarr_version=zarr_version,
zarr_format=zarr_format,
**kwargs,
Expand All @@ -1499,7 +1570,6 @@ def open_groups_as_dict(
consolidated=None,
chunk_store=None,
storage_options=None,
stacklevel=3,
zarr_version=None,
zarr_format=None,
**kwargs,
Expand All @@ -1523,7 +1593,6 @@ def open_groups_as_dict(
consolidate_on_close=False,
chunk_store=chunk_store,
storage_options=storage_options,
stacklevel=stacklevel + 1,
zarr_version=zarr_version,
zarr_format=zarr_format,
)
Expand Down Expand Up @@ -1569,7 +1638,6 @@ def _get_open_params(
consolidate_on_close,
chunk_store,
storage_options,
stacklevel,
zarr_version,
use_zarr_fill_value_as_mask,
zarr_format,
Expand Down Expand Up @@ -1614,7 +1682,7 @@ def _get_open_params(
# ValueError in zarr-python 3.x, KeyError in 2.x.
try:
zarr_group = zarr.open_group(store, **open_kwargs)
warnings.warn(
emit_user_level_warning(
"Failed to open Zarr store with consolidated metadata, "
"but successfully read with non-consolidated metadata. "
"This is typically much slower for opening a dataset. "
Expand All @@ -1627,7 +1695,6 @@ def _get_open_params(
"error in this case instead of falling back to try "
"reading non-consolidated metadata.",
RuntimeWarning,
stacklevel=stacklevel,
)
except missing_exc as err:
raise FileNotFoundError(
Expand Down

0 comments on commit 7467b1e

Please sign in to comment.