Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Soft import #9561

Open
wants to merge 9 commits into
base: main
Choose a base branch
from
8 changes: 4 additions & 4 deletions xarray/backends/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@
from xarray.core.dataset import Dataset, _get_chunk, _maybe_chunk
from xarray.core.indexes import Index
from xarray.core.types import NetcdfWriteModes, ZarrWriteModes
from xarray.core.utils import is_remote_uri
from xarray.core.utils import check_fsspec_installed, is_remote_uri
from xarray.namedarray.daskmanager import DaskManager
from xarray.namedarray.parallelcompat import guess_chunkmanager

Expand Down Expand Up @@ -1692,15 +1692,15 @@ def to_zarr(
mapper = store
chunk_mapper = chunk_store
else:
from fsspec import get_mapper
fsspec = check_fsspec_installed()

if not isinstance(store, str):
raise ValueError(
f"store must be a string to use storage_options. Got {type(store)}"
)
mapper = get_mapper(store, **storage_options)
mapper = fsspec.get_mapper(store, **storage_options)
if chunk_store is not None:
chunk_mapper = get_mapper(chunk_store, **storage_options)
chunk_mapper = fsspec.get_mapper(chunk_store, **storage_options)
else:
chunk_mapper = chunk_store

Expand Down
18 changes: 9 additions & 9 deletions xarray/backends/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@

from xarray.conventions import cf_encoder
from xarray.core import indexing
from xarray.core.utils import FrozenDict, NdimSizeLenMixin, is_remote_uri
from xarray.core.utils import (
FrozenDict,
NdimSizeLenMixin,
check_fsspec_installed,
is_remote_uri,
)
from xarray.namedarray.parallelcompat import get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array

Expand Down Expand Up @@ -83,14 +88,9 @@ def _find_absolute_paths(
"""
if isinstance(paths, str):
if is_remote_uri(paths) and kwargs.get("engine", None) == "zarr":
try:
from fsspec.core import get_fs_token_paths
except ImportError as e:
raise ImportError(
"The use of remote URLs for opening zarr requires the package fsspec"
) from e

fs, _, _ = get_fs_token_paths(
fsspec = check_fsspec_installed()

fs, _, _ = fsspec.core.get_fs_token_paths(
paths,
mode="rb",
storage_options=kwargs.get("backend_kwargs", {}).get(
Expand Down
5 changes: 3 additions & 2 deletions xarray/backends/zarr.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from xarray.core.utils import (
FrozenDict,
HiddenKeyDict,
check_zarr_installed,
close_on_error,
)
from xarray.core.variable import Variable
Expand Down Expand Up @@ -634,7 +635,7 @@ def store(
dimension on which the zarray will be appended
only needed in append mode
"""
import zarr
zarr = check_zarr_installed()

existing_keys = tuple(self.zarr_group.array_keys())

Expand Down Expand Up @@ -1317,7 +1318,7 @@ def _get_open_params(
stacklevel,
zarr_version,
):
import zarr
zarr = check_zarr_installed()

# zarr doesn't support pathlib.Path objects yet. zarr-python#601
if isinstance(store, os.PathLike):
Expand Down
61 changes: 25 additions & 36 deletions xarray/coding/cftime_offsets.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,13 +67,9 @@
nanosecond_precision_timestamp,
no_default,
)
from xarray.core.utils import emit_user_level_warning

try:
import cftime
except ImportError:
cftime = None
from xarray.core.utils import check_cftime_installed, emit_user_level_warning

cftime = check_cftime_installed(strict=False)

if TYPE_CHECKING:
from xarray.core.types import InclusiveOptions, Self, SideOptions, TypeAlias
Expand All @@ -93,24 +89,23 @@ def _nanosecond_precision_timestamp(*args, **kwargs):

def get_date_type(calendar, use_cftime=True):
"""Return the cftime date type for a given calendar name."""
if cftime is None:
raise ImportError("cftime is required for dates with non-standard calendars")
else:
if _is_standard_calendar(calendar) and not use_cftime:
return _nanosecond_precision_timestamp

calendars = {
"noleap": cftime.DatetimeNoLeap,
"360_day": cftime.Datetime360Day,
"365_day": cftime.DatetimeNoLeap,
"366_day": cftime.DatetimeAllLeap,
"gregorian": cftime.DatetimeGregorian,
"proleptic_gregorian": cftime.DatetimeProlepticGregorian,
"julian": cftime.DatetimeJulian,
"all_leap": cftime.DatetimeAllLeap,
"standard": cftime.DatetimeGregorian,
}
return calendars[calendar]
cftime = check_cftime_installed()

if _is_standard_calendar(calendar) and not use_cftime:
return _nanosecond_precision_timestamp

calendars = {
"noleap": cftime.DatetimeNoLeap,
"360_day": cftime.Datetime360Day,
"365_day": cftime.DatetimeNoLeap,
"366_day": cftime.DatetimeAllLeap,
"gregorian": cftime.DatetimeGregorian,
"proleptic_gregorian": cftime.DatetimeProlepticGregorian,
"julian": cftime.DatetimeJulian,
"all_leap": cftime.DatetimeAllLeap,
"standard": cftime.DatetimeGregorian,
}
return calendars[calendar]


class BaseCFTimeOffset:
Expand Down Expand Up @@ -141,8 +136,7 @@ def __add__(self, other):
return self.__apply__(other)

def __sub__(self, other):
if cftime is None:
raise ModuleNotFoundError("No module named 'cftime'")
cftime = check_cftime_installed()

if isinstance(other, cftime.datetime):
raise TypeError("Cannot subtract a cftime.datetime from a time offset.")
Expand Down Expand Up @@ -293,8 +287,7 @@ def _adjust_n_years(other, n, month, reference_day):

def _shift_month(date, months, day_option: DayOption = "start"):
"""Shift the date to a month start or end a given number of months away."""
if cftime is None:
raise ModuleNotFoundError("No module named 'cftime'")
_ = check_cftime_installed()

has_year_zero = date.has_year_zero
delta_year = (date.month + months) // 12
Expand Down Expand Up @@ -458,8 +451,7 @@ def onOffset(self, date) -> bool:
return mod_month == 0 and date.day == self._get_offset_day(date)

def __sub__(self, other: Self) -> Self:
if cftime is None:
raise ModuleNotFoundError("No module named 'cftime'")
cftime = check_cftime_installed()

if isinstance(other, cftime.datetime):
raise TypeError("Cannot subtract cftime.datetime from offset.")
Expand Down Expand Up @@ -544,8 +536,7 @@ def __apply__(self, other):
return _shift_month(other, months, self._day_option)

def __sub__(self, other):
if cftime is None:
raise ModuleNotFoundError("No module named 'cftime'")
cftime = check_cftime_installed()

if isinstance(other, cftime.datetime):
raise TypeError("Cannot subtract cftime.datetime from offset.")
Expand Down Expand Up @@ -828,8 +819,7 @@ def delta_to_tick(delta: timedelta | pd.Timedelta) -> Tick:


def to_cftime_datetime(date_str_or_date, calendar=None):
if cftime is None:
raise ModuleNotFoundError("No module named 'cftime'")
cftime = check_cftime_installed()

if isinstance(date_str_or_date, str):
if calendar is None:
Expand Down Expand Up @@ -867,8 +857,7 @@ def _maybe_normalize_date(date, normalize):
def _generate_linear_range(start, end, periods):
"""Generate an equally-spaced sequence of cftime.datetime objects between
and including two dates (whose length equals the number of periods)."""
if cftime is None:
raise ModuleNotFoundError("No module named 'cftime'")
cftime = check_cftime_installed()

total_seconds = (end - start).total_seconds()
values = np.linspace(0.0, total_seconds, periods, endpoint=True)
Expand Down
17 changes: 6 additions & 11 deletions xarray/coding/cftimeindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,12 +58,9 @@
)
from xarray.core.common import _contains_cftime_datetimes
from xarray.core.options import OPTIONS
from xarray.core.utils import is_scalar
from xarray.core.utils import check_cftime_installed, is_scalar

try:
import cftime
except ImportError:
cftime = None
cftime = check_cftime_installed(strict=False)

if TYPE_CHECKING:
from xarray.coding.cftime_offsets import BaseCFTimeOffset
Expand Down Expand Up @@ -130,8 +127,7 @@ def parse_iso8601_like(datetime_string):


def _parse_iso8601_with_reso(date_type, timestr):
if cftime is None:
raise ModuleNotFoundError("No module named 'cftime'")
_ = check_cftime_installed()

default = date_type(1, 1, 1)
result = parse_iso8601_like(timestr)
Expand Down Expand Up @@ -200,8 +196,7 @@ def _field_accessor(name, docstring=None, min_cftime_version="0.0"):
"""Adapted from pandas.tseries.index._field_accessor"""

def f(self, min_cftime_version=min_cftime_version):
if cftime is None:
raise ModuleNotFoundError("No module named 'cftime'")
cftime = check_cftime_installed()

if Version(cftime.__version__) >= Version(min_cftime_version):
return get_date_field(self._data, name)
Expand All @@ -225,8 +220,7 @@ def get_date_type(self):


def assert_all_valid_date_type(data):
if cftime is None:
raise ModuleNotFoundError("No module named 'cftime'")
cftime = check_cftime_installed()

if len(data) > 0:
sample = data[0]
Expand Down Expand Up @@ -803,6 +797,7 @@ def round(self, freq):

@property
def is_leap_year(self):
cftime = check_cftime_installed()
func = np.vectorize(cftime.is_leap_year)
return func(self.year, calendar=self.calendar)

Expand Down
8 changes: 3 additions & 5 deletions xarray/coding/times.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from xarray.core.duck_array_ops import asarray, ravel, reshape
from xarray.core.formatting import first_n_items, format_timestamp, last_item
from xarray.core.pdcompat import nanosecond_precision_timestamp
from xarray.core.utils import emit_user_level_warning
from xarray.core.utils import check_cftime_installed, emit_user_level_warning
from xarray.core.variable import Variable
from xarray.namedarray.parallelcompat import T_ChunkedArray, get_chunked_array_type
from xarray.namedarray.pycompat import is_chunked_array
Expand Down Expand Up @@ -235,8 +235,7 @@ def _decode_cf_datetime_dtype(
def _decode_datetime_with_cftime(
num_dates: np.ndarray, units: str, calendar: str
) -> np.ndarray:
if cftime is None:
raise ModuleNotFoundError("No module named 'cftime'")
cftime = check_cftime_installed()
if num_dates.size > 0:
return np.asarray(
cftime.num2date(num_dates, units, calendar, only_use_cftime_datetimes=True)
Expand Down Expand Up @@ -625,8 +624,7 @@ def _encode_datetime_with_cftime(dates, units: str, calendar: str) -> np.ndarray
This method is more flexible than xarray's parsing using datetime64[ns]
arrays but also slower because it loops over each element.
"""
if cftime is None:
raise ModuleNotFoundError("No module named 'cftime'")
cftime = check_cftime_installed()

if np.issubdtype(dates.dtype, np.datetime64):
# numpy's broken datetime conversion only works for us precision
Expand Down
62 changes: 61 additions & 1 deletion xarray/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

import contextlib
import functools
import importlib
import inspect
import io
import itertools
Expand All @@ -63,7 +64,7 @@
)
from enum import Enum
from pathlib import Path
from types import EllipsisType
from types import EllipsisType, ModuleType
from typing import TYPE_CHECKING, Any, Generic, Literal, TypeGuard, TypeVar, overload

import numpy as np
Expand Down Expand Up @@ -1146,3 +1147,62 @@ def _resolve_doubly_passed_kwarg(
)

return kwargs_dict


def soft_import(
name: str,
*,
purpose: str,
strict: bool = True,
) -> ModuleType | None:
"""Import optional dependencies, providing informative errors on failure.

Parameters
----------
name : str
The name of the module to import. For example, ``'matplotlib'``.
purpose : str
A very brief statement explaining why the package is needed.
For example, ``'plotting'``.
strict : bool
If ``True``, raise an ImportError if the package is not found. If ``False``,
return ``None`` if the package is not found. Default is ``True``.

Returns
-------
module | None
The imported module, or ``None`` if the package is not found and strict=False.
"""
install_mapping = {
"matplotlib.pyplot": "matplotlib",
"hypothesis.strategies": "hypothesis",
"nc_time_axis": "nc-time-axis",
}
package_name = install_mapping.get(name, name)

if module_available(name):
return importlib.import_module(name)
if strict:
raise ImportError(
f"For {purpose}, {package_name} is required. "
f"Please install it via pip or conda."
)
return


def check_fsspec_installed(strict=True):
"""Import fsspec if available, otherwise raise an ImportError."""
purpose = "opening Zarr stores with remote URLs"
return soft_import("fsspec", purpose=purpose, strict=strict)


def check_cftime_installed(strict=True):
"""Import cftime if available, otherwise raise an ImportError."""
purpose = "working with dates with non-standard calendars"
return soft_import("cftime", purpose=purpose, strict=strict)


def check_zarr_installed(strict=True):
"""Import zarr if available, otherwise raise an ImportError."""
purpose = "working with Zarr stores"
return soft_import("zarr", purpose=purpose, strict=strict)
Loading