Skip to content

Commit

Permalink
Merge branch '14-missing-kpoint-labels-not-working' into 'master'
Browse files Browse the repository at this point in the history
Resolve "Missing kpoint labels not working"

Closes orest-d#14

See merge request schlipf/py4vasp!12
  • Loading branch information
martin-schlipf committed Feb 3, 2020
2 parents db9d2ea + b7855c9 commit 0887c76
Show file tree
Hide file tree
Showing 3 changed files with 67 additions and 38 deletions.
94 changes: 61 additions & 33 deletions src/py4vasp/data/band.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ def __init__(self, raw_band):
self._raw = raw_band
self._fermi_energy = raw_band.fermi_energy
self._kpoints = raw_band.kpoints
self._kdists = None
self._bands = raw_band.eigenvalues
self._spin_polarized = len(self._bands) == 2
scale = raw_band.cell.scale
Expand All @@ -28,37 +29,20 @@ def from_file(cls, file):
return cls(file.band())

def read(self, selection=None):
kpoints = self._kpoints[:]
return {
"kpoints": kpoints,
"kpoint_distances": self._kpoint_distances(kpoints),
res = {
"kpoints": self._kpoints[:],
"kpoint_labels": self._kpoint_labels(),
"fermi_energy": self._fermi_energy,
**self._shift_bands_by_fermi_energy(),
"projections": self._read_projections(selection),
}
res["kpoint_distances"] = self._kpoint_distances(res["kpoints"])
return res

def plot(self, selection=None, width=0.5):
kdists = self._kpoint_distances(self._kpoints[:])
fatband_kdists = np.concatenate((kdists, np.flip(kdists)))
bands = self._shift_bands_by_fermi_energy()
projections = self._read_projections(selection)
ticks = [*kdists[:: self._line_length], kdists[-1]]
ticks = self._ticks()
labels = self._ticklabels()
data = []
for key, lines in bands.items():
if len(projections) == 0:
data.append(self._scatter(key, kdists, lines))
for name, proj in projections.items():
if self._spin_polarized and not key in name:
continue
upper = lines + width * proj
lower = lines - width * proj
fatband_lines = np.concatenate((lower, np.flip(upper, axis=0)), axis=0)
plot = self._scatter(name, fatband_kdists, fatband_lines)
plot.fill = "toself"
plot.mode = "none"
data.append(plot)
data = self._band_structure(selection, width)
default = {
"xaxis": {"tickmode": "array", "tickvals": ticks, "ticktext": labels},
"yaxis": {"title": {"text": "Energy (eV)"}},
Expand All @@ -74,21 +58,59 @@ def _shift_bands_by_fermi_energy(self):
else:
return {"bands": self._bands[0] - self._fermi_energy}

def _scatter(self, name, kdists, lines):
# insert NaN to split separate lines
num_bands = lines.shape[-1]
kdists = np.tile([*kdists, np.NaN], num_bands)
lines = np.append(lines, [np.repeat(np.NaN, num_bands)], axis=0)
return go.Scatter(x=kdists, y=lines.flatten(order="F"), name=name)

def _kpoint_distances(self, kpoints):
def _kpoint_distances(self, kpoints=None):
if self._kdists is not None:
return self._kdists
if kpoints is None:
kpoints = self._kpoints
cartesian_kpoints = np.linalg.solve(self._cell, kpoints.T).T
kpoint_lines = np.split(cartesian_kpoints, self._num_lines)
kpoint_norms = [np.linalg.norm(line - line[0], axis=1) for line in kpoint_lines]
concatenate_distances = lambda current, addition: (
np.concatenate((current, addition + current[-1]))
)
return functools.reduce(concatenate_distances, kpoint_norms)
self._kdists = functools.reduce(concatenate_distances, kpoint_norms)
return self._kdists

def _band_structure(self, selection, width):
bands = self._shift_bands_by_fermi_energy()
projections = self._read_projections(selection)
if len(projections) == 0:
return self._regular_band_structure(bands)
else:
return self._fat_band_structure(bands, projections, width)

def _regular_band_structure(self, bands):
kdists = self._kpoint_distances()
return [self._scatter(name, kdists, lines) for name, lines in bands.items()]

def _fat_band_structure(self, bands, projections, width):
data = (
self._fat_band(args, width)
for args in itertools.product(bands.items(), projections.items())
)
return list(filter(None, data))

def _fat_band(self, args, width):
(key, lines), (name, projection) = args
if self._spin_polarized and not key in name:
return None
kdists = self._kpoint_distances()
fatband_kdists = np.concatenate((kdists, np.flip(kdists)))
upper = lines + width * projection
lower = lines - width * projection
fatband_lines = np.concatenate((lower, np.flip(upper, axis=0)), axis=0)
plot = self._scatter(name, fatband_kdists, fatband_lines)
plot.fill = "toself"
plot.mode = "none"
return plot

def _scatter(self, name, kdists, lines):
# insert NaN to split separate lines
num_bands = lines.shape[-1]
kdists = np.tile([*kdists, np.NaN], num_bands)
lines = np.append(lines, [np.repeat(np.NaN, num_bands)], axis=0)
return go.Scatter(x=kdists, y=lines.flatten(order="F"), name=name)

def _read_projections(self, selection):
if selection is None:
Expand All @@ -113,7 +135,7 @@ def _read_element(self, index):
return functools.reduce(sum_weight, itertools.product(*index), zero_weight)

def _kpoint_labels(self):
if len(self._labels) == 0:
if self._indices is None or self._labels is None:
return None
# convert from input kpoint list to full list
labels = np.zeros(len(self._kpoints), dtype=self._labels.dtype)
Expand All @@ -122,8 +144,14 @@ def _kpoint_labels(self):
labels[indices] = self._labels
return [l.decode().strip() for l in labels]

def _ticks(self):
kdists = self._kpoint_distances()
return [*kdists[:: self._line_length], kdists[-1]]

def _ticklabels(self):
labels = [" "] * (self._num_lines + 1)
if self._indices is None or self._labels is None:
return labels
for index, label in zip(self._indices, self._labels):
i = index // 2 # line has 2 ends
label = label.decode().strip()
Expand Down
4 changes: 2 additions & 2 deletions src/py4vasp/raw/rawdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,8 @@ class Band:
kpoints: np.ndarray
eigenvalues: np.ndarray
cell: Cell
labels: np.ndarray = np.empty(0, dtype="S")
label_indices: np.ndarray = np.empty(0, dtype="int")
labels: np.ndarray = None
label_indices: np.ndarray = None
projections: np.ndarray = None
projectors: Projectors = None
__eq__ = _dataclass_equal
7 changes: 4 additions & 3 deletions tests/raw/test_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,9 +108,10 @@ def reference_band(use_projectors, use_labels):
kpoints=np.linspace(np.zeros(3), np.ones(3), num_kpoints),
eigenvalues=np.arange(np.prod(shape_eval)).reshape(shape_eval),
cell=reference_cell(),
labels=np.array(["G", "X"], dtype="S") if use_labels else None,
label_indices=[0, 1] if use_labels else None,
)
if use_labels:
band.labels = np.array(["G", "X"], dtype="S")
band.label_indices = [0, 1]
if use_projectors:
band.projectors = reference_projectors()
band.projections = np.arange(np.prod(shape_proj)).reshape(shape_proj)
Expand All @@ -127,7 +128,7 @@ def write_band(h5f, band):
h5f["input/kpoints/positions_labels_kpoints"] = band.label_indices
if band.labels is not None:
h5f["input/kpoints/labels_kpoints"] = band.labels
if band.projectors:
if band.projectors is not None:
write_projectors(h5f, band.projectors)
if band.projections is not None:
h5f["results/projectors/par"] = band.projections
Expand Down

0 comments on commit 0887c76

Please sign in to comment.