Skip to content

Commit

Permalink
Resolve orest-d#17: close file automatically (orest-d#24)
Browse files Browse the repository at this point in the history
Implement context manager for File and the data processing classes.
  • Loading branch information
martin-schlipf authored Feb 18, 2020
1 parent 056ebab commit e5789dc
Show file tree
Hide file tree
Showing 12 changed files with 145 additions and 26 deletions.
12 changes: 12 additions & 0 deletions src/py4vasp/data/_util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from contextlib import contextmanager, nullcontext
import py4vasp.raw as raw


@contextmanager
def from_file(cls, file, attr):
if file is None or isinstance(file, str):
context = raw.File(file)
else:
context = nullcontext(file)
with context as file:
yield cls(getattr(file, attr)())
5 changes: 3 additions & 2 deletions src/py4vasp/data/band.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import plotly.graph_objects as go
from .projectors import Projectors
from py4vasp.data import _util


class Band:
Expand All @@ -25,8 +26,8 @@ def __init__(self, raw_band):
self._projections = raw_band.projections

@classmethod
def from_file(cls, file):
return cls(file.band())
def from_file(cls, file=None):
return _util.from_file(cls, file, "band")

def read(self, selection=None):
res = {
Expand Down
11 changes: 8 additions & 3 deletions src/py4vasp/data/convergence.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,22 @@
import plotly.graph_objects as go
from py4vasp.data import _util


class Convergence:
def __init__(self, raw_conv):
self._conv = raw_conv
self._raw = raw_conv

@classmethod
def from_file(cls, file=None):
return _util.from_file(cls, file, "convergence")

def read(self, selection=None):
if selection is None:
selection = "TOTEN"
for i, label in enumerate(self._conv.labels):
for i, label in enumerate(self._raw.labels):
label = str(label, "utf-8").strip()
if selection in label:
return label, self._conv.energies[:, i]
return label, self._raw.energies[:, i]

def plot(self, selection=None):
label, data = self.read(selection)
Expand Down
5 changes: 3 additions & 2 deletions src/py4vasp/data/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import numpy as np
import pandas as pd
from .projectors import Projectors
from py4vasp.data import _util


class Dos:
Expand All @@ -18,8 +19,8 @@ def __init__(self, raw_dos):
self._projections = raw_dos.projections

@classmethod
def from_file(cls, file):
return cls(file.dos())
def from_file(cls, file=None):
return _util.from_file(cls, file, "dos")

def plot(self, selection=None):
df = self.to_frame(selection)
Expand Down
6 changes: 6 additions & 0 deletions src/py4vasp/data/projectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from dataclasses import dataclass
import re
import numpy as np
from py4vasp.data import _util


_default = "*"
Expand Down Expand Up @@ -82,11 +83,16 @@ class Index(NamedTuple):
spin: Union[str, Selection]

def __init__(self, raw_proj):
self._raw = raw_proj
self._init_atom_dict(raw_proj)
self._init_orbital_dict(raw_proj)
self._init_spin_dict(raw_proj)
self._spin_polarized = raw_proj.number_spins == 2

@classmethod
def from_file(cls, file=None):
return _util.from_file(cls, file, "projectors")

def _init_atom_dict(self, raw_proj):
num_atoms = np.sum(raw_proj.number_ion_types)
all_atoms = self.Selection(indices=range(num_atoms))
Expand Down
21 changes: 19 additions & 2 deletions src/py4vasp/raw/file.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,18 @@
from contextlib import AbstractContextManager
import h5py
import py4vasp.raw as raw


class File:
def __init__(self, filename="vaspout.h5"):
class File(AbstractContextManager):
default_filename = "vaspout.h5"

def __init__(self, filename=None):
filename = filename or File.default_filename
self._h5f = h5py.File(filename, "r")
self.closed = False

def dos(self):
self._assert_not_closed()
return raw.Dos(
fermi_energy=self._h5f["results/electron_dos/efermi"][()],
energies=self._h5f["results/electron_dos/energies"],
Expand All @@ -16,6 +22,7 @@ def dos(self):
)

def band(self):
self._assert_not_closed()
return raw.Band(
fermi_energy=self._h5f["results/electron_dos/efermi"][()],
line_length=self._h5f["input/kpoints/number_kpoints"][()],
Expand All @@ -29,6 +36,7 @@ def band(self):
)

def projectors(self):
self._assert_not_closed()
if "results/projectors" not in self._h5f:
return None
return raw.Projectors(
Expand All @@ -39,19 +47,28 @@ def projectors(self):
)

def cell(self):
self._assert_not_closed()
return raw.Cell(
scale=self._h5f["results/positions/scale"][()],
lattice_vectors=self._h5f["results/positions/lattice_vectors"],
)

def convergence(self):
self._assert_not_closed()
return raw.Convergence(
labels=self._h5f["intermediate/history/energies_tags"],
energies=self._h5f["intermediate/history/energies"],
)

def close(self):
self._h5f.close()
self.closed = True

def __exit__(self, exc_type, exc_value, traceback):
self.close()

def _assert_not_closed(self):
assert not self.closed, "I/O operation on closed file."

def _safe_get_key(self, key):
if key in self._h5f:
Expand Down
56 changes: 56 additions & 0 deletions tests/data/conftest.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
from numpy.testing import assert_array_almost_equal_nulp
from contextlib import contextmanager
from unittest.mock import patch
import pytest
import py4vasp.raw as raw


class _Assert:
Expand All @@ -11,3 +14,56 @@ def allclose(actual, desired):
@pytest.fixture
def Assert():
return _Assert


@pytest.fixture
def mock_file():
@contextmanager
def _mock_file(name, ref):
cm_init = patch.object(raw.File, "__init__", autospec=True, return_value=None)
cm_sut = patch.object(raw.File, name, autospec=True, return_value=ref)
cm_close = patch.object(raw.File, "close", autospec=True)
with cm_init as init, cm_sut as sut, cm_close as close:
yield {"init": init, "sut": sut, "close": close}

return _mock_file


@pytest.fixture
def check_read():
def _check_read(cls, mocks, ref):
ref = cls(ref)
_check_read_from_open_file(cls, mocks, ref)
_check_read_from_default_file(cls, mocks, ref)
_check_read_from_filename(cls, mocks, ref)

def _check_read_from_open_file(cls, mocks, ref):
with raw.File() as file:
_reset_mocks(mocks)
with cls.from_file(file) as actual:
assert actual._raw == ref._raw
mocks["init"].assert_not_called()
mocks["sut"].assert_called_once()
mocks["close"].assert_not_called()

def _check_read_from_default_file(cls, mocks, ref):
_reset_mocks(mocks)
with cls.from_file() as actual:
assert actual._raw == ref._raw
mocks["init"].assert_called_once()
mocks["sut"].assert_called_once()
mocks["close"].assert_called_once()

def _check_read_from_filename(cls, mocks, ref):
_reset_mocks(mocks)
with cls.from_file("filename") as actual:
assert actual._raw == ref._raw
mocks["init"].assert_called_once()
mocks["sut"].assert_called_once()
mocks["close"].assert_called_once()

def _reset_mocks(mocks):
for mock in mocks.values():
mock.reset_mock()

return _check_read
10 changes: 3 additions & 7 deletions tests/data/test_band.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from py4vasp.data import Band
import py4vasp.raw as raw
import pytest
import types
import numpy as np


Expand Down Expand Up @@ -54,13 +53,10 @@ def test_parabolic_band_plot(two_parabolic_bands, Assert):
Assert.allclose(bands, ref_bands)


def test_parabolic_band_from_file(two_parabolic_bands):
def test_parabolic_band_from_file(two_parabolic_bands, mock_file, check_read):
raw_band, _ = two_parabolic_bands
file = types.SimpleNamespace()
file.band = lambda: raw_band
reference = Band(raw_band)
actual = Band.from_file(file)
assert actual._raw == reference._raw
with mock_file("band", raw_band) as mocks:
check_read(Band, mocks, raw_band)


@pytest.fixture
Expand Down
8 changes: 6 additions & 2 deletions tests/data/test_convergence.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
from py4vasp.data import Convergence
import py4vasp.raw as raw
import pytest
import types
import numpy as np
import py4vasp.raw as raw


@pytest.fixture
Expand Down Expand Up @@ -33,3 +32,8 @@ def test_plot_convergence(reference_convergence, Assert):
fig = conv.plot("temperature")
assert fig.layout.yaxis.title.text == "Temperature (K)"
Assert.allclose(fig.data[0].y, reference_convergence.energies[:, 1])


def test_convergence_from_file(reference_convergence, mock_file, check_read):
with mock_file("convergence", reference_convergence) as mocks:
check_read(Convergence, mocks, reference_convergence)
11 changes: 3 additions & 8 deletions tests/data/test_dos.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from py4vasp.data import Dos
import py4vasp.raw as raw
import pytest
import types
import numpy as np

num_energies = 50
Expand Down Expand Up @@ -51,13 +50,9 @@ def test_nonmagnetic_Dos_plot(nonmagnetic_Dos, Assert):
Assert.allclose(fig.data[0].y, raw_dos.dos[0])


def test_nonmagnetic_Dos_from_file(nonmagnetic_Dos):
raw_dos = nonmagnetic_Dos
file = types.SimpleNamespace()
file.dos = lambda: raw_dos
reference = Dos(raw_dos)
actual = Dos.from_file(file)
assert actual._raw == reference._raw
def test_nonmagnetic_Dos_from_file(nonmagnetic_Dos, mock_file, check_read):
with mock_file("dos", nonmagnetic_Dos) as mocks:
check_read(Dos, mocks, nonmagnetic_Dos)


@pytest.fixture
Expand Down
5 changes: 5 additions & 0 deletions tests/data/test_projectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@ def without_spin():
return proj


def test_from_file(without_spin, mock_file, check_read):
with mock_file("projectors", without_spin) as mocks:
check_read(Projectors, mocks, without_spin)


@pytest.fixture
def spin_polarized(without_spin):
without_spin.number_spins = 2
Expand Down
21 changes: 21 additions & 0 deletions tests/raw/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import os
import numpy as np
import itertools
import inspect
from tempfile import TemporaryFile
from collections import namedtuple
from numbers import Number, Integral
Expand Down Expand Up @@ -35,6 +36,26 @@ def working_directory(path):
os.chdir(prev_cwd)


def test_file_as_context():
tf = TemporaryFile()
h5f = h5py.File(tf, "w")
h5f.close()
with File(tf) as file:
assert not file.closed
h5f = file._h5f
# check that file is closed and accessing it raises ValueError
assert file.closed
with pytest.raises(ValueError):
h5f.file
for func in inspect.getmembers(file, predicate=inspect.isroutine):
name = func[0]
if name[0] == "_" or name in ["close"]:
continue
with pytest.raises(AssertionError):
print(name)
getattr(file, name)()


def generic_test(setup):
with working_directory(setup.directory):
for option in setup.options:
Expand Down

0 comments on commit e5789dc

Please sign in to comment.