Skip to content

Commit

Permalink
Add better typing
Browse files Browse the repository at this point in the history
  • Loading branch information
CeliaBenquet committed Sep 25, 2024
1 parent d83986c commit e11b4dd
Showing 1 changed file with 17 additions and 14 deletions.
31 changes: 17 additions & 14 deletions cebra/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@
import abc
import collections
import types
from typing import List, Tuple, Union
from typing import List, Literal, Optional, Tuple, Union

import literate_dataclasses as dataclasses
import numpy as np
Expand Down Expand Up @@ -67,25 +67,28 @@ class TensorDataset(cebra_data.SingleSessionDataset):
"""

def __init__(
self,
neural: Union[torch.Tensor, npt.NDArray],
continuous: Union[torch.Tensor, npt.NDArray] = None,
discrete: Union[torch.Tensor, npt.NDArray] = None,
offset: int = 1,
device: str = "cpu"
):
def __init__(self,
neural: Union[torch.Tensor, npt.NDArray],
continuous: Union[torch.Tensor, npt.NDArray] = None,
discrete: Union[torch.Tensor, npt.NDArray] = None,
offset: int = 1,
device: str = "cpu"):
super().__init__(device=device)
self.neural = self._to_tensor(neural, check_dtype="float").float()
self.continuous = self._to_tensor(continuous, check_dtype="float").float()
self.continuous = self._to_tensor(continuous,
check_dtype="float").float()
self.discrete = self._to_tensor(discrete, check_dtype="integer")
if self.continuous is None and self.discrete is None:
raise ValueError(
"You have to pass at least one of the arguments 'continuous' or 'discrete'."
)
self.offset = offset

def _to_tensor(self, array, check_dtype: str = None):
def _to_tensor(
self,
array: Union[torch.Tensor, npt.NDArray],
check_dtype: Optional[Literal["int",
"float"]] = None) -> torch.Tensor:
"""Convert :py:func:`numpy.array` to :py:class:`torch.Tensor` if necessary and check the dtype.
Args:
Expand All @@ -101,9 +104,9 @@ def _to_tensor(self, array, check_dtype: str = None):
if isinstance(array, np.ndarray):
array = torch.from_numpy(array)
if check_dtype is not None:
if (check_dtype == "integer" and not cebra.helper._is_integer(array)
) or (check_dtype == "float" and not cebra.helper._is_floating(array)
) or (check_dtype == "float_integer" and not cebra.helper._is_floating_or_integer(array)):
if (check_dtype == "int" and not cebra.helper._is_integer(array)
) or (check_dtype == "float" and
not cebra.helper._is_floating(array)):
raise TypeError(f"{array.dtype} instead of {check_dtype}.")
return array

Expand Down

0 comments on commit e11b4dd

Please sign in to comment.