Skip to content

Commit

Permalink
Merge branch '15-file-reads-non-array-data' into 'master'
Browse files Browse the repository at this point in the history
Resolve "File reads non-array data"

Closes orest-d#15

See merge request schlipf/py4vasp!13
  • Loading branch information
martin-schlipf committed Feb 4, 2020
2 parents 0887c76 + c789175 commit 21219ad
Show file tree
Hide file tree
Showing 4 changed files with 41 additions and 16 deletions.
6 changes: 3 additions & 3 deletions src/py4vasp/data/band.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@ def __init__(self, raw_band):
self._spin_polarized = len(self._bands) == 2
scale = raw_band.cell.scale
lattice_vectors = raw_band.cell.lattice_vectors
self._cell = scale * np.array(lattice_vectors)
self._line_length = np.array(raw_band.line_length)
self._cell = scale * lattice_vectors
self._line_length = raw_band.line_length
self._num_lines = len(self._kpoints) // self._line_length
self._indices = raw_band.label_indices
self._labels = raw_band.labels
Expand Down Expand Up @@ -62,7 +62,7 @@ def _kpoint_distances(self, kpoints=None):
if self._kdists is not None:
return self._kdists
if kpoints is None:
kpoints = self._kpoints
kpoints = self._kpoints[:]
cartesian_kpoints = np.linalg.solve(self._cell, kpoints.T).T
kpoint_lines = np.split(cartesian_kpoints, self._num_lines)
kpoint_norms = [np.linalg.norm(line - line[0], axis=1) for line in kpoint_lines]
Expand Down
2 changes: 1 addition & 1 deletion src/py4vasp/data/dos.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ def to_dict(self, selection=None):

def to_frame(self, selection=None):
df = pd.DataFrame(self._read_data(selection))
df.fermi_energy = np.array(self._fermi_energy)
df.fermi_energy = self._fermi_energy
return df

def _read_data(self, selection):
Expand Down
10 changes: 5 additions & 5 deletions src/py4vasp/raw/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def __init__(self, filename="vaspout.h5"):

def dos(self):
return raw.Dos(
fermi_energy=self._h5f["results/electron_dos/efermi"],
fermi_energy=self._h5f["results/electron_dos/efermi"][()],
energies=self._h5f["results/electron_dos/energies"],
dos=self._h5f["results/electron_dos/dos"],
projectors=self.projectors(),
Expand All @@ -17,8 +17,8 @@ def dos(self):

def band(self):
return raw.Band(
fermi_energy=self._h5f["results/electron_dos/efermi"],
line_length=self._h5f["input/kpoints/number_kpoints"],
fermi_energy=self._h5f["results/electron_dos/efermi"][()],
line_length=self._h5f["input/kpoints/number_kpoints"][()],
kpoints=self._h5f["results/electron_eigenvalues/kpoint_coords"],
eigenvalues=self._h5f["results/electron_eigenvalues/eigenvalues"],
labels=self._safe_get_key("input/kpoints/labels_kpoints"),
Expand All @@ -35,12 +35,12 @@ def projectors(self):
ion_types=self._h5f["results/positions/ion_types"],
number_ion_types=self._h5f["results/positions/number_ion_types"],
orbital_types=self._h5f["results/projectors/lchar"],
number_spins=self._h5f["results/electron_eigenvalues/ispin"],
number_spins=self._h5f["results/electron_eigenvalues/ispin"][()],
)

def cell(self):
return raw.Cell(
scale=self._h5f["results/positions/scale"],
scale=self._h5f["results/positions/scale"][()],
lattice_vectors=self._h5f["results/positions/lattice_vectors"],
)

Expand Down
39 changes: 32 additions & 7 deletions tests/raw/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import itertools
from tempfile import TemporaryFile
from collections import namedtuple
from numbers import Number, Integral

num_spins = 2
num_energies = 20
Expand All @@ -18,7 +19,7 @@
fermi_energy = 0.123

SetupTest = namedtuple(
"SetupTest", "directory, options, create_reference, write_reference, attribute"
"SetupTest", "directory, options, create_reference, write_reference, check_actual"
)


Expand All @@ -43,8 +44,7 @@ def generic_test(setup):
setup.write_reference(h5f, reference)
h5f.close()
file = open_vasp_file(use_default, filename)
actual = getattr(file, setup.attribute)()
assert actual == reference
setup.check_actual(file, reference)
file.close() # must be after comparison, because file is read lazily


Expand All @@ -63,7 +63,7 @@ def test_dos(tmpdir):
options=itertools.product((True, False), repeat=2),
create_reference=reference_dos,
write_reference=write_dos,
attribute="dos",
check_actual=check_dos,
)
generic_test(setup)

Expand All @@ -88,13 +88,19 @@ def write_dos(h5f, dos):
h5f["results/electron_dos/dospar"] = proj.dos


def check_dos(file, reference):
actual = file.dos()
assert actual == reference
assert isinstance(actual.fermi_energy, Number)


def test_band(tmpdir):
setup = SetupTest(
directory=tmpdir,
options=itertools.product((True, False), repeat=3),
create_reference=reference_band,
write_reference=write_band,
attribute="band",
check_actual=check_band,
)
generic_test(setup)

Expand Down Expand Up @@ -134,13 +140,20 @@ def write_band(h5f, band):
h5f["results/projectors/par"] = band.projections


def check_band(file, reference):
actual = file.band()
assert actual == reference
assert isinstance(actual.fermi_energy, Number)
assert isinstance(actual.line_length, Integral)


def test_projectors(tmpdir):
setup = SetupTest(
directory=tmpdir,
options=((True,), (False,)),
create_reference=reference_projectors,
write_reference=write_projectors,
attribute="projectors",
check_actual=check_projectors,
)
generic_test(setup)

Expand All @@ -162,13 +175,19 @@ def write_projectors(h5f, proj):
h5f["results/electron_eigenvalues/ispin"] = proj.number_spins


def check_projectors(file, reference):
actual = file.projectors()
assert actual == reference
assert isinstance(actual.number_spins, Integral)


def test_cell(tmpdir):
setup = SetupTest(
directory=tmpdir,
options=((True,), (False,)),
create_reference=reference_cell,
write_reference=write_cell,
attribute="cell",
check_actual=check_actual,
)
generic_test(setup)

Expand All @@ -180,3 +199,9 @@ def reference_cell():
def write_cell(h5f, cell):
h5f["results/positions/scale"] = cell.scale
h5f["results/positions/lattice_vectors"] = cell.lattice_vectors


def check_actual(file, reference):
actual = file.cell()
assert actual == reference
assert isinstance(actual.scale, Number)

0 comments on commit 21219ad

Please sign in to comment.