Skip to content

Commit

Permalink
Update closest_site_info to be robust for very very weird structure…
Browse files Browse the repository at this point in the history
… inputs (e.g. mp-674158, mp-1208561)
  • Loading branch information
kavanase committed Nov 21, 2024
1 parent 1a28d56 commit 6041506
Show file tree
Hide file tree
Showing 12 changed files with 111 additions and 76 deletions.
2 changes: 1 addition & 1 deletion doped/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -2175,7 +2175,7 @@ def __eq__(self, other) -> bool:
loose ``stol`` used in ``pymatgen-analysis-defects``) and
much more efficient.
"""
if not isinstance(other, (self, core.Defect)):
if not isinstance(other, (type(self), core.Defect)):
raise TypeError("Can only compare `Defect`s with `Defect`s!")

if self.defect_type != other.defect_type:
Expand Down
143 changes: 86 additions & 57 deletions doped/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
calculations.
"""

import contextlib
import copy
import logging
import operator
Expand All @@ -28,7 +29,7 @@
from pymatgen.core import IStructure, Structure
from pymatgen.core.composition import Composition, Element
from pymatgen.core.periodic_table import DummySpecies
from pymatgen.core.structure import PeriodicSite
from pymatgen.core.structure import PeriodicNeighbor, PeriodicSite
from pymatgen.entries.computed_entries import ComputedStructureEntry
from pymatgen.transformations.advanced_transformations import CubicSupercellTransformation
from pymatgen.util.typing import PathLike
Expand Down Expand Up @@ -199,31 +200,25 @@ def closest_site_info(
be at least 0.02 Å further away than the n-1th site.
"""
if isinstance(defect_entry_or_defect, (DefectEntry, thermo.DefectEntry)):
defect = defect_entry_or_defect.defect
# use defect_supercell_site if attribute exists, otherwise use sc_defect_frac_coords:
defect_supercell_site = parsing._get_defect_supercell_site(defect_entry_or_defect)
defect_supercell = parsing._get_defect_supercell(defect_entry_or_defect)
site = None
with contextlib.suppress(Exception):
defect = defect_entry_or_defect.defect
site = defect.site
structure = defect.structure
if site is None:
# use defect_supercell_site if attribute exists, otherwise use sc_defect_frac_coords:
site = parsing._get_defect_supercell_site(defect_entry_or_defect)
structure = parsing._get_bulk_supercell(defect_entry_or_defect)

elif isinstance(defect_entry_or_defect, (Defect, core.Defect)):
if isinstance(defect_entry_or_defect, core.Defect):
defect = doped_defect_from_pmg_defect(defect_entry_or_defect) # convert to doped Defect
else:
defect = defect_entry_or_defect

req_sc_mat = np.eye(3) * np.ceil((5 * np.sqrt(n)) / min(defect.defect_structure.lattice.abc))
if np.all(req_sc_mat == np.eye(3)): # just defect is fine
defect_supercell = defect.defect_structure
defect_supercell_site = defect.site
site = defect.site
structure = defect.structure

else:
(
defect_supercell,
defect_supercell_site,
_equivalent_supercell_sites,
) = defect.get_supercell_structure(
sc_mat=req_sc_mat,
return_sites=True,
)
else:
raise TypeError(
f"defect_entry_or_defect must be a DefectEntry or Defect object, not "
Expand All @@ -233,43 +228,68 @@ def closest_site_info(
if element_list is None:
element_list = _get_element_list(defect)

distance_matrix = defect_supercell.lattice.get_all_distances(
defect_supercell.frac_coords,
defect_supercell_site.frac_coords,
)
if distance_matrix.shape[1] == 1: # Check if it is (X, 1)
distance_matrix = distance_matrix.ravel()

# ensure the defect site itself is excluded, and ignore sites further than 5*sqrt(n) Å away
possible_close_site_indices = np.where((distance_matrix > 0.05) & (distance_matrix < 5 * np.sqrt(n)))[
0
]

site_distances = sorted( # Could make this faster using caching if it was becoming a bottleneck
[
(
distance_matrix[i],
defect_supercell.sites[i].specie.symbol,
)
for i in possible_close_site_indices
],
key=lambda x: (symmetry._custom_round(x[0], 2), _list_index_or_val(element_list, x[1]), x[1]),
)

# prune site_distances to remove any tuples with distances within 0.02 Å of the previous entry:
site_distances = [
site_distances[i]
for i in range(len(site_distances))
if i == 0
or abs(site_distances[i][0] - site_distances[i - 1][0]) > 0.02
or site_distances[i][1] != site_distances[i - 1][1]
]
def _get_site_distances_and_symbols(
site: PeriodicSite,
structure: Structure,
n: int,
element_list: list[str],
dist_tol_prefactor: float = 3.0,
):
"""
Get a list of sorted tuples of (distance, element) for the closest
sites to the input site in the input structure, and the last used
``dist_tol_prefactor``.
Dynamically increases ``dist_tol_prefactor`` until at least one other
site is found within the distance tolerance. Function defined and used
here to allow dynamic upscaling of the distance tolerance for weird
structures (e.g. mp-674158, mp-1208561).
"""
neighbours: list[PeriodicNeighbor] = []
while not neighbours: # for efficiency, ignore sites further than dist_tol*sqrt(n) Å away:
neighbours_w_site_itself = structure.get_sites_in_sphere(
site.coords, dist_tol_prefactor * np.sqrt(n)
) # exclude defect site itself:
neighbours = sorted(neighbours_w_site_itself, key=lambda x: x.nn_distance)[1:]
dist_tol_prefactor += 0.5 # increase the distance tolerance if no other sites are found
if dist_tol_prefactor > 40:
warnings.warn(
"No other sites found within 40*sqrt(n) Å of the defect site, indicating a very "
"weird structure..."
)
break
if not neighbours:
return [], dist_tol_prefactor

if site_distances:
min_distance, closest_site = site_distances[n - 1]
return f"{closest_site}{symmetry._custom_round(min_distance, 2):.2f}"
site_distances = sorted( # Could make this faster using caching if it was becoming a bottleneck
[
(
neigh.nn_distance,
neigh.specie.symbol,
)
for neigh in neighbours
],
key=lambda x: (symmetry._custom_round(x[0], 2), _list_index_or_val(element_list, x[1]), x[1]),
)
return [ # prune site_distances to remove any with distances within 0.02 Å of the previous n:
site_distances[i]
for i in range(len(site_distances))
if i == 0
or abs(site_distances[i][0] - site_distances[i - 1][0]) > 0.02
or site_distances[i][1] != site_distances[i - 1][1]
], dist_tol_prefactor

site_distances, dist_tol_prefactor = _get_site_distances_and_symbols(site, structure, n, element_list)
while len(site_distances) < n:
if dist_tol_prefactor > 40:
return "" # already warned

site_distances, dist_tol_prefactor = _get_site_distances_and_symbols(
site, structure, n, element_list, dist_tol_prefactor + 2
)

return "" # hypothetical case of very weird structure with no sites within 5*sqrt(n) Å...
min_distance, closest_site = site_distances[n - 1]
return f"{closest_site}{symmetry._custom_round(min_distance, 2):.2f}"


def get_defect_name_from_defect(
Expand Down Expand Up @@ -2497,12 +2517,21 @@ def get_Voronoi_interstitial_sites(
f"only the following keys are supported: {supported_interstitial_gen_kwargs}"
)
top = DopedTopographyAnalyzer(host_structure)
if not top.vnodes:
warnings.warn("No interstitial sites found in host structure!")
return []

sites_list = [v.frac_coords for v in top.vnodes]
sites_list = remove_collisions(
sites_list, structure=host_structure, min_dist=interstitial_gen_kwargs.get("min_dist", 0.9)
)
min_dist = interstitial_gen_kwargs.get("min_dist", 0.9)
sites_array = remove_collisions(sites_list, structure=host_structure, min_dist=min_dist)
if sites_array.size == 0:
warnings.warn(
f"No interstitial sites found after removing those within {min_dist} Å of host atoms!"
)
return []

site_frac_coords_array: np.array = _doped_cluster_frac_coords(
sites_list,
sites_array,
host_structure,
tol=interstitial_gen_kwargs.get("clustering_tol", 0.55),
symmetry_preference=interstitial_gen_kwargs.get("symmetry_preference", 0.1),
Expand Down
6 changes: 5 additions & 1 deletion doped/utils/efficiency.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,8 +457,12 @@ def _doped_cluster_frac_coords(
Returns:
np.typing.NDArray: Clustered fractional coordinates
"""
if len(fcoords) <= 1:
if len(fcoords) == 0:
return None
if len(fcoords) == 1:
return symmetry._vectorized_custom_round(
np.mod(symmetry._vectorized_custom_round(fcoords, 5), 1), 4
) # to unit cell

lattice = structure.lattice
sga = symmetry.get_sga(structure, symprec=0.1) # for getting symmetries of different sites
Expand Down
2 changes: 1 addition & 1 deletion tests/data/CdTe_defect_gen.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/data/N_diamond_defect_gen.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/data/agcu_defect_gen.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/data/cd_i_supercell_defect_gen.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/data/cu_defect_gen.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/data/lmno_defect_gen.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/data/ytos_defect_gen.json

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/data/ytos_defect_gen_supercell.json

Large diffs are not rendered by default.

20 changes: 11 additions & 9 deletions tests/test_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -736,23 +736,23 @@ def setUp(self):
Interstitials Guessed Charges Conv. Cell Coords Wyckoff
--------------- --------------------------- ------------------- ---------
Si_i_C1_Sb2.48 [+4,+3,+2,+1,0] [0.347,0.348,0.457] 18f
Si_i_C1_Si2.21 [+4,+3,+2,+1,0] [0.158,0.359,0.167] 18f
Si_i_C1_Si2.48 [+4,+3,+2,+1,0] [0.347,0.348,0.457] 18f
Si_i_C1_Te2.44 [+4,+3,+2,+1,0] [0.001,0.336,0.289] 18f
Si_i_C3_Sb2.41 [+4,+3,+2,+1,0] [0.000,0.000,0.050] 6c
Si_i_C3_Si2.64 [+4,+3,+2,+1,0] [0.000,0.000,0.318] 6c
Si_i_C3_Te2.41 [+4,+3,+2,+1,0] [0.000,0.000,0.050] 6c
Si_i_C3i_Te2.81 [+4,+3,+2,+1,0] [0.000,0.000,0.000] 3a
Sb_i_C1_Sb2.48 [+5,+4,+3,+2,+1,0,-1,-2,-3] [0.347,0.348,0.457] 18f
Sb_i_C1_Si2.21 [+5,+4,+3,+2,+1,0,-1,-2,-3] [0.158,0.359,0.167] 18f
Sb_i_C1_Si2.48 [+5,+4,+3,+2,+1,0,-1,-2,-3] [0.347,0.348,0.457] 18f
Sb_i_C1_Te2.44 [+5,+4,+3,+2,+1,0,-1,-2,-3] [0.001,0.336,0.289] 18f
Sb_i_C3_Sb2.41 [+5,+4,+3,+2,+1,0,-1,-2,-3] [0.000,0.000,0.050] 6c
Sb_i_C3_Si2.64 [+5,+4,+3,+2,+1,0,-1,-2,-3] [0.000,0.000,0.318] 6c
Sb_i_C3_Te2.41 [+5,+4,+3,+2,+1,0,-1,-2,-3] [0.000,0.000,0.050] 6c
Sb_i_C3i_Te2.81 [+5,+4,+3,+2,+1,0,-1,-2,-3] [0.000,0.000,0.000] 3a
Te_i_C1_Sb2.48 [+4,+3,+2,+1,0,-1,-2] [0.347,0.348,0.457] 18f
Te_i_C1_Si2.21 [+4,+3,+2,+1,0,-1,-2] [0.158,0.359,0.167] 18f
Te_i_C1_Si2.48 [+4,+3,+2,+1,0,-1,-2] [0.347,0.348,0.457] 18f
Te_i_C1_Te2.44 [+4,+3,+2,+1,0,-1,-2] [0.001,0.336,0.289] 18f
Te_i_C3_Sb2.41 [+4,+3,+2,+1,0,-1,-2] [0.000,0.000,0.050] 6c
Te_i_C3_Si2.64 [+4,+3,+2,+1,0,-1,-2] [0.000,0.000,0.318] 6c
Te_i_C3_Te2.41 [+4,+3,+2,+1,0,-1,-2] [0.000,0.000,0.050] 6c
Te_i_C3i_Te2.81 [+4,+3,+2,+1,0,-1,-2] [0.000,0.000,0.000] 3a
\n"""
"The number in the Wyckoff label is the site multiplicity/degeneracy of that defect "
Expand Down Expand Up @@ -854,7 +854,9 @@ def _general_defect_gen_check(self, defect_gen, charge_states_removed=False):
set(Poscar(structure).site_symbols)
) # no duplicates

assert np.isclose(defect_gen.min_image_distance, get_min_image_distance(defect_gen.bulk_supercell))
assert np.isclose(
defect_gen.min_image_distance, get_min_image_distance(defect_gen.bulk_supercell), atol=1e-2
)

print("Checking Defect/DefectEntry types")
assert all(defect.defect_type == DefectType.Vacancy for defect in defect_gen.defects["vacancies"])
Expand Down Expand Up @@ -1854,7 +1856,7 @@ def test_supercell_gen_kwargs(self):
) # gives 4x conventional cell
assert self.CdTe_defect_gen_info in output
self._general_defect_gen_check(CdTe_defect_gen)
assert CdTe_defect_gen.min_image_distance == 26.1626
assert np.isclose(CdTe_defect_gen.min_image_distance, 26.1626, atol=1e-2)
assert len(CdTe_defect_gen.bulk_supercell) == 512
assert CdTe_defect_gen.supercell_gen_kwargs["min_image_distance"] == 20
assert CdTe_defect_gen.supercell_gen_kwargs["force_cubic"] is True
Expand Down Expand Up @@ -3372,7 +3374,7 @@ def Sn5O6_defect_gen_check(self, defect_gen, manual_oxi=False):
"v_O_C1_Sn2.09 [+2,+1,0,-1] [0.642,0.323,0.461] 4e",
"Sn_Sn_C1_O2.08O2.11O2.14b [+2,+1,0,-1] [0.101,0.502,0.319] 4e",
"Sn_Sn_C1_O2.08Sn3.28O3.69a [+1,0,-1] [0.500,0.500,0.500] 2b",
"Sn_i_C1_Sn2.33O2.33O2.39d [+4,+3,+2,+1,0] [0.273,0.460,0.248] 4e",
"Sn_i_C1_Sn2.33O2.39O2.60d [+4,+3,+2,+1,0] [0.273,0.460,0.248] 4e",
"O_i_C1_O1.83Sn1.99Sn2.09a [0,-1,-2] [0.567,0.320,0.205] 4e",
]
assert set(defect_gen._bulk_oxi_states.composition.elements) == {
Expand Down

0 comments on commit 6041506

Please sign in to comment.