Skip to content

Commit

Permalink
Merge pull request #3 from din14970/testing
Browse files Browse the repository at this point in the history
Testing
  • Loading branch information
din14970 authored Feb 5, 2021
2 parents 1c99ebc + 9a15b9b commit 4601547
Show file tree
Hide file tree
Showing 6 changed files with 193 additions and 312 deletions.
1 change: 0 additions & 1 deletion binder/environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ channels:
- conda-forge
dependencies:
- hyperspy
- temmeta
- pywget
- match-series
- pip:
Expand Down
284 changes: 0 additions & 284 deletions examples/example-minimal.ipynb

This file was deleted.

5 changes: 1 addition & 4 deletions pymatchseries/config_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,10 +68,7 @@ def _get_frame_list(self):

def _get_frame_index_iterator(self):
"""Get a range to loop over all the indices of the output"""
numframes = self["numTemplates"]
sf = self["templateSkipNums"]
numstep = self["templateNumStep"]
return range((numframes - len(sf)) // numstep)
return range(len(self._get_frame_list()))


def load_config(path=None):
Expand Down
23 changes: 16 additions & 7 deletions pymatchseries/matchseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,14 +122,22 @@ def __setup_new_calculation(self, data, path, **kwargs):
else:
self.__metadata["lazy"] = True
elif isinstance(data, hs.signals.Signal2D):
self.__metadata["x_name"] = data.axes_manager[-2].name
self.__metadata["x_name"] = (
str(data.axes_manager[-2].name).replace("<", "").replace(">", "")
)
self.__metadata["x_scale"] = data.axes_manager[-2].scale
self.__metadata["x_offset"] = data.axes_manager[-2].offset
self.__metadata["x_unit"] = data.axes_manager[-2].units
self.__metadata["y_name"] = data.axes_manager[-1].name
self.__metadata["x_unit"] = (
str(data.axes_manager[-2].units).replace("<", "").replace(">", "")
)
self.__metadata["y_name"] = (
str(data.axes_manager[-1].name).replace("<", "").replace(">", "")
)
self.__metadata["y_scale"] = data.axes_manager[-1].scale
self.__metadata["y_offset"] = data.axes_manager[-1].offset
self.__metadata["y_unit"] = data.axes_manager[-1].units
self.__metadata["y_unit"] = (
str(data.axes_manager[-1].units).replace("<", "").replace(">", "")
)
self.__metadata["input_type"] = "hyperspy"
EXT = "hspy"
self.__metadata["lazy"] = data._lazy
Expand All @@ -142,7 +150,8 @@ def __setup_new_calculation(self, data, path, **kwargs):
if path is None:
try:
# this will only work for some hspy datasets with the right metadata
filename, _ = os.path.splitext(data.metadata.General.original_filename)
p, fn = os.path.split(data.metadata.General.original_filename)
filename, _ = os.path.splitext(fn)
title = data.metadata.General.title
path = f"./{filename}_{title}/"
except AttributeError:
Expand Down Expand Up @@ -253,7 +262,7 @@ def metadata_file_path(self):

def __prepare_calculation(self):
if os.path.isdir(self.path):
warnings.warn("The calculation {self.metadata['path']} already exists!")
warnings.warn(f"The calculation {self.path} already exists!")
if ioutls.overwrite_dir(self.path):
shutil.rmtree(self.path)
else:
Expand Down Expand Up @@ -395,7 +404,7 @@ def __get_default_axlist(self, numframes):

def __is_valid_data(self, data):
"""Check whether data is the same shape as the calculation"""
frames = self.configuration.__get_frame_list()
frames = self.configuration._get_frame_list()
maxframes = np.max(frames)
xdim = self.metadata["x_dim"]
ydim = self.metadata["y_dim"]
Expand Down
91 changes: 75 additions & 16 deletions pymatchseries/tests/test_config_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,19 +27,78 @@ def test_create_config_file():
os.remove("test.par")


class TestConfigDict:
def test_init(self):
testdict = {"foo": "bar", "batman": 1}
c = ctools.config_dict(testdict)
assert isinstance(c, ctools.config_dict)

def test_get(self):
testdict = {"foo": "bar", "batman": 1}
c = ctools.config_dict(testdict)
assert c.foo == "bar"

def test_set(self):
testdict = {"foo": "bar", "batman": 1}
c = ctools.config_dict(testdict)
c.foo == "foo"
assert c.foo == "bar"
def test_init():
testdict = {"foo": "bar", "batman": 1}
c = ctools.config_dict(testdict)
assert isinstance(c, ctools.config_dict)


def test_get():
testdict = {"foo": "bar", "batman": 1}
c = ctools.config_dict(testdict)
assert c.foo == "bar"


def test_set():
testdict = {"foo": "bar", "batman": 1}
c = ctools.config_dict(testdict)
c.foo == "foo"
assert c.foo == "bar"


@pytest.mark.xfail(raises=KeyError)
def test_setfail():
testdict = {"foo": "bar", "batman": 1}
c = ctools.config_dict(testdict)
c.superman == "foo"


def test_getitem():
testdict = {"foo": "bar", "batman": "1", "superman": "1.1", "joker": "{ 1 3 135 }"}
c = ctools.config_dict(testdict)
assert c["foo"] == "bar"
assert c["batman"] == 1
assert c["superman"] == 1.1
assert c["joker"] == [1, 3, 135]


def test_setitem():
testdict = {"foo": "bar", "batman": "1", "superman": "1.1", "joker": "{ 1 3 135 }"}
c = ctools.config_dict(testdict)
c["foo"] = "bla"
c["batman"] = True
c["joker"] = [1, 3, 135]


@pytest.mark.xfail(raises=KeyError)
def test_setitemfail():
testdict = {"foo": "bar", "batman": "1", "superman": "1.1", "joker": "{ 1 3 135 }"}
c = ctools.config_dict(testdict)
c["noexist"] = 1


def test_bznum():
cf = ctools.load_config()
s, bz = cf._get_stage_bznum()
assert bz == "08"
assert s == 3


def test_framelist():
cf = ctools.load_config()
cf.templateSkipNums = [3, 9]
cf.numTemplates = 10
cf.templateNumOffset = 1
cf.templateNumStep = 2
frms = cf._get_frame_list()
assert frms == [1, 5, 7]


def test_frameindexgen():
cf = ctools.load_config()
cf.templateSkipNums = [3, 9]
cf.numTemplates = 10
cf.templateNumOffset = 1
cf.templateNumStep = 2
frms = cf._get_frame_index_iterator()
assert list(frms) == [0, 1, 2]
101 changes: 101 additions & 0 deletions pymatchseries/tests/test_matchseries.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from pymatchseries import matchseries as ms
import os
import gc
import shutil
import pytest
from hyperspy.signals import Signal2D, EDSTEMSpectrum
import dask.array as da
import numpy as np


# numpy dataset
np.random.seed(1001)
imageds = np.random.randint(1, 5, size=(3, 32, 32))
imageds_da = da.from_array(imageds)
hsimage = Signal2D(imageds)
hsimage_da = Signal2D(imageds_da).as_lazy()
hsimage.metadata.General.original_filename = "non/existant/path.emd"
hsimage.metadata.General.title = "dummy"
hsimage_da.metadata.General.original_filename = "non/existant/path.emd"
hsimage_da.metadata.General.title = "dummy"

# spectrum maps
specmap = np.random.randint(1, 5, size=(3, 32, 32, 3))
specmap_da = da.from_array(specmap)
hsspecmap = EDSTEMSpectrum(specmap)
hsspecmap_da = EDSTEMSpectrum(specmap_da).as_lazy()


@pytest.mark.parametrize("data", [imageds, imageds_da, hsimage, hsimage_da])
def test_match_series_create(data):
mso = ms.MatchSeries(data)
assert not mso.completed
assert type(mso.data) == type(data)


@pytest.mark.xfail(raises=ValueError)
def test_match_series_fail():
ms.MatchSeries()


@pytest.mark.xfail(raises=NotImplementedError)
def test_match_series_typefail():
ms.MatchSeries(data="won't work")


@pytest.mark.xfail(raises=ValueError)
def test_match_series_dimfail():
ms.MatchSeries(data=np.ones((4, 4, 4, 4)))


@pytest.mark.xfail(raises=ValueError)
def test_match_series_sizefail():
ms.MatchSeries(data=np.ones((4, 6, 6)))


@pytest.mark.parametrize(
"data, f",
[
(imageds, 1),
(imageds_da, 2),
(hsimage, 3),
(hsimage_da, 4),
],
)
def test_match_series_save_load(data, f):
mso = ms.MatchSeries(data, path=f"test{f}")
mso._MatchSeries__prepare_calculation()
mso.save_data(mso.input_data_file)
msl = ms.MatchSeries.load(mso.path)
assert mso.metadata == msl.metadata
assert mso.configuration == msl.configuration
assert type(mso.data) == type(msl.data)
path = mso.path
del mso
del msl
gc.collect()
shutil.rmtree(path)


@pytest.fixture(scope="module")
def match_series_dummy():
mso = ms.MatchSeries(imageds)
mso.run()
yield mso
shutil.rmtree(mso.path)


@pytest.mark.parametrize("data", [None, imageds, imageds_da, hsimage, hsimage_da])
def test_match_series_apply_images(data, match_series_dummy):
defdat = match_series_dummy.get_deformed_images(data)
if data is None:
data = match_series_dummy.data
assert type(data) == type(defdat)


@pytest.mark.parametrize("data", [specmap, specmap_da, hsspecmap, hsspecmap_da])
def test_match_series_apply_spectra(data, match_series_dummy):
defspec = match_series_dummy.apply_deformations_to_spectra(data)
defspec2 = match_series_dummy.apply_deformations_to_spectra(data, sum_frames=False)
assert type(data) == type(defspec)
assert type(data) == type(defspec2)

0 comments on commit 4601547

Please sign in to comment.