diff --git a/src/py4vasp/data/__init__.py b/src/py4vasp/data/__init__.py index 7092e35..ea28757 100644 --- a/src/py4vasp/data/__init__.py +++ b/src/py4vasp/data/__init__.py @@ -1,5 +1,6 @@ from .band import Band from .dos import Dos +from .projectors import Projectors import plotly.io as pio import cufflinks as cf diff --git a/src/py4vasp/data/band.py b/src/py4vasp/data/band.py index 090a4d6..b9677c1 100644 --- a/src/py4vasp/data/band.py +++ b/src/py4vasp/data/band.py @@ -1,17 +1,11 @@ -import re import functools import itertools import numpy as np import plotly.graph_objects as go -from collections import namedtuple +from .projectors import Projectors class Band: - _Index = namedtuple("_Index", "spin, atom, orbital") - _Atom = namedtuple("_Atom", "indices, label") - _Orbital = namedtuple("_Orbital", "indices, label") - _Spin = namedtuple("_Spin", "indices, label") - def __init__(self, raw_band): self._raw = raw_band self._fermi_energy = raw_band.fermi_energy @@ -25,58 +19,14 @@ def __init__(self, raw_band): self._num_lines = len(self._kpoints) // self._line_length self._indices = raw_band.label_indices self._labels = raw_band.labels - self._has_projectors = raw_band.projectors is not None - if self._has_projectors: - self._init_projectors(raw_band.projectors) + if raw_band.projectors is not None: + self._projectors = Projectors(raw_band.projectors) + self._projections = raw_band.projections @classmethod def from_file(cls, file): return cls(file.band()) - def _init_projectors(self, raw_proj): - self._projections = raw_proj.bands - ion_types = raw_proj.ion_types - ion_types = [type.decode().strip() for type in ion_types] - self._init_atom_dict(ion_types, raw_proj.number_ion_types) - orbitals = raw_proj.orbital_types - orbitals = [orb.decode().strip() for orb in orbitals] - self._init_orbital_dict(orbitals) - self._init_spin_dict() - - def _init_atom_dict(self, ion_types, number_ion_types): - num_atoms = self._projections.shape[1] - all_atoms = self._Atom(indices=range(num_atoms), label=None) - self._atom_dict = {"*": all_atoms} - start = 0 - for type, number in zip(ion_types, number_ion_types): - _range = range(start, start + number) - self._atom_dict[type] = self._Atom(indices=_range, label=type) - for i in _range: - # create labels like Si_1, Si_2, Si_3 (starting at 1) - label = type + "_" + str(_range.index(i) + 1) - self._atom_dict[str(i + 1)] = self._Atom(indices=[i], label=label) - start += number - # atoms may be preceeded by : - for key in self._atom_dict.copy(): - self._atom_dict[key + ":"] = self._atom_dict[key] - - def _init_orbital_dict(self, orbitals): - num_orbitals = self._projections.shape[2] - all_orbitals = self._Orbital(indices=range(num_orbitals), label=None) - self._orbital_dict = {"*": all_orbitals} - for i, orbital in enumerate(orbitals): - self._orbital_dict[orbital] = self._Orbital(indices=[i], label=orbital) - if "px" in self._orbital_dict: - self._orbital_dict["p"] = self._Orbital(indices=range(1, 4), label="p") - self._orbital_dict["d"] = self._Orbital(indices=range(4, 9), label="d") - self._orbital_dict["f"] = self._Orbital(indices=range(9, 16), label="f") - - def _init_spin_dict(self): - labels = ["up", "down"] if self._spin_polarized else [None] - self._spin_dict = { - key: self._Spin(indices=[i], label=key) for i, key in enumerate(labels) - } - def read(self, selection=None): kpoints = self._kpoints[:] return { @@ -124,12 +74,6 @@ def _shift_bands_by_fermi_energy(self): else: return {"bands": self._bands[0] - self._fermi_energy} - def _read_projections(self, selection): - if selection is None: - return {} - parts = self._parse_selection(selection) - return self._read_elements(parts) - def _scatter(self, name, kdists, lines): # insert NaN to split separate lines num_bands = lines.shape[-1] @@ -146,24 +90,17 @@ def _kpoint_distances(self, kpoints): ) return functools.reduce(concatenate_distances, kpoint_norms) - def _parse_selection(self, selection): - atom = self._atom_dict["*"] - selection = re.sub("\s*:\s*", ": ", selection) - for part in re.split("[ ,]+", selection): - if part in self._orbital_dict: - orbital = self._orbital_dict[part] - else: - atom = self._atom_dict[part] - orbital = self._orbital_dict["*"] - if ":" not in part: # exclude ":" because it starts a new atom - for spin in self._spin_dict.values(): - yield atom, orbital, spin + def _read_projections(self, selection): + if selection is None: + return {} + return self._read_elements(selection) - def _read_elements(self, parts): + def _read_elements(self, selection): res = {} - for atom, orbital, spin in parts: + for select in self._projectors.parse_selection(selection): + atom, orbital, spin = self._projectors.select(*select) label = self._merge_labels([atom.label, orbital.label, spin.label]) - index = self._Index(spin.indices, atom.indices, orbital.indices) + index = (spin.indices, atom.indices, orbital.indices) res[label] = self._read_element(index) return res diff --git a/src/py4vasp/data/dos.py b/src/py4vasp/data/dos.py index b22399b..d2617eb 100644 --- a/src/py4vasp/data/dos.py +++ b/src/py4vasp/data/dos.py @@ -1,17 +1,11 @@ -import re import functools import itertools import numpy as np import pandas as pd -from collections import namedtuple +from .projectors import Projectors class Dos: - _Index = namedtuple("_Index", "spin, atom, orbital") - _Atom = namedtuple("_Atom", "indices, label") - _Orbital = namedtuple("_Orbital", "indices, label") - _Spin = namedtuple("_Spin", "indices, label") - def __init__(self, raw_dos): self._raw = raw_dos self._fermi_energy = raw_dos.fermi_energy @@ -20,56 +14,13 @@ def __init__(self, raw_dos): self._spin_polarized = self._dos.shape[0] == 2 self._has_partial_dos = raw_dos.projectors is not None if self._has_partial_dos: - self._init_partial_dos(raw_dos.projectors) + self._projectors = Projectors(raw_dos.projectors) + self._projections = raw_dos.projections @classmethod def from_file(cls, file): return cls(file.dos()) - def _init_partial_dos(self, raw_proj): - self._partial_dos = raw_proj.dos - ion_types = raw_proj.ion_types - ion_types = [type.decode().strip() for type in ion_types] - self._init_atom_dict(ion_types, raw_proj.number_ion_types) - orbitals = raw_proj.orbital_types - orbitals = [orb.decode().strip() for orb in orbitals] - self._init_orbital_dict(orbitals) - self._init_spin_dict() - - def _init_atom_dict(self, ion_types, number_ion_types): - num_atoms = self._partial_dos.shape[1] - all_atoms = self._Atom(indices=range(num_atoms), label=None) - self._atom_dict = {"*": all_atoms} - start = 0 - for type, number in zip(ion_types, number_ion_types): - _range = range(start, start + number) - self._atom_dict[type] = self._Atom(indices=_range, label=type) - for i in _range: - # create labels like Si_1, Si_2, Si_3 (starting at 1) - label = type + "_" + str(_range.index(i) + 1) - self._atom_dict[str(i + 1)] = self._Atom(indices=[i], label=label) - start += number - # atoms may be preceeded by : - for key in self._atom_dict.copy(): - self._atom_dict[key + ":"] = self._atom_dict[key] - - def _init_orbital_dict(self, orbitals): - num_orbitals = self._partial_dos.shape[2] - all_orbitals = self._Orbital(indices=range(num_orbitals), label=None) - self._orbital_dict = {"*": all_orbitals} - for i, orbital in enumerate(orbitals): - self._orbital_dict[orbital] = self._Orbital(indices=[i], label=orbital) - if "px" in self._orbital_dict: - self._orbital_dict["p"] = self._Orbital(indices=range(1, 4), label="p") - self._orbital_dict["d"] = self._Orbital(indices=range(4, 9), label="d") - self._orbital_dict["f"] = self._Orbital(indices=range(9, 16), label="f") - - def _init_spin_dict(self): - labels = ["up", "down"] if self._spin_polarized else [None] - self._spin_dict = { - key: self._Spin(indices=[i], label=key) for i, key in enumerate(labels) - } - def plot(self, selection=None): df = self.to_frame(selection) if self._spin_polarized: @@ -114,8 +65,7 @@ def _read_partial_dos(self, selection): if selection is None: return {} self._raise_error_if_partial_Dos_not_available() - parts = self._parse_filter(selection) - return self._read_elements(parts) + return self._read_elements(selection) def _raise_error_if_partial_Dos_not_available(self): if not self._has_partial_dos: @@ -123,24 +73,12 @@ def _raise_error_if_partial_Dos_not_available(self): "Filtering requires partial DOS which was not found in HDF5 file." ) - def _parse_filter(self, selection): - atom = self._atom_dict["*"] - selection = re.sub("\s*:\s*", ": ", selection) - for part in re.split("[ ,]+", selection): - if part in self._orbital_dict: - orbital = self._orbital_dict[part] - else: - atom = self._atom_dict[part] - orbital = self._orbital_dict["*"] - if ":" not in part: # exclude ":" because it starts a new atom - for spin in self._spin_dict.values(): - yield atom, orbital, spin - - def _read_elements(self, parts): + def _read_elements(self, selection): res = {} - for atom, orbital, spin in parts: + for select in self._projectors.parse_selection(selection): + atom, orbital, spin = self._projectors.select(*select) label = self._merge_labels([atom.label, orbital.label, spin.label]) - index = self._Index(spin.indices, atom.indices, orbital.indices) + index = (spin.indices, atom.indices, orbital.indices) res[label] = self._read_element(index) return res @@ -148,6 +86,6 @@ def _merge_labels(self, labels): return "_".join(filter(None, labels)) def _read_element(self, index): - sum_dos = lambda dos, i: dos + self._partial_dos[i] + sum_dos = lambda dos, i: dos + self._projections[i] zero_dos = np.zeros(len(self._energies)) return functools.reduce(sum_dos, itertools.product(*index), zero_dos) diff --git a/src/py4vasp/data/projectors.py b/src/py4vasp/data/projectors.py new file mode 100644 index 0000000..4a082e5 --- /dev/null +++ b/src/py4vasp/data/projectors.py @@ -0,0 +1,175 @@ +from __future__ import annotations +from typing import NamedTuple, Iterable, Union +from dataclasses import dataclass +import re +import numpy as np + + +_default = "*" +_spin_not_set = "not set" +_begin_spec = "(" +_end_spec = ")" +_seperators = (" ", ",") +_range_separator = "-" +_range = re.compile(r"^(\d+)" + re.escape(_range_separator) + "(\d+)$") +_whitespace_begin_spec = re.compile(r"\s*" + re.escape(_begin_spec) + r"\s*") +_whitespace_end_spec = re.compile(r"\s*" + re.escape(_end_spec) + r"\s*") +_whitespace_range = re.compile(r"\s*" + re.escape(_range_separator) + r"\s*") + + +@dataclass +class _State: + level: int = 0 + part: str = "" + specification: str = "" + complete: bool = False + + +def _split_into_parts(selection): + selection = _cleanup_whitespace(selection) + state = _State() + for char in selection + _seperators[0]: # make sure selection contains termination + state = _update_state(state, char) + if state.complete: + yield state.part, state.specification + + +def _cleanup_whitespace(selection): + selection = _whitespace_begin_spec.sub(_begin_spec, selection) + selection = _whitespace_end_spec.sub(_end_spec + _seperators[0], selection) + return _whitespace_range.sub(_range_separator, selection) + + +def _update_state(state, char): + state.level = _update_level(state, char) + state.part = _update_part(state, char) + state.specification = _update_specification(state, char) + state.complete = _is_state_complete(state, char) + return state + + +def _update_level(state, char): + return state.level + (char == _begin_spec) - (char == _end_spec) + + +def _update_part(state, char): + part = state.part if not state.complete else "" + char_used = char not in (_end_spec, *_seperators) and state.level == 0 + char = char if char_used else "" + return part + char + + +def _update_specification(state, char): + spec_used = not state.complete and (state.level != 1 or char != _begin_spec) + spec = state.specification if spec_used else "" + char_used = spec_used and state.level > 0 + char = char if char_used else "" + return spec + char + + +def _is_state_complete(state, char): + return state.level == 0 and char in _seperators and state.part != "" + + +class Projectors: + class Selection(NamedTuple): + indices: Iterable[int] + label: str = "" + + class Index(NamedTuple): + atom: Union[str, Selection] + orbital: Union[str, Selection] + spin: Union[str, Selection] + + def __init__(self, raw_proj): + self._init_atom_dict(raw_proj) + self._init_orbital_dict(raw_proj) + self._init_spin_dict(raw_proj) + self._spin_polarized = raw_proj.number_spins == 2 + + def _init_atom_dict(self, raw_proj): + num_atoms = np.sum(raw_proj.number_ion_types) + all_atoms = self.Selection(indices=range(num_atoms)) + self._atom_dict = {_default: all_atoms} + start = 0 + for type, number in zip(raw_proj.ion_types, raw_proj.number_ion_types): + type = str(type, "utf-8").strip() + _range = range(start, start + number) + self._atom_dict[type] = self.Selection(indices=_range, label=type) + for i in _range: + # create labels like Si_1, Si_2, Si_3 (starting at 1) + label = type + "_" + str(_range.index(i) + 1) + self._atom_dict[str(i + 1)] = self.Selection(indices=(i,), label=label) + start += number + + def _init_orbital_dict(self, raw_proj): + num_orbitals = len(raw_proj.orbital_types) + all_orbitals = self.Selection(indices=range(num_orbitals)) + self._orbital_dict = {_default: all_orbitals} + for i, orbital in enumerate(raw_proj.orbital_types): + orbital = str(orbital, "utf-8").strip() + self._orbital_dict[orbital] = self.Selection(indices=(i,), label=orbital) + if "px" in self._orbital_dict: + self._orbital_dict["p"] = self.Selection(indices=range(1, 4), label="p") + self._orbital_dict["d"] = self.Selection(indices=range(4, 9), label="d") + self._orbital_dict["f"] = self.Selection(indices=range(9, 16), label="f") + + def _init_spin_dict(self, raw_proj): + num_spins = raw_proj.number_spins + self._spin_dict = { + "up": self.Selection(indices=(0,), label="up"), + "down": self.Selection(indices=(1,), label="down"), + "total": self.Selection(indices=range(num_spins), label="total"), + _default: self.Selection(indices=range(num_spins)), + } + + def select(self, atom=_default, orbital=_default, spin=_default): + return self.Index( + atom=self._select_atom(atom), + orbital=self._orbital_dict[orbital], + spin=self._spin_dict[spin], + ) + + def _select_atom(self, atom): + match = _range.match(atom) + if match: + lower = self._atom_dict[match.groups()[0]].indices[0] + upper = self._atom_dict[match.groups()[1]].indices[0] + return self.Selection(indices=range(lower, upper + 1), label=atom) + else: + return self._atom_dict[atom] + + def parse_selection(self, selection): + default_index = self.Index(atom=_default, orbital=_default, spin=_spin_not_set) + yield from self._parse_recursive(selection, default_index) + + def _parse_recursive(self, selection, current_index): + for part, specification in _split_into_parts(selection): + new_index = self._update_index(current_index, part) + if specification == "": + yield from self._setup_spin_indices(new_index) + else: + yield from self._parse_recursive(specification, new_index) + + def _update_index(self, index, part): + part = part.strip() + if part == _default: + pass + elif part in self._atom_dict or _range.match(part): + index = index._replace(atom=part) + elif part in self._orbital_dict: + index = index._replace(orbital=part) + elif part in self._spin_dict: + index = index._replace(spin=part) + else: + raise KeyError("Could not find " + part + " in the list or projectors.") + return index + + def _setup_spin_indices(self, index): + if index.spin != _spin_not_set: + yield index + elif not self._spin_polarized: + yield index._replace(spin=_default) + else: + for key in ("up", "down"): + yield index._replace(spin=key) diff --git a/src/py4vasp/raw/file.py b/src/py4vasp/raw/file.py index 9156066..9f79a4f 100644 --- a/src/py4vasp/raw/file.py +++ b/src/py4vasp/raw/file.py @@ -12,6 +12,7 @@ def dos(self): energies=self._h5f["results/electron_dos/energies"], dos=self._h5f["results/electron_dos/dos"], projectors=self.projectors(), + projections=self._safe_get_key("results/electron_dos/dospar"), ) def band(self): @@ -24,6 +25,7 @@ def band(self): label_indices=self._safe_get_key("input/kpoints/positions_labels_kpoints"), cell=self.cell(), projectors=self.projectors(), + projections=self._safe_get_key("results/projectors/par"), ) def projectors(self): @@ -33,8 +35,7 @@ def projectors(self): ion_types=self._h5f["results/positions/ion_types"], number_ion_types=self._h5f["results/positions/number_ion_types"], orbital_types=self._h5f["results/projectors/lchar"], - dos=self._h5f["results/electron_dos/dospar"], - bands=self._h5f["results/projectors/par"], + number_spins=self._h5f["results/electron_eigenvalues/ispin"], ) def cell(self): diff --git a/src/py4vasp/raw/rawdata.py b/src/py4vasp/raw/rawdata.py index e26fecc..c3a27ce 100644 --- a/src/py4vasp/raw/rawdata.py +++ b/src/py4vasp/raw/rawdata.py @@ -24,8 +24,7 @@ class Projectors: number_ion_types: np.ndarray ion_types: np.ndarray orbital_types: np.ndarray - dos: np.ndarray - bands: np.ndarray + number_spins: int __eq__ = _dataclass_equal @@ -41,6 +40,7 @@ class Dos: fermi_energy: float energies: np.ndarray dos: np.ndarray + projections: np.ndarray = None projectors: Projectors = None __eq__ = _dataclass_equal @@ -54,5 +54,6 @@ class Band: cell: Cell labels: np.ndarray = np.empty(0, dtype="S") label_indices: np.ndarray = np.empty(0, dtype="int") + projections: np.ndarray = None projectors: Projectors = None __eq__ = _dataclass_equal diff --git a/tests/data/test_band.py b/tests/data/test_band.py index 6ac66ba..86adfc5 100644 --- a/tests/data/test_band.py +++ b/tests/data/test_band.py @@ -124,12 +124,12 @@ def spin_band_structure(): line_length=num_kpoints, kpoints=np.linspace(np.zeros(3), np.ones(3), num_kpoints), eigenvalues=np.arange(size_eig).reshape(shape_eig), + projections=np.random.uniform(low=0.2, size=shape_proj), projectors=raw.Projectors( number_ion_types=[1], ion_types=np.array(["Si"], dtype="S"), orbital_types=np.array(["s"], dtype="S"), - bands=np.random.uniform(low=0.2, size=shape_proj), - dos=None, + number_spins=num_spins, ), fermi_energy=0.0, cell=raw.Cell(scale=1, lattice_vectors=np.eye(3)), @@ -142,8 +142,8 @@ def test_spin_band_structure_read(spin_band_structure): band = Band(raw_band).read("s") assert_allclose(band["up"], raw_band.eigenvalues[0]) assert_allclose(band["down"], raw_band.eigenvalues[1]) - assert_allclose(band["projections"]["s_up"], raw_band.projectors.bands[0, 0]) - assert_allclose(band["projections"]["s_down"], raw_band.projectors.bands[1, 0]) + assert_allclose(band["projections"]["s_up"], raw_band.projections[0, 0]) + assert_allclose(band["projections"]["s_down"], raw_band.projections[1, 0]) def test_spin_band_structure_plot(spin_band_structure): @@ -155,7 +155,7 @@ def test_spin_band_structure_plot(spin_band_structure): for i, (spin, data) in enumerate(zip(spins, fig.data)): assert data.name == "Si_" + spin bands = np.nditer(raw_band.eigenvalues[i]) - weights = np.nditer(raw_band.projectors.bands[i, 0, 0]) + weights = np.nditer(raw_band.projections[i, 0, 0]) for band, weight in zip(bands, weights): upper = band + width * weight lower = band - width * weight @@ -174,12 +174,12 @@ def projected_band_structure(): line_length=num_kpoints, kpoints=np.linspace(np.zeros(3), np.ones(3), num_kpoints), eigenvalues=np.arange(np.prod(shape_eig)).reshape(shape_eig), + projections=np.random.uniform(low=0.2, size=shape_proj), projectors=raw.Projectors( number_ion_types=[1], ion_types=np.array(["Si"], dtype="S"), orbital_types=np.array(["s"], dtype="S"), - bands=np.random.uniform(low=0.2, size=shape_proj), - dos=None, + number_spins=num_spins, ), fermi_energy=0.0, cell=raw.Cell(scale=1, lattice_vectors=np.eye(3)), @@ -189,8 +189,8 @@ def projected_band_structure(): def test_projected_band_structure_read(projected_band_structure): raw_band = projected_band_structure - band = Band(raw_band).read("Si:s") - assert_allclose(band["projections"]["Si_s"], raw_band.projectors.bands[0, 0, 0]) + band = Band(raw_band).read("Si(s)") + assert_allclose(band["projections"]["Si_s"], raw_band.projections[0, 0, 0]) def test_projected_band_structure_plot(projected_band_structure): @@ -208,7 +208,7 @@ def test_projected_band_structure_plot(projected_band_structure): num_NaN_y = np.count_nonzero(np.isnan(data.y)) assert num_NaN_x == num_NaN_y > 0 bands = np.nditer(raw_band.eigenvalues[0]) - weights = np.nditer(raw_band.projectors.bands[0, 0, 0]) + weights = np.nditer(raw_band.projections[0, 0, 0]) for band, weight in zip(bands, weights): upper = band + default_width * weight lower = band - default_width * weight diff --git a/tests/data/test_dos.py b/tests/data/test_dos.py index f61f84f..c1b7029 100644 --- a/tests/data/test_dos.py +++ b/tests/data/test_dos.py @@ -99,7 +99,7 @@ def test_magnetic_Dos_plot(magnetic_Dos): @pytest.fixture -def nonmagnetic_projections(): +def nonmagnetic_projections(nonmagnetic_Dos): """ Setup a l resolved Dos containing all relevant quantities.""" ref = { "Si_s": np.random.random(num_energies), @@ -117,25 +117,25 @@ def nonmagnetic_projections(): number_ion_types=[1, 2], ion_types=np.array(["Si", "C "], dtype="S"), orbital_types=np.array([" s", " p", " d", " f"], dtype="S"), - dos=np.zeros((num_spins, len(atoms), lmax, num_energies)), - bands=None, + number_spins=num_spins, ) + nonmagnetic_Dos.projectors = raw_proj + nonmagnetic_Dos.projections = np.zeros((num_spins, len(atoms), lmax, num_energies)) orbitals = ["s", "p", "d"] for iatom, atom in enumerate(atoms): for l, orbital in enumerate(orbitals): key = atom + "_" + orbital if key in ref: - raw_proj.dos[:, iatom, l] = ref[key] - return raw_proj, ref + nonmagnetic_Dos.projections[:, iatom, l] = ref[key] + return nonmagnetic_Dos, ref -def test_nonmagnetic_l_Dos_to_frame(nonmagnetic_Dos, nonmagnetic_projections): +def test_nonmagnetic_l_Dos_to_frame(nonmagnetic_projections): """ Test whether reading the nonmagnetic l resolved Dos yields the expected results.""" - raw_dos = nonmagnetic_Dos - raw_dos.projectors, ref = nonmagnetic_projections + raw_dos, ref = nonmagnetic_projections equivalent_selections = [ - "s Si:d Si C:s,p 1:p 2 3:s", - "1: p, C : s Si : d, *: s, 2 Si:* C: p 3 : s", + "s Si(d) Si C(s,p) 1(p) 2 3(s)", + "1( p), C(s) Si(d), *(s), 2 Si(*) p(C) s(3)", ] for selection in equivalent_selections: dos = Dos(raw_dos).to_frame(selection) @@ -149,11 +149,10 @@ def test_nonmagnetic_l_Dos_to_frame(nonmagnetic_Dos, nonmagnetic_projections): assert_allclose(dos.C_2_s, ref["C2_s"]) -def test_nonmagnetic_l_Dos_plot(nonmagnetic_Dos, nonmagnetic_projections): +def test_nonmagnetic_l_Dos_plot(nonmagnetic_projections): """ Test whether plotting the nonmagnetic l resolved Dos yields the expected results.""" - raw_dos = nonmagnetic_Dos - raw_dos.projectors, ref = nonmagnetic_projections - selection = "p 3 Si:d" + raw_dos, ref = nonmagnetic_projections + selection = "p 3 Si(d)" fig = Dos(raw_dos).plot(selection) assert len(fig.data) == 4 # total Dos + 3 selections assert_allclose(fig.data[1].y, ref["Si_p"] + ref["C1_p"] + ref["C2_p"]) @@ -162,7 +161,7 @@ def test_nonmagnetic_l_Dos_plot(nonmagnetic_Dos, nonmagnetic_projections): @pytest.fixture -def magnetic_projections(): +def magnetic_projections(magnetic_Dos): """ Setup a lm resolved Dos containing all relevant quantities.""" num_spins = 2 lm_size = 16 @@ -174,22 +173,22 @@ def magnetic_projections(): number_ion_types=[1], ion_types=np.array(["Fe"], dtype="S"), orbital_types=np.array(orbitals, dtype="S"), - dos=np.zeros((num_spins, 1, lm_size, num_energies)), - bands=None, + number_spins=num_spins, ) + magnetic_Dos.projectors = raw_proj + magnetic_Dos.projections = np.zeros((num_spins, 1, lm_size, num_energies)) ref = {} for ispin, spin in enumerate(["up", "down"]): for lm, orbital in enumerate(orbitals): key = orbital.strip() + "_" + spin ref[key] = np.random.random(num_energies) - raw_proj.dos[ispin, :, lm] = ref[key] - return raw_proj, ref + magnetic_Dos.projections[ispin, :, lm] = ref[key] + return magnetic_Dos, ref -def test_magnetic_lm_Dos_read(magnetic_Dos, magnetic_projections): +def test_magnetic_lm_Dos_read(magnetic_projections): """ Test whether reading lm resolved Dos works as expected.""" - raw_dos = magnetic_Dos - raw_dos.projectors, ref = magnetic_projections + raw_dos, ref = magnetic_projections dos = Dos(raw_dos).read("px p d f") assert_allclose(dos["px_up"], ref["px_up"]) assert_allclose(dos["px_down"], ref["px_down"]) @@ -204,8 +203,7 @@ def test_magnetic_lm_Dos_read(magnetic_Dos, magnetic_projections): def test_magnetic_lm_Dos_plot(magnetic_Dos, magnetic_projections): """ Test whether plotting lm resolved Dos works as expected.""" - raw_dos = magnetic_Dos - raw_dos.projectors, ref = magnetic_projections + raw_dos, ref = magnetic_projections fig = Dos(raw_dos).plot("dxz p") data = fig.data assert len(data) == 6 # spin resolved total + 2 selections diff --git a/tests/data/test_projectors.py b/tests/data/test_projectors.py new file mode 100644 index 0000000..158fd2f --- /dev/null +++ b/tests/data/test_projectors.py @@ -0,0 +1,212 @@ +from py4vasp.data import Projectors +import py4vasp.raw as raw +import pytest +import numpy as np +from typing import NamedTuple, Iterable + +Selection = Projectors.Selection +Index = Projectors.Index + + +class SelectionTestCase(NamedTuple): + equivalent_formats: Iterable[str] + reference_selections: Iterable[Index] + + +@pytest.fixture +def without_spin(): + proj = raw.Projectors( + number_ion_types=np.array((2, 1, 4)), + ion_types=np.array(("Sr", "Ti", "O "), dtype="S"), + orbital_types=np.array( + (" s", "py", "pz", "px", "dxy", "dyz", "dz2", "dxz", "x2-y2") + + ("fy3x2", "fxyz", "fyz2", "fz3", "fxz2", "fzx2", "fx3"), + dtype="S", + ), + number_spins=1, + ) + return proj + + +@pytest.fixture +def spin_polarized(without_spin): + without_spin.number_spins = 2 + return without_spin + + +@pytest.fixture +def for_selection(spin_polarized): + index = np.cumsum(spin_polarized.number_ion_types) + ref = { + "atom": { + "Sr": Selection(indices=range(index[0]), label="Sr"), + "Ti": Selection(indices=range(index[0], index[1]), label="Ti"), + "O": Selection(indices=range(index[1], index[2]), label="O"), + "1": Selection(indices=(0,), label="Sr_1"), + "2": Selection(indices=(1,), label="Sr_2"), + "3": Selection(indices=(2,), label="Ti_1"), + "4": Selection(indices=(3,), label="O_1"), + "5": Selection(indices=(4,), label="O_2"), + "6": Selection(indices=(5,), label="O_3"), + "7": Selection(indices=(6,), label="O_4"), + "1-3": Selection(indices=range(0, 3), label="1-3"), + "4-7": Selection(indices=range(3, 7), label="4-7"), + "*": Selection(indices=range(index[-1])), + }, + "orbital": { + "s": Selection(indices=(0,), label="s"), + "px": Selection(indices=(3,), label="px"), + "py": Selection(indices=(1,), label="py"), + "pz": Selection(indices=(2,), label="pz"), + "dxy": Selection(indices=(4,), label="dxy"), + "dxz": Selection(indices=(7,), label="dxz"), + "dyz": Selection(indices=(5,), label="dyz"), + "dz2": Selection(indices=(6,), label="dz2"), + "x2-y2": Selection(indices=(8,), label="x2-y2"), + "fxyz": Selection(indices=(10,), label="fxyz"), + "fxz2": Selection(indices=(13,), label="fxz2"), + "fx3": Selection(indices=(15,), label="fx3"), + "fyz2": Selection(indices=(11,), label="fyz2"), + "fy3x2": Selection(indices=(9,), label="fy3x2"), + "fzx2": Selection(indices=(14,), label="fzx2"), + "fz3": Selection(indices=(12,), label="fz3"), + "p": Selection(indices=range(1, 4), label="p"), + "d": Selection(indices=range(4, 9), label="d"), + "f": Selection(indices=range(9, 16), label="f"), + "*": Selection(indices=range(len(spin_polarized.orbital_types))), + }, + "spin": { + "up": Selection(indices=(0,), label="up"), + "down": Selection(indices=(1,), label="down"), + "total": Selection( + indices=range(spin_polarized.number_spins), label="total" + ), + "*": Selection(indices=range(spin_polarized.number_spins)), + }, + } + return Projectors(spin_polarized), ref + + +def test_selection(for_selection): + proj, ref = for_selection + default = Index(ref["atom"]["*"], ref["orbital"]["*"], ref["spin"]["*"]) + for atom, ref_atom in ref["atom"].items(): + assert proj.select(atom=atom) == default._replace(atom=ref_atom) + for orbital, ref_orbital in ref["orbital"].items(): + assert proj.select(orbital=orbital) == default._replace(orbital=ref_orbital) + for spin, ref_spin in ref["spin"].items(): + assert proj.select(spin=spin) == default._replace(spin=ref_spin) + + +@pytest.fixture +def for_parse_selection(without_spin): + testcases = ( + SelectionTestCase( + equivalent_formats=("Sr", "Sr(*)"), + reference_selections=(Index(atom="Sr", orbital="*", spin="*"),), + ), + SelectionTestCase( + equivalent_formats=("Ti(s,p)", "Ti (s) Ti (p)"), + reference_selections=( + Index(atom="Ti", orbital="s", spin="*"), + Index(atom="Ti", orbital="p", spin="*"), + ), + ), + SelectionTestCase( + equivalent_formats=("Ti 5", "Ti( * ), 5( * )"), + reference_selections=( + Index(atom="Ti", orbital="*", spin="*"), + Index(atom="5", orbital="*", spin="*"), + ), + ), + SelectionTestCase( + equivalent_formats=("p, d", "*(p) *(d)"), + reference_selections=( + Index(atom="*", orbital="p", spin="*"), + Index(atom="*", orbital="d", spin="*"), + ), + ), + SelectionTestCase( + equivalent_formats=("O(d), 1 s", "O(d), 1(*), *(s)"), + reference_selections=( + Index(atom="O", orbital="d", spin="*"), + Index(atom="1", orbital="*", spin="*"), + Index(atom="*", orbital="s", spin="*"), + ), + ), + SelectionTestCase( + equivalent_formats=("Sr(p)Ti(s)O(s)", "p(Sr) s(Ti, O)"), + reference_selections=( + Index(atom="Sr", orbital="p", spin="*"), + Index(atom="Ti", orbital="s", spin="*"), + Index(atom="O", orbital="s", spin="*"), + ), + ), + SelectionTestCase( + equivalent_formats=("1 - 4", "1-4", " 1 - 4 "), + reference_selections=(Index(atom="1-4", orbital="*", spin="*"),), + ), + ) + return Projectors(without_spin), testcases + + +@pytest.fixture +def for_spin_polarized_parse_selection(spin_polarized): + testcases = ( + SelectionTestCase( + equivalent_formats=("Sr", "Sr(up,down)", "Sr(*(up)), Sr(down)"), + reference_selections=( + Index(atom="Sr", orbital="*", spin="up"), + Index(atom="Sr", orbital="*", spin="down"), + ), + ), + SelectionTestCase( + equivalent_formats=("Ti( s(up) p(down) )", "Ti(s(up))Ti(p(down))"), + reference_selections=( + Index(atom="Ti", orbital="s", spin="up"), + Index(atom="Ti", orbital="p", spin="down"), + ), + ), + SelectionTestCase( + equivalent_formats=("s p(up) d(total)", "s(up, down), *(p(up)), d(total)"), + reference_selections=( + Index(atom="*", orbital="s", spin="up"), + Index(atom="*", orbital="s", spin="down"), + Index(atom="*", orbital="p", spin="up"), + Index(atom="*", orbital="d", spin="total"), + ), + ), + SelectionTestCase( + equivalent_formats=("up (s) down (p, d)", "s(up) p(down) d(down)"), + reference_selections=( + Index(atom="*", orbital="s", spin="up"), + Index(atom="*", orbital="p", spin="down"), + Index(atom="*", orbital="d", spin="down"), + ), + ), + SelectionTestCase( + equivalent_formats=("2( px(up) )", "px(2(up))", "up(2(px))"), + reference_selections=(Index(atom="2", orbital="px", spin="up"),), + ), + SelectionTestCase( + equivalent_formats=("3-4(up)", "up (3 - 4)"), + reference_selections=(Index(atom="3-4", orbital="*", spin="up"),), + ), + ) + return Projectors(spin_polarized), testcases + + +def test_parse_selection(for_parse_selection): + run_parse_selection(for_parse_selection) + + +def test_spin_polarized_parse_selection(for_spin_polarized_parse_selection): + run_parse_selection(for_spin_polarized_parse_selection) + + +def run_parse_selection(setup): + proj, testcases = setup + for testcase in testcases: + for format in testcase.equivalent_formats: + selections = proj.parse_selection(format) + assert list(selections) == list(testcase.reference_selections) diff --git a/tests/raw/test_file.py b/tests/raw/test_file.py index a8c6508..7b3f398 100644 --- a/tests/raw/test_file.py +++ b/tests/raw/test_file.py @@ -13,6 +13,8 @@ num_energies = 20 num_kpoints = 10 num_bands = 3 +num_atoms = 10 # sum(range(5)) +lmax = 3 fermi_energy = 0.123 SetupTest = namedtuple( @@ -80,8 +82,10 @@ def write_dos(h5f, dos): h5f["results/electron_dos/efermi"] = dos.fermi_energy h5f["results/electron_dos/energies"] = dos.energies h5f["results/electron_dos/dos"] = dos.dos - if dos.projectors: + if dos.projectors is not None: write_projectors(h5f, dos.projectors) + if dos.projections is not None: + h5f["results/electron_dos/dospar"] = proj.dos def test_band(tmpdir): @@ -96,17 +100,21 @@ def test_band(tmpdir): def reference_band(use_projectors, use_labels): - shape = (num_spins, num_kpoints, num_bands) - return raw.Band( + shape_eval = (num_spins, num_kpoints, num_bands) + shape_proj = (num_spins, num_atoms, lmax, num_kpoints, num_bands) + band = raw.Band( fermi_energy=fermi_energy, line_length=num_kpoints, kpoints=np.linspace(np.zeros(3), np.ones(3), num_kpoints), - eigenvalues=np.arange(np.prod(shape)).reshape(shape), + eigenvalues=np.arange(np.prod(shape_eval)).reshape(shape_eval), cell=reference_cell(), labels=np.array(["G", "X"], dtype="S") if use_labels else None, label_indices=[0, 1] if use_labels else None, - projectors=reference_projectors() if use_projectors else None, ) + if use_projectors: + band.projectors = reference_projectors() + band.projections = np.arange(np.prod(shape_proj)).reshape(shape_proj) + return band def write_band(h5f, band): @@ -121,6 +129,8 @@ def write_band(h5f, band): h5f["input/kpoints/labels_kpoints"] = band.labels if band.projectors: write_projectors(h5f, band.projectors) + if band.projections is not None: + h5f["results/projectors/par"] = band.projections def test_projectors(tmpdir): @@ -135,16 +145,12 @@ def test_projectors(tmpdir): def reference_projectors(): - num_atoms = 10 # sum(range(5)) - lmax = 3 shape_dos = (num_spins, num_atoms, lmax, num_energies) - shape_bands = (num_spins, num_atoms, lmax, num_kpoints, num_bands) return raw.Projectors( number_ion_types=np.arange(5), ion_types=np.array(["B", "C", "N", "O", "F"], dtype="S"), orbital_types=np.array(["s", "p", "d", "f"], dtype="S"), - dos=np.arange(np.prod(shape_dos)).reshape(shape_dos), - bands=np.arange(np.prod(shape_bands)).reshape(shape_bands), + number_spins=num_spins, ) @@ -152,8 +158,7 @@ def write_projectors(h5f, proj): h5f["results/positions/number_ion_types"] = proj.number_ion_types h5f["results/positions/ion_types"] = proj.ion_types h5f["results/projectors/lchar"] = proj.orbital_types - h5f["results/electron_dos/dospar"] = proj.dos - h5f["results/projectors/par"] = proj.bands + h5f["results/electron_eigenvalues/ispin"] = proj.number_spins def test_cell(tmpdir):