From 7467b1ec242924fc3257986cab813a4db22d8517 Mon Sep 17 00:00:00 2001 From: Deepak Cherian Date: Thu, 31 Oct 2024 09:11:22 -0700 Subject: [PATCH] Refactor out utility functions from to_zarr (#9695) * Refactor out utility functions from to_zarr * Use `emit_user_level_warning` instead of explicit `stacklevel` * tiny reordering * Some more * comment --- xarray/backends/api.py | 83 ++++++--------------------------- xarray/backends/zarr.py | 101 +++++++++++++++++++++++++++++++++------- 2 files changed, 97 insertions(+), 87 deletions(-) diff --git a/xarray/backends/api.py b/xarray/backends/api.py index a52e73701ab..9147f750330 100644 --- a/xarray/backends/api.py +++ b/xarray/backends/api.py @@ -33,7 +33,6 @@ _normalize_path, ) from xarray.backends.locks import _get_scheduler -from xarray.backends.zarr import _zarr_v3 from xarray.core import indexing from xarray.core.combine import ( _infer_concat_order_from_positions, @@ -2131,66 +2130,25 @@ def to_zarr( See `Dataset.to_zarr` for full API docs. """ + from xarray.backends.zarr import _choose_default_mode, _get_mappers + + # validate Dataset keys, DataArray names + _validate_dataset_names(dataset) # Load empty arrays to avoid bug saving zero length dimensions (Issue #5741) + # TODO: delete when min dask>=2023.12.1 + # https://github.com/dask/dask/pull/10506 for v in dataset.variables.values(): if v.size == 0: v.load() - # expand str and path-like arguments - store = _normalize_path(store) - chunk_store = _normalize_path(chunk_store) - - kwargs = {} - if storage_options is None: - mapper = store - chunk_mapper = chunk_store - else: - if not isinstance(store, str): - raise ValueError( - f"store must be a string to use storage_options. Got {type(store)}" - ) - - if _zarr_v3(): - kwargs["storage_options"] = storage_options - mapper = store - chunk_mapper = chunk_store - else: - from fsspec import get_mapper - - mapper = get_mapper(store, **storage_options) - if chunk_store is not None: - chunk_mapper = get_mapper(chunk_store, **storage_options) - else: - chunk_mapper = chunk_store - if encoding is None: encoding = {} - if mode is None: - if append_dim is not None: - mode = "a" - elif region is not None: - mode = "r+" - else: - mode = "w-" - - if mode not in ["a", "a-"] and append_dim is not None: - raise ValueError("cannot set append_dim unless mode='a' or mode=None") - - if mode not in ["a", "a-", "r+"] and region is not None: - raise ValueError( - "cannot set region unless mode='a', mode='a-', mode='r+' or mode=None" - ) - - if mode not in ["w", "w-", "a", "a-", "r+"]: - raise ValueError( - "The only supported options for mode are 'w', " - f"'w-', 'a', 'a-', and 'r+', but mode={mode!r}" - ) - - # validate Dataset keys, DataArray names - _validate_dataset_names(dataset) + kwargs, mapper, chunk_mapper = _get_mappers( + storage_options=storage_options, store=store, chunk_store=chunk_store + ) + mode = _choose_default_mode(mode=mode, append_dim=append_dim, region=region) if mode == "r+": already_consolidated = consolidated @@ -2198,6 +2156,7 @@ def to_zarr( else: already_consolidated = False consolidate_on_close = consolidated or consolidated is None + zstore = backends.ZarrStore.open_group( store=mapper, mode=mode, @@ -2209,30 +2168,14 @@ def to_zarr( append_dim=append_dim, write_region=region, safe_chunks=safe_chunks, - stacklevel=4, # for Dataset.to_zarr() zarr_version=zarr_version, zarr_format=zarr_format, write_empty=write_empty_chunks, **kwargs, ) - if region is not None: - zstore._validate_and_autodetect_region(dataset) - # can't modify indexes with region writes - dataset = dataset.drop_vars(dataset.indexes) - if append_dim is not None and append_dim in region: - raise ValueError( - f"cannot list the same dimension in both ``append_dim`` and " - f"``region`` with to_zarr(), got {append_dim} in both" - ) - - if encoding and mode in ["a", "a-", "r+"]: - existing_var_names = set(zstore.zarr_group.array_keys()) - for var_name in existing_var_names: - if var_name in encoding: - raise ValueError( - f"variable {var_name!r} already exists, but encoding was provided" - ) + dataset = zstore._validate_and_autodetect_region(dataset) + zstore._validate_encoding(encoding) writer = ArrayWriter() # TODO: figure out how to properly handle unlimited_dims diff --git a/xarray/backends/zarr.py b/xarray/backends/zarr.py index 5090ec7728e..bed7d84d60d 100644 --- a/xarray/backends/zarr.py +++ b/xarray/backends/zarr.py @@ -4,8 +4,7 @@ import json import os import struct -import warnings -from collections.abc import Iterable +from collections.abc import Hashable, Iterable, Mapping from typing import TYPE_CHECKING, Any, Literal import numpy as np @@ -46,6 +45,66 @@ from xarray.core.datatree import DataTree +def _get_mappers(*, storage_options, store, chunk_store): + # expand str and path-like arguments + store = _normalize_path(store) + chunk_store = _normalize_path(chunk_store) + + kwargs = {} + if storage_options is None: + mapper = store + chunk_mapper = chunk_store + else: + if not isinstance(store, str): + raise ValueError( + f"store must be a string to use storage_options. Got {type(store)}" + ) + + if _zarr_v3(): + kwargs["storage_options"] = storage_options + mapper = store + chunk_mapper = chunk_store + else: + from fsspec import get_mapper + + mapper = get_mapper(store, **storage_options) + if chunk_store is not None: + chunk_mapper = get_mapper(chunk_store, **storage_options) + else: + chunk_mapper = chunk_store + return kwargs, mapper, chunk_mapper + + +def _choose_default_mode( + *, + mode: ZarrWriteModes | None, + append_dim: Hashable | None, + region: Mapping[str, slice | Literal["auto"]] | Literal["auto"] | None, +) -> ZarrWriteModes: + if mode is None: + if append_dim is not None: + mode = "a" + elif region is not None: + mode = "r+" + else: + mode = "w-" + + if mode not in ["a", "a-"] and append_dim is not None: + raise ValueError("cannot set append_dim unless mode='a' or mode=None") + + if mode not in ["a", "a-", "r+"] and region is not None: + raise ValueError( + "cannot set region unless mode='a', mode='a-', mode='r+' or mode=None" + ) + + if mode not in ["w", "w-", "a", "a-", "r+"]: + raise ValueError( + "The only supported options for mode are 'w', " + f"'w-', 'a', 'a-', and 'r+', but mode={mode!r}" + ) + return mode + + def _zarr_v3() -> bool: # TODO: switch to "3" once Zarr V3 is released return module_available("zarr", minversion="2.99") @@ -567,7 +626,6 @@ def open_store( append_dim=None, write_region=None, safe_chunks=True, - stacklevel=2, zarr_version=None, zarr_format=None, use_zarr_fill_value_as_mask=None, @@ -587,7 +645,6 @@ def open_store( consolidate_on_close=consolidate_on_close, chunk_store=chunk_store, storage_options=storage_options, - stacklevel=stacklevel, zarr_version=zarr_version, use_zarr_fill_value_as_mask=use_zarr_fill_value_as_mask, zarr_format=zarr_format, @@ -622,7 +679,6 @@ def open_group( append_dim=None, write_region=None, safe_chunks=True, - stacklevel=2, zarr_version=None, zarr_format=None, use_zarr_fill_value_as_mask=None, @@ -642,7 +698,6 @@ def open_group( consolidate_on_close=consolidate_on_close, chunk_store=chunk_store, storage_options=storage_options, - stacklevel=stacklevel, zarr_version=zarr_version, use_zarr_fill_value_as_mask=use_zarr_fill_value_as_mask, zarr_format=zarr_format, @@ -1105,7 +1160,10 @@ def _auto_detect_regions(self, ds, region): region[dim] = slice(idxs[0], idxs[-1] + 1) return region - def _validate_and_autodetect_region(self, ds) -> None: + def _validate_and_autodetect_region(self, ds: Dataset) -> Dataset: + if self._write_region is None: + return ds + region = self._write_region if region == "auto": @@ -1153,8 +1211,26 @@ def _validate_and_autodetect_region(self, ds) -> None: f".drop_vars({non_matching_vars!r})" ) + if self._append_dim is not None and self._append_dim in region: + raise ValueError( + f"cannot list the same dimension in both ``append_dim`` and " + f"``region`` with to_zarr(), got {self._append_dim} in both" + ) + self._write_region = region + # can't modify indexes with region writes + return ds.drop_vars(ds.indexes) + + def _validate_encoding(self, encoding) -> None: + if encoding and self._mode in ["a", "a-", "r+"]: + existing_var_names = set(self.zarr_group.array_keys()) + for var_name in existing_var_names: + if var_name in encoding: + raise ValueError( + f"variable {var_name!r} already exists, but encoding was provided" + ) + def open_zarr( store, @@ -1329,7 +1405,6 @@ def open_zarr( "overwrite_encoded_chunks": overwrite_encoded_chunks, "chunk_store": chunk_store, "storage_options": storage_options, - "stacklevel": 4, "zarr_version": zarr_version, "zarr_format": zarr_format, } @@ -1398,7 +1473,6 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti consolidated=None, chunk_store=None, storage_options=None, - stacklevel=3, zarr_version=None, zarr_format=None, store=None, @@ -1416,7 +1490,6 @@ def open_dataset( # type: ignore[override] # allow LSP violation, not supporti consolidate_on_close=False, chunk_store=chunk_store, storage_options=storage_options, - stacklevel=stacklevel + 1, zarr_version=zarr_version, use_zarr_fill_value_as_mask=None, zarr_format=zarr_format, @@ -1453,7 +1526,6 @@ def open_datatree( consolidated=None, chunk_store=None, storage_options=None, - stacklevel=3, zarr_version=None, zarr_format=None, **kwargs, @@ -1474,7 +1546,6 @@ def open_datatree( consolidated=consolidated, chunk_store=chunk_store, storage_options=storage_options, - stacklevel=stacklevel, zarr_version=zarr_version, zarr_format=zarr_format, **kwargs, @@ -1499,7 +1570,6 @@ def open_groups_as_dict( consolidated=None, chunk_store=None, storage_options=None, - stacklevel=3, zarr_version=None, zarr_format=None, **kwargs, @@ -1523,7 +1593,6 @@ def open_groups_as_dict( consolidate_on_close=False, chunk_store=chunk_store, storage_options=storage_options, - stacklevel=stacklevel + 1, zarr_version=zarr_version, zarr_format=zarr_format, ) @@ -1569,7 +1638,6 @@ def _get_open_params( consolidate_on_close, chunk_store, storage_options, - stacklevel, zarr_version, use_zarr_fill_value_as_mask, zarr_format, @@ -1614,7 +1682,7 @@ def _get_open_params( # ValueError in zarr-python 3.x, KeyError in 2.x. try: zarr_group = zarr.open_group(store, **open_kwargs) - warnings.warn( + emit_user_level_warning( "Failed to open Zarr store with consolidated metadata, " "but successfully read with non-consolidated metadata. " "This is typically much slower for opening a dataset. " @@ -1627,7 +1695,6 @@ def _get_open_params( "error in this case instead of falling back to try " "reading non-consolidated metadata.", RuntimeWarning, - stacklevel=stacklevel, ) except missing_exc as err: raise FileNotFoundError(