From 9e14790f452484d3cf668b2f52b854ac3cd94885 Mon Sep 17 00:00:00 2001 From: Steffen Schneider Date: Tue, 10 Sep 2024 23:49:24 +0200 Subject: [PATCH] Add OS indepdent paths for monkey, hippocampus and synthetic data (#169) --- cebra/data/assets.py | 14 ++++++----- cebra/data/datasets.py | 14 +++++------ cebra/datasets/allen/ca_movie_decoding.py | 1 - cebra/datasets/allen/neuropixel_movie.py | 1 - .../allen/neuropixel_movie_decoding.py | 1 - cebra/datasets/gaussian_mixture.py | 4 +++- cebra/datasets/generate_synthetic_data.py | 4 ++-- cebra/datasets/hippocampus.py | 6 ++--- cebra/datasets/monkey_reaching.py | 23 +++++++++++++------ cebra/solver/multi_session.py | 1 - cebra/solver/single_session.py | 1 - cebra/solver/supervised.py | 1 - tests/test_datasets.py | 2 ++ 13 files changed, 40 insertions(+), 33 deletions(-) diff --git a/cebra/data/assets.py b/cebra/data/assets.py index 6b1f1daf..86695482 100644 --- a/cebra/data/assets.py +++ b/cebra/data/assets.py @@ -21,9 +21,9 @@ # import hashlib -import os import re import warnings +from pathlib import Path from typing import Optional import requests @@ -57,8 +57,10 @@ def download_file_with_progress_bar(url: str, """ # Check if the file already exists in the location - file_path = os.path.join(location, file_name) - if os.path.exists(file_path): + location_path = Path(location) + file_path = location_path / file_name + + if file_path.exists(): existing_checksum = calculate_checksum(file_path) if existing_checksum == expected_checksum: return file_path @@ -91,10 +93,10 @@ def download_file_with_progress_bar(url: str, ) # Create the directory and any necessary parent directories - os.makedirs(location, exist_ok=True) + location_path.mkdir(exist_ok=True) filename = filename_match.group(1) - file_path = os.path.join(location, filename) + file_path = location_path / filename total_size = int(response.headers.get("Content-Length", 0)) checksum = hashlib.md5() # create checksum @@ -111,7 +113,7 @@ def download_file_with_progress_bar(url: str, downloaded_checksum = checksum.hexdigest() # Get the checksum value if downloaded_checksum != expected_checksum: warnings.warn(f"Checksum verification failed. Deleting '{file_path}'.") - os.remove(file_path) + file_path.unlink() warnings.warn("File deleted. Retrying download...") # Retry download using a for loop diff --git a/cebra/data/datasets.py b/cebra/data/datasets.py index dfc37d5e..0b7f191d 100644 --- a/cebra/data/datasets.py +++ b/cebra/data/datasets.py @@ -67,14 +67,12 @@ class TensorDataset(cebra_data.SingleSessionDataset): """ - def __init__( - self, - neural: Union[torch.Tensor, npt.NDArray], - continuous: Union[torch.Tensor, npt.NDArray] = None, - discrete: Union[torch.Tensor, npt.NDArray] = None, - offset: int = 1, - device: str = "cpu" - ): + def __init__(self, + neural: Union[torch.Tensor, npt.NDArray], + continuous: Union[torch.Tensor, npt.NDArray] = None, + discrete: Union[torch.Tensor, npt.NDArray] = None, + offset: int = 1, + device: str = "cpu"): super().__init__(device=device) self.neural = self._to_tensor(neural, torch.FloatTensor).float() self.continuous = self._to_tensor(continuous, torch.FloatTensor) diff --git a/cebra/datasets/allen/ca_movie_decoding.py b/cebra/datasets/allen/ca_movie_decoding.py index 9a8a6317..12d6cc64 100644 --- a/cebra/datasets/allen/ca_movie_decoding.py +++ b/cebra/datasets/allen/ca_movie_decoding.py @@ -31,7 +31,6 @@ import glob import hashlib -import os import pathlib import h5py diff --git a/cebra/datasets/allen/neuropixel_movie.py b/cebra/datasets/allen/neuropixel_movie.py index 097f2105..51011407 100644 --- a/cebra/datasets/allen/neuropixel_movie.py +++ b/cebra/datasets/allen/neuropixel_movie.py @@ -28,7 +28,6 @@ """ import glob import hashlib -import os import pathlib import h5py diff --git a/cebra/datasets/allen/neuropixel_movie_decoding.py b/cebra/datasets/allen/neuropixel_movie_decoding.py index 2fbc4c51..a99f367d 100644 --- a/cebra/datasets/allen/neuropixel_movie_decoding.py +++ b/cebra/datasets/allen/neuropixel_movie_decoding.py @@ -28,7 +28,6 @@ """ import glob import hashlib -import os import pathlib import h5py diff --git a/cebra/datasets/gaussian_mixture.py b/cebra/datasets/gaussian_mixture.py index 63bd1009..f5508838 100644 --- a/cebra/datasets/gaussian_mixture.py +++ b/cebra/datasets/gaussian_mixture.py @@ -19,6 +19,7 @@ # See the License for the specific language governing permissions and # limitations under the License. # +import pathlib from typing import Tuple import joblib as jl @@ -51,7 +52,8 @@ def __init__(self, noise: str = "poisson"): super().__init__() self.noise = noise data = jl.load( - get_datapath(f"synthetic/continuous_label_{self.noise}.jl")) + pathlib.Path(_DEFAULT_DATADIR) / "synthetic" / + f"continuous_label_{self.noise}.jl") self.latent = data["z"] self.index = torch.from_numpy(data["u"]).float() self.neural = torch.from_numpy(data["x"]).float() diff --git a/cebra/datasets/generate_synthetic_data.py b/cebra/datasets/generate_synthetic_data.py index 0b75baf5..8a243d6d 100644 --- a/cebra/datasets/generate_synthetic_data.py +++ b/cebra/datasets/generate_synthetic_data.py @@ -25,7 +25,7 @@ Adapted from pi-VAE: https://github.com/zhd96/pi-vae/blob/main/code/pi_vae.py """ import argparse -import os +import pathlib import sys import joblib as jl @@ -245,5 +245,5 @@ def refractory_poisson(x): "lam": lam_true, "x": x }, - os.path.join(args.save_path, f"continuous_label_{args.noise}.jl"), + pathlib.Path(args.save_path) / f"continuous_label_{args.noise}.jl", ) diff --git a/cebra/datasets/hippocampus.py b/cebra/datasets/hippocampus.py index 29962cb8..a32209a3 100644 --- a/cebra/datasets/hippocampus.py +++ b/cebra/datasets/hippocampus.py @@ -32,7 +32,7 @@ """ import hashlib -import os +import pathlib import joblib import numpy as np @@ -94,8 +94,8 @@ class SingleRatDataset(cebra.data.SingleSessionDataset): """ def __init__(self, name="achilles", root=_DEFAULT_DATADIR, download=True): - location = os.path.join(root, "rat_hippocampus") - file_path = os.path.join(location, f"{name}.jl") + location = pathlib.Path(root) / "rat_hippocampus" + file_path = location / f"{name}.jl" super().__init__(download=download, data_url=rat_dataset_urls[name]["url"], diff --git a/cebra/datasets/monkey_reaching.py b/cebra/datasets/monkey_reaching.py index d98fddae..209fe576 100644 --- a/cebra/datasets/monkey_reaching.py +++ b/cebra/datasets/monkey_reaching.py @@ -29,8 +29,9 @@ """ import hashlib -import os +import pathlib import pickle as pk +from typing import Union import joblib as jl import numpy as np @@ -41,10 +42,11 @@ from cebra.datasets import get_datapath from cebra.datasets import register +_DEFAULT_DATADIR = get_datapath() + def _load_data( - path: str = get_datapath( - "s1_reaching/sub-Han_desc-train_behavior+ecephys.nwb"), + path: Union[str, pathlib.Path] = None, session: str = "active", split: str = "train", ): @@ -61,6 +63,13 @@ def _load_data( """ + if path is None: + path = pathlib.Path( + _DEFAULT_DATADIR + ) / "s1_reaching" / "sub-Han_desc-train_behavior+ecephys.nwb" + else: + path = pathlib.Path(path) + try: from nlb_tools.nwb_interface import NWBDataset except ImportError as e: @@ -259,7 +268,7 @@ def __init__(self, ) self.data = jl.load( - os.path.join(self.path, f"{self.load_session}_all.jl")) + pathlib.Path(self.path) / f"{self.load_session}_all.jl") self._post_load() def split(self, split): @@ -285,7 +294,7 @@ def split(self, split): file_name=f"{self.load_session}_{split}.jl", ) self.data = jl.load( - os.path.join(self.path, f"{self.load_session}_{split}.jl")) + pathlib.Path(self.path) / f"{self.load_session}_{split}.jl") self._post_load() def _post_load(self): @@ -407,7 +416,7 @@ def _create_area2_dataset(): """ - PATH = get_datapath("monkey_reaching_preload_smth_40") + PATH = pathlib.Path(_DEFAULT_DATADIR) / "monkey_reaching_preload_smth_40" for session_type in ["active", "passive", "active-passive", "all"]: @register(f"area2-bump-pos-{session_type}") @@ -506,7 +515,7 @@ def _create_area2_shuffled_dataset(): """ - PATH = get_datapath("monkey_reaching_preload_smth_40/") + PATH = pathlib.Path(_DEFAULT_DATADIR) / "monkey_reaching_preload_smth_40" for session_type in ["active", "active-passive"]: @register(f"area2-bump-pos-{session_type}-shuffled-trial") diff --git a/cebra/solver/multi_session.py b/cebra/solver/multi_session.py index 7f103708..8f456eb6 100644 --- a/cebra/solver/multi_session.py +++ b/cebra/solver/multi_session.py @@ -22,7 +22,6 @@ """Solver implementations for multi-session datasetes.""" import abc -import os from collections.abc import Iterable from typing import List, Optional diff --git a/cebra/solver/single_session.py b/cebra/solver/single_session.py index ded526e9..6b3b1030 100644 --- a/cebra/solver/single_session.py +++ b/cebra/solver/single_session.py @@ -23,7 +23,6 @@ import abc import copy -import os from collections.abc import Iterable from typing import List diff --git a/cebra/solver/supervised.py b/cebra/solver/supervised.py index 471e601f..f69308e6 100644 --- a/cebra/solver/supervised.py +++ b/cebra/solver/supervised.py @@ -26,7 +26,6 @@ as experimental/outdated, and the API for this particular package unstable. """ import abc -import os from collections.abc import Iterable from typing import List diff --git a/tests/test_datasets.py b/tests/test_datasets.py index bc65d4b4..adbfab64 100644 --- a/tests/test_datasets.py +++ b/tests/test_datasets.py @@ -113,6 +113,8 @@ def test_monkey(): def test_allen(): from cebra.datasets import allen + pytest.skip("Test takes too long") + ca_dataset = cebra.datasets.init("allen-movie-one-ca-VISp-100-train-10-111") ca_loader = cebra.data.ContinuousDataLoader( dataset=ca_dataset,