diff --git a/cebra/data/datasets.py b/cebra/data/datasets.py index 7ec5f2f..55a02c7 100644 --- a/cebra/data/datasets.py +++ b/cebra/data/datasets.py @@ -24,7 +24,7 @@ import abc import collections import types -from typing import List, Tuple, Union +from typing import List, Literal, Optional, Tuple, Union import literate_dataclasses as dataclasses import numpy as np @@ -67,17 +67,16 @@ class TensorDataset(cebra_data.SingleSessionDataset): """ - def __init__( - self, - neural: Union[torch.Tensor, npt.NDArray], - continuous: Union[torch.Tensor, npt.NDArray] = None, - discrete: Union[torch.Tensor, npt.NDArray] = None, - offset: int = 1, - device: str = "cpu" - ): + def __init__(self, + neural: Union[torch.Tensor, npt.NDArray], + continuous: Union[torch.Tensor, npt.NDArray] = None, + discrete: Union[torch.Tensor, npt.NDArray] = None, + offset: int = 1, + device: str = "cpu"): super().__init__(device=device) self.neural = self._to_tensor(neural, check_dtype="float").float() - self.continuous = self._to_tensor(continuous, check_dtype="float").float() + self.continuous = self._to_tensor(continuous, + check_dtype="float").float() self.discrete = self._to_tensor(discrete, check_dtype="integer") if self.continuous is None and self.discrete is None: raise ValueError( @@ -85,7 +84,11 @@ def __init__( ) self.offset = offset - def _to_tensor(self, array, check_dtype: str = None): + def _to_tensor( + self, + array: Union[torch.Tensor, npt.NDArray], + check_dtype: Optional[Literal["int", + "float"]] = None) -> torch.Tensor: """Convert :py:func:`numpy.array` to :py:class:`torch.Tensor` if necessary and check the dtype. Args: @@ -101,9 +104,9 @@ def _to_tensor(self, array, check_dtype: str = None): if isinstance(array, np.ndarray): array = torch.from_numpy(array) if check_dtype is not None: - if (check_dtype == "integer" and not cebra.helper._is_integer(array) - ) or (check_dtype == "float" and not cebra.helper._is_floating(array) - ) or (check_dtype == "float_integer" and not cebra.helper._is_floating_or_integer(array)): + if (check_dtype == "int" and not cebra.helper._is_integer(array) + ) or (check_dtype == "float" and + not cebra.helper._is_floating(array)): raise TypeError(f"{array.dtype} instead of {check_dtype}.") return array