Skip to content

Commit

Permalink
Fix for issue orest-d#18 (orest-d#26)
Browse files Browse the repository at this point in the history
We create new Kpoints classes to read and refine the data from the HDF5 file. This removes some logic from the Band class that more logically belongs to the Kpoints class. In the process, we change the logic such that different input modes for different KPOINTS files do not crash the Band class.
  • Loading branch information
martin-schlipf authored Feb 20, 2020
1 parent e5789dc commit b5647af
Show file tree
Hide file tree
Showing 11 changed files with 342 additions and 106 deletions.
1 change: 1 addition & 0 deletions src/py4vasp/data/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from .band import Band
from .convergence import Convergence
from .dos import Dos
from .kpoints import Kpoints
from .projectors import Projectors

import plotly.io as pio
Expand Down
7 changes: 7 additions & 0 deletions src/py4vasp/data/_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,3 +10,10 @@ def from_file(cls, file, attr):
context = nullcontext(file)
with context as file:
yield cls(getattr(file, attr)())


def decode_if_possible(string):
try:
return string.decode()
except (UnicodeDecodeError, AttributeError):
return string
97 changes: 34 additions & 63 deletions src/py4vasp/data/band.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,46 +3,34 @@
import numpy as np
import plotly.graph_objects as go
from .projectors import Projectors
from .kpoints import Kpoints
from py4vasp.data import _util


class Band:
def __init__(self, raw_band):
self._raw = raw_band
self._fermi_energy = raw_band.fermi_energy
self._kpoints = raw_band.kpoints
self._kdists = None
self._bands = raw_band.eigenvalues
self._spin_polarized = len(self._bands) == 2
scale = raw_band.cell.scale
lattice_vectors = raw_band.cell.lattice_vectors
self._cell = scale * lattice_vectors
self._line_length = raw_band.line_length
self._num_lines = len(self._kpoints) // self._line_length
self._indices = raw_band.label_indices
self._labels = raw_band.labels
self._kpoints = Kpoints(raw_band.kpoints)
self._spin_polarized = len(raw_band.eigenvalues) == 2
if raw_band.projectors is not None:
self._projectors = Projectors(raw_band.projectors)
self._projections = raw_band.projections

@classmethod
def from_file(cls, file=None):
return _util.from_file(cls, file, "band")

def read(self, selection=None):
res = {
"kpoints": self._kpoints[:],
"kpoint_labels": self._kpoint_labels(),
"fermi_energy": self._fermi_energy,
"kpoint_distances": self._kpoints.distances(),
"kpoint_labels": self._kpoints.labels(),
"fermi_energy": self._raw.fermi_energy,
**self._shift_bands_by_fermi_energy(),
"projections": self._read_projections(selection),
}
res["kpoint_distances"] = self._kpoint_distances(res["kpoints"])
return res

def plot(self, selection=None, width=0.5):
ticks = self._ticks()
labels = self._ticklabels()
ticks, labels = self._ticks_and_labels()
data = self._band_structure(selection, width)
default = {
"xaxis": {"tickmode": "array", "tickvals": ticks, "ticktext": labels},
Expand All @@ -53,25 +41,11 @@ def plot(self, selection=None, width=0.5):
def _shift_bands_by_fermi_energy(self):
if self._spin_polarized:
return {
"up": self._bands[0] - self._fermi_energy,
"down": self._bands[1] - self._fermi_energy,
"up": self._raw.eigenvalues[0] - self._raw.fermi_energy,
"down": self._raw.eigenvalues[1] - self._raw.fermi_energy,
}
else:
return {"bands": self._bands[0] - self._fermi_energy}

def _kpoint_distances(self, kpoints=None):
if self._kdists is not None:
return self._kdists
if kpoints is None:
kpoints = self._kpoints[:]
cartesian_kpoints = np.linalg.solve(self._cell, kpoints.T).T
kpoint_lines = np.split(cartesian_kpoints, self._num_lines)
kpoint_norms = [np.linalg.norm(line - line[0], axis=1) for line in kpoint_lines]
concatenate_distances = lambda current, addition: (
np.concatenate((current, addition + current[-1]))
)
self._kdists = functools.reduce(concatenate_distances, kpoint_norms)
return self._kdists
return {"bands": self._raw.eigenvalues[0] - self._raw.fermi_energy}

def _band_structure(self, selection, width):
bands = self._shift_bands_by_fermi_energy()
Expand All @@ -82,7 +56,7 @@ def _band_structure(self, selection, width):
return self._fat_band_structure(bands, projections, width)

def _regular_band_structure(self, bands):
kdists = self._kpoint_distances()
kdists = self._kpoints.distances()
return [self._scatter(name, kdists, lines) for name, lines in bands.items()]

def _fat_band_structure(self, bands, projections, width):
Expand All @@ -96,7 +70,7 @@ def _fat_band(self, args, width):
(key, lines), (name, projection) = args
if self._spin_polarized and not key in name:
return None
kdists = self._kpoint_distances()
kdists = self._kpoints.distances()
fatband_kdists = np.concatenate((kdists, np.flip(kdists)))
upper = lines + width * projection
lower = lines - width * projection
Expand Down Expand Up @@ -131,30 +105,27 @@ def _merge_labels(self, labels):
return "_".join(filter(None, labels))

def _read_element(self, index):
sum_weight = lambda weight, i: weight + self._projections[i]
zero_weight = np.zeros(self._bands.shape[1:])
sum_weight = lambda weight, i: weight + self._raw.projections[i]
zero_weight = np.zeros(self._raw.eigenvalues.shape[1:])
return functools.reduce(sum_weight, itertools.product(*index), zero_weight)

def _kpoint_labels(self):
if self._indices is None or self._labels is None:
return None
# convert from input kpoint list to full list
labels = np.zeros(len(self._kpoints), dtype=self._labels.dtype)
indices = np.array(self._indices)
indices = self._line_length * (indices // 2) + indices % 2 - 1
labels[indices] = self._labels
return [l.decode().strip() for l in labels]

def _ticks(self):
kdists = self._kpoint_distances()
return [*kdists[:: self._line_length], kdists[-1]]

def _ticklabels(self):
labels = [" "] * (self._num_lines + 1)
if self._indices is None or self._labels is None:
return labels
for index, label in zip(self._indices, self._labels):
i = index // 2 # line has 2 ends
label = label.decode().strip()
labels[i] = (labels[i] + "|" + label) if labels[i].strip() else label
return labels
def _ticks_and_labels(self):
labels = self._kpoints.labels()
if labels is None:
return None, None
labels = np.array(labels)
indices = np.arange(len(self._raw.kpoints.coordinates))
line_length = self._kpoints.line_length()
edge_of_line = (indices + 1) % line_length == 0
edge_of_line[0] = True
mask = np.logical_or(edge_of_line, labels != "")
masked_dists = self._kpoints.distances()[mask]
masked_labels = labels[mask]
ticks, indices = np.unique(masked_dists, return_inverse=True)
labels = [""] * len(ticks)
for i, label in zip(indices, masked_labels):
if labels[i].strip():
labels[i] = labels[i] + "|" + label
else:
labels[i] = label or " "
return ticks, labels
72 changes: 72 additions & 0 deletions src/py4vasp/data/kpoints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
from py4vasp.data import _util
from py4vasp.exceptions import RefinementException
import functools
import numpy as np


class Kpoints:
def __init__(self, raw_kpoints):
self._raw = raw_kpoints
self._distances = None

def read(self):
return {
"mode": self.mode(),
"line_length": self.line_length(),
"coordinates": self._raw.coordinates[:],
"weights": self._raw.weights[:],
"labels": self.labels(),
}

def line_length(self):
if self.mode() == "line":
return self._raw.number
return len(self._raw.coordinates)

def number_lines(self):
return len(self._raw.coordinates) // self.line_length()

def distances(self):
if self._distances is not None:
return self._distances
cell = self._raw.cell.lattice_vectors * self._raw.cell.scale
cartesian_kpoints = np.linalg.solve(cell, self._raw.coordinates[:].T).T
kpoint_lines = np.split(cartesian_kpoints, self.number_lines())
kpoint_norms = [np.linalg.norm(line - line[0], axis=1) for line in kpoint_lines]
concatenate_distances = lambda current, addition: (
np.concatenate((current, addition + current[-1]))
)
self._distances = functools.reduce(concatenate_distances, kpoint_norms)
return self._distances

def mode(self):
mode = _util.decode_if_possible(self._raw.mode).strip() or "# empty string"
first_char = mode[0].lower()
if first_char == "a":
return "automatic"
elif first_char == "e":
return "explicit"
elif first_char == "g":
return "gamma"
elif first_char == "l":
return "line"
elif first_char == "m":
return "monkhorst"
else:
raise RefinementException(
"Could not understand the mode '{}' ".format(mode)
+ "when refining the raw kpoints data."
)

def labels(self):
if self._raw.labels is None or self._raw.label_indices is None:
return None
labels = [""] * len(self._raw.coordinates)
use_line_mode = self.mode() == "line"
for label, index in zip(self._raw.labels, self._raw.label_indices):
label = _util.decode_if_possible(label.strip())
if use_line_mode:
index = self.line_length() * (index // 2) + index % 2
index -= 1 # convert from Fortran to Python
labels[index] = label
return labels
10 changes: 10 additions & 0 deletions src/py4vasp/exceptions/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
from .exceptions import *


class Py4VaspException(Exception):
"""Base class for all exceptions raised by py4vasp"""


class RefinementException(Py4VaspException):
"""When refining the raw dataclass into the class handling e.g. reading and
plotting of the data an error occured"""
Empty file.
18 changes: 13 additions & 5 deletions src/py4vasp/raw/file.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,8 @@ def band(self):
self._assert_not_closed()
return raw.Band(
fermi_energy=self._h5f["results/electron_dos/efermi"][()],
line_length=self._h5f["input/kpoints/number_kpoints"][()],
kpoints=self._h5f["results/electron_eigenvalues/kpoint_coords"],
kpoints=self.kpoints(),
eigenvalues=self._h5f["results/electron_eigenvalues/eigenvalues"],
labels=self._safe_get_key("input/kpoints/labels_kpoints"),
label_indices=self._safe_get_key("input/kpoints/positions_labels_kpoints"),
cell=self.cell(),
projectors=self.projectors(),
projections=self._safe_get_key("results/projectors/par"),
)
Expand All @@ -46,6 +42,18 @@ def projectors(self):
number_spins=self._h5f["results/electron_eigenvalues/ispin"][()],
)

def kpoints(self):
self._assert_not_closed()
return raw.Kpoints(
mode=self._h5f["input/kpoints/mode"][()],
number=self._h5f["input/kpoints/number_kpoints"][()],
coordinates=self._h5f["results/electron_eigenvalues/kpoint_coords"],
weights=self._h5f["results/electron_eigenvalues/kpoints_symmetry_weight"],
labels=self._safe_get_key("input/kpoints/labels_kpoints"),
label_indices=self._safe_get_key("input/kpoints/positions_labels_kpoints"),
cell=self.cell(),
)

def cell(self):
self._assert_not_closed()
return raw.Cell(
Expand Down
18 changes: 13 additions & 5 deletions src/py4vasp/raw/rawdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,18 @@ class Cell:
__eq__ = _dataclass_equal


@dataclass
class Kpoints:
mode: str
number: int
coordinates: np.ndarray
weights: np.ndarray
cell: Cell
labels: np.ndarray = None
label_indices: np.ndarray = None
__eq__ = _dataclass_equal


@dataclass
class Dos:
fermi_energy: float
Expand All @@ -48,12 +60,8 @@ class Dos:
@dataclass
class Band:
fermi_energy: float
line_length: int
kpoints: np.ndarray
kpoints: Kpoints
eigenvalues: np.ndarray
cell: Cell
labels: np.ndarray = None
label_indices: np.ndarray = None
projections: np.ndarray = None
projectors: Projectors = None
__eq__ = _dataclass_equal
Expand Down
Loading

0 comments on commit b5647af

Please sign in to comment.