Skip to content

Commit

Permalink
Merge pull request #233 from neworderofjamie/update_smnist
Browse files Browse the repository at this point in the history
SMNIST dataset logic for downloading and checking was inconsistent
  • Loading branch information
fabrizio-ottati authored Jan 4, 2023
2 parents 9838d27 + e40c2e7 commit f4cb6fc
Showing 1 changed file with 46 additions and 11 deletions.
57 changes: 46 additions & 11 deletions tonic/datasets/s_mnist.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np

from tonic.dataset import Dataset
from tonic.download_utils import download_and_extract_archive
from tonic.download_utils import check_integrity, download_and_extract_archive
from tonic.io import make_structured_array


Expand All @@ -25,9 +25,6 @@ class SMNIST(Dataset):
duplicate (bool): If True, emits two spikes per threshold crossing
num_neurons (integer): How many neurons to use to encode thresholds(must be odd)
dt (float): Duration(in microseconds) of each timestep
download (bool): Choose to download data or verify existing files. If True
and a file with the same name and correct hash is already
in the directory, download is automatically skipped.
transform (callable, optional): A callable of transforms to apply to the data.
target_transform (callable, optional): A callable of transforms to apply to the targets/labels.
transforms (callable, optional): A callable of transforms that is applied to both data and
Expand All @@ -43,6 +40,12 @@ class SMNIST(Dataset):
train_labels_file = "train-labels-idx1-ubyte"
test_images_file = "t10k-images-idx3-ubyte"
test_labels_file = "t10k-labels-idx1-ubyte"

train_images_md5 = "f68b3c2dcbeaaa9fbdd348bbdeb94873"
train_labels_md5 = "d53e105ee54ea40749a09fcbcd1e9432"
test_images_md5 = "9fb629c4189551a2d022fa330f9573f3"
test_labels_md5 = "ec29112dd5afa0611ce80d1b7f02629c"

dtype = np.dtype([("t", int), ("x", int), ("p", int)])
ordering = dtype.names

Expand All @@ -66,7 +69,6 @@ def __init__(
duplicate=True,
num_neurons=99,
dt=1000.0,
download=True,
transform=None,
target_transform=None,
):
Expand All @@ -82,10 +84,18 @@ def __init__(
if (num_neurons % 2) == 0:
raise Exception("Number of neurons must be odd")

self.images_file = self.train_images_file if train else self.test_images_file
self.labels_file = self.train_labels_file if train else self.test_labels_file

if download:
if train:
self.images_file = self.train_images_file
self.labels_file = self.train_labels_file
self.images_md5 = self.train_images_md5
self.labels_md5 = self.train_labels_md5
else:
self.images_file = self.test_images_file
self.labels_file = self.test_labels_file
self.images_md5 = self.test_images_md5
self.labels_md5 = self.test_labels_md5

if not self._check_exists():
self.download()

# Open images file
Expand Down Expand Up @@ -179,7 +189,32 @@ def __len__(self):
return self.image_data.shape[0]

def download(self):
for f in [self.images_file, self.labels_file]:
for (f, m) in [(self.images_file, self.images_md5),
(self.labels_file, self.labels_md5)]:
download_and_extract_archive(
self.base_url + f + ".gz", self.location_on_system, filename=f + ".gz"
self.base_url + f + ".gz", self.location_on_system,
filename=f + ".gz", md5=m
)

def _are_labels_present(self) -> bool:
"""Check if the label file is present on disk.
No hashing.
"""
return check_integrity(os.path.join(self.location_on_system,
self.labels_file))

def _are_images_present(self) -> bool:
"""Check if the images file is present on disk.
No hashing.
"""
return check_integrity(os.path.join(self.location_on_system,
self.images_file))


def _check_exists(self):
return (
self._are_labels_present()
and self._are_images_present()
)

0 comments on commit f4cb6fc

Please sign in to comment.