Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add support atomwise rmax for sktb module #209

Open
wants to merge 62 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
62 commits
Select commit Hold shift + click to select a range
7cf9323
feat: add logging for self-interaction warnings and implement tests f…
QG-phy Sep 14, 2024
4feaf4e
feat: add support for atom-wise rs rmax setting in deeptb sk modula.
QG-phy Sep 20, 2024
453a86a
Merge branch 'deepmodeling:main' into multi_rmax
QG-phy Sep 20, 2024
a0524a8
remove check atomic_number shape in get_r_map
QG-phy Sep 20, 2024
3706df0
Merge branch 'multi_rmax' of https://github.com/QG-phy/DeePTB into mu…
QG-phy Sep 20, 2024
60d2b89
add Covalent_radii database
QG-phy Sep 23, 2024
241f4c7
Add support for basis input in the format : ['s','p','d'], previous o…
QG-phy Sep 24, 2024
955c658
fix test error when add the support for basis input format of ['s','p…
QG-phy Sep 24, 2024
93c8dbd
change the unit of bond_length_list from bohr to \AA
QG-phy Sep 24, 2024
a5c2722
add test_Covalent_radii
QG-phy Sep 24, 2024
2514fbd
add some comment in dftbsk
QG-phy Sep 24, 2024
e074b38
add BondLenCovalent
QG-phy Sep 24, 2024
cd66749
add new onsite formula uniform_noref
QG-phy Sep 24, 2024
5a98af3
remove the unit transition from bohr to \AA since gthe bondlenth are …
QG-phy Sep 24, 2024
cced890
rename BondLenCovalent to R_cov_list
QG-phy Sep 24, 2024
c48612a
chore: Update test_SKHamiltonian to use torch.allclose with specified…
QG-phy Sep 24, 2024
87cf70a
Refactor AtomicData.py to use environment variable for error handling
QG-phy Sep 27, 2024
ab6ed11
add atomic radius in bondlengthDB
QG-phy Sep 27, 2024
803da72
Refactor covalent radii database to use \AA as unit
QG-phy Sep 27, 2024
7409e6d
feat: Calculate minimum and maximum atomic radii based on skdata
QG-phy Sep 27, 2024
d34a7b1
add new dftb2nnsk
QG-phy Oct 2, 2024
4ebaa77
add new dftb2nnsk
QG-phy Oct 9, 2024
ce17523
feat: Update SKParam class to update atomic radii and format_skparam…
QG-phy Oct 9, 2024
f5d0e09
feat: Add range check for bond distance in HoppingIntp
QG-phy Oct 9, 2024
6518591
feat: Add poly4pow hopping formula to HoppingFormula class and simpli…
QG-phy Oct 10, 2024
f2dc9ee
update dftb2nnsk to accept large model
QG-phy Oct 10, 2024
6d9b212
delete temp
QG-phy Oct 10, 2024
b320681
add get_rmap for bond_wise cutoff seting style
QG-phy Oct 11, 2024
6a2a4a2
feat: Add poly3exp and poly4exp hopping formulas to HoppingFormula class
QG-phy Oct 11, 2024
7455959
feat: update AtomicData.neighbor_list_and_relative_vec to support bon…
QG-phy Oct 11, 2024
c499fdc
feat:add support bond-wise rmax in nnsk
QG-phy Oct 11, 2024
4a9cf95
feat: Update NNSK class to support bond-wise rs values
QG-phy Oct 11, 2024
0fd4422
feat: Update AtomicData.neighbor_list_and_relative_vec to support bon…
QG-phy Oct 11, 2024
78ac760
feat: Add support for bond-wise rmax in build_dataset tests
QG-phy Oct 11, 2024
bec788d
test:update some tests
QG-phy Oct 11, 2024
afec4a4
feat: add bondwise cal rmax and rim in skparam
QG-phy Oct 14, 2024
bb46aab
use bondwise rmin and rmax in skparam and dftb2nnsk
QG-phy Oct 15, 2024
6feeb39
Merge branch 'deepmodeling:main' into multi_rmax
QG-phy Oct 15, 2024
23d27b2
Refactor bondwise rmax calculation in sk_param.py and dftb2nnsk.py
QG-phy Oct 15, 2024
47f75fc
Refactor DFTB2NNSK add save and load model
QG-phy Oct 18, 2024
17e53fd
Refactor NRL_OVERLAP0 and NRL_OVERLAP1 in hopping.py to use torch.Ten…
QG-phy Oct 20, 2024
df83035
Refactor get_lr_scheduler to add support for cosine annealing learnin…
QG-phy Oct 20, 2024
4d2103a
Refactor optimise funciton in dftb2nnsk.py
QG-phy Oct 20, 2024
d86a51d
Refactor collectskf.py and main.py to add support for converting sk f…
QG-phy Oct 20, 2024
5e8afd9
Refactor argcheck.py to add support for RMSprop and LBFGS optimizers …
QG-phy Oct 20, 2024
1a9bf31
Refactor argcheck.py to remove duplicate formulas in the hopping func…
QG-phy Oct 20, 2024
5149dea
update test for dftb2nnsk
QG-phy Oct 20, 2024
7aec1f1
Refactor argcheck.py to remove duplicate formulas in the hopping func…
QG-phy Oct 20, 2024
cd7407f
rename functype to method in dftb2nnsk and fix bug in argcheck.
QG-phy Oct 20, 2024
9aca6e8
Add output inputpara
QG-phy Oct 20, 2024
fb95814
add bond integrl plot viz
QG-phy Oct 20, 2024
fdd5e96
add bond integrl plot viz
QG-phy Oct 20, 2024
3cfe8b1
Refactor NNSK class to include support for different atomic radius op…
QG-phy Oct 20, 2024
b0044ca
Refactor atomic radius initialization in DFTB2NNSK and NNSK classes
QG-phy Oct 20, 2024
000f9d8
fix: the support for uniform_noref onsite mode in nnsk
QG-phy Oct 21, 2024
871c711
Refactor OnsiteFormula class in onsite.py
QG-phy Oct 23, 2024
0c5345d
Refactor save method in NNSK class add device and dtype in dftb2nnsk…
QG-phy Oct 23, 2024
0038aea
Refactor ElecStruCal class to handle atomic radius options and fix a …
QG-phy Oct 23, 2024
3568650
Refactor get_cutoffs_from_model_options function to handle the case w…
QG-phy Oct 23, 2024
bc5dcdf
add example of dftb2nnsk
QG-phy Oct 23, 2024
3aff70e
test: add skf2nnsk example
QG-phy Oct 23, 2024
2fa3d22
Refactor load method in DFTB2NNSK class to include an optional output…
QG-phy Oct 24, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 99 additions & 11 deletions dptb/data/AtomicData.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from .util import _TORCH_INTEGER_DTYPES
from dptb.utils.torch_geometric.data import Data
from dptb.utils.constants import atomic_num_dict
import logging

log = logging.getLogger(__name__)

# A type representing ASE-style periodic boundary condtions, which can be partial (the tuple case)
PBC = Union[bool, Tuple[bool, bool, bool]]
Expand Down Expand Up @@ -874,11 +877,10 @@ def without_nodes(self, which_nodes):
return type(self)(**new_dict)


_ERROR_ON_NO_EDGES: bool = os.environ.get("NEQUIP_ERROR_ON_NO_EDGES", "true").lower()
assert _ERROR_ON_NO_EDGES in ("true", "false")
_ERROR_ON_NO_EDGES = os.environ.get("NEQUIP_ERROR_ON_NO_EDGES", "true").lower()
assert _ERROR_ON_NO_EDGES in ("true", "false"), "NEQUIP_ERROR_ON_NO_EDGES must be 'true' or 'false'"
_ERROR_ON_NO_EDGES = _ERROR_ON_NO_EDGES == "true"


def neighbor_list_and_relative_vec(
pos,
r_max,
Expand Down Expand Up @@ -1026,6 +1028,13 @@ def neighbor_list_and_relative_vec(
# so, only when key_rev is not in the dict, we keep the bond. that is when rev_dict.get(key_rev, False) is False, we set o_mast = True.
if not (rev_dict.get(key_rev, False) and rev_dict.get(key, False)):
o_mask[i] = True

if self_interaction:
log.warning("self_interaction is True, but usually we do not want the self-interaction, please check if it is correct.")
# for self-interaction, the above will remove the self-interaction, i.e. i == j, shift == [0, 0, 0]. since -0 = 0.
if (o_shift[i] == np.array([0, 0, 0])).all():
o_mask[i] = True

del rev_dict
del o_first_idex
del o_second_idex
Expand All @@ -1038,6 +1047,7 @@ def neighbor_list_and_relative_vec(
shifts = torch.as_tensor(shifts[mask], dtype=out_dtype, device=out_device)

if not reduce:
assert self_interaction == False, "for self_interaction = True, i i 0 0 0 will be duplicated."
first_idex, second_idex = torch.cat((first_idex, second_idex), dim=0), torch.cat((second_idex, first_idex), dim=0)
shifts = torch.cat((shifts, -shifts), dim=0)

Expand All @@ -1049,7 +1059,7 @@ def neighbor_list_and_relative_vec(
# TODO: mask the edges that is larger than r_max
if mask_r:
edge_vec = pos[edge_index[1]] - pos[edge_index[0]]
if cell is not None:
if cell is not None :
edge_vec = edge_vec + torch.einsum(
"ni,ij->nj",
shifts,
Expand All @@ -1058,17 +1068,36 @@ def neighbor_list_and_relative_vec(

edge_length = torch.linalg.norm(edge_vec, dim=-1)

atom_species_num = [atomic_num_dict[k] for k in r_max.keys()]
for i in set(atomic_numbers):
assert i in atom_species_num
r_map = torch.zeros(max(atom_species_num))
for k, v in r_max.items():
r_map[atomic_num_dict[k]-1] = v
edge_length_max = 0.5 * (r_map[atomic_numbers[edge_index[0]]-1] + r_map[atomic_numbers[edge_index[1]]-1])
# atom_species_num = [atomic_num_dict[k] for k in r_max.keys()]
# for i in set(atomic_numbers):
# assert i in atom_species_num
# r_map = torch.zeros(max(atom_species_num))
# for k, v in r_max.items():
# r_map[atomic_num_dict[k]-1] = v

first_key = next(iter(r_max.keys()))
key_parts = first_key.split("-")

if len(key_parts)==1:
r_map = get_r_map(r_max, atomic_numbers)
edge_length_max = 0.5 * (r_map[atomic_numbers[edge_index[0]]-1] + r_map[atomic_numbers[edge_index[1]]-1])

elif len(key_parts)==2:
r_map = get_r_map_bondwise(r_max, atomic_numbers)
edge_length_max = r_map[atomic_numbers[edge_index[0]]-1,atomic_numbers[edge_index[1]]-1]
else:
raise ValueError("The r_max keys should be either atomic number or atomic number pair.")

r_mask = edge_length <= edge_length_max
if any(~r_mask):
edge_index = edge_index[:, r_mask]
shifts = shifts[r_mask]
# 收集不同类型的边及其对应的最大截断半径
#edge_types = {}
#for i in range(edge_index.shape[1]):
# atom_type_pair = (atomic_numbers[edge_index[0, i]], atomic_numbers[edge_index[1, i]])
# if atom_type_pair not in edge_types:
# edge_types[atom_type_pair] = edge_length_max[i].item()

del edge_length
del edge_vec
Expand All @@ -1077,3 +1106,62 @@ def neighbor_list_and_relative_vec(
del r_mask

return edge_index, shifts, cell_tensor

def get_r_map(r_max: dict, atomic_numbers=None):
"""
Returns a torch tensor representing the mapping of atomic species to their maximum distances.

Args:
r_max (dict): A dictionary mapping atomic species to their maximum distances.
atomic_numbers (list, optional): A list of atomic numbers to validate against the atomic species. Defaults to None.

Returns:
torch.Tensor: A torch tensor representing the mapping of atomic species to their maximum distances.
"""
atom_species_num = [atomic_num_dict[k] for k in r_max.keys()]
if atomic_numbers is not None:
for i in atomic_numbers:
assert i in atom_species_num
r_map = torch.zeros(max(atom_species_num))
for k, v in r_max.items():
r_map[atomic_num_dict[k]-1] = v
return r_map

def get_r_map_bondwise(r_max:dict, atomic_numbers=None):
"""
Calculate the bondwise distance map based on the maximum bond length dictionary.

Args:
r_max (dict): A dictionary containing the maximum bond lengths for different atom pairs.
atomic_numbers (list, optional): A list of atomic numbers. Defaults to None.

Returns:
torch.Tensor: A torch tensor representing the bondwise distance map.
"""
atom_species_num = []
for k in r_max.keys():
assert len(k.split('-')) == 2
atom_a, atom_b = k.split('-')
if atom_a not in atom_species_num:
atom_species_num.append(atomic_num_dict[atom_a])
if atom_b not in atom_species_num:
atom_species_num.append(atomic_num_dict[atom_b])

if atomic_numbers is not None:
for i in atomic_numbers:
assert i in atom_species_num

r_map = torch.zeros(max(atom_species_num), max(atom_species_num))
for k, v in r_max.items():
atom_a, atom_b = k.split('-')

inv_value = r_map[atomic_num_dict[atom_b]-1, atomic_num_dict[atom_a]-1]
if inv_value == 0:
r_map[atomic_num_dict[atom_a]-1, atomic_num_dict[atom_b]-1] = v
r_map[atomic_num_dict[atom_b]-1, atomic_num_dict[atom_a]-1] = v
else:
mean_val = (v + inv_value) / 2
r_map[atomic_num_dict[atom_a]-1, atomic_num_dict[atom_b]-1] = mean_val
r_map[atomic_num_dict[atom_b]-1, atomic_num_dict[atom_a]-1] = mean_val

return r_map
2 changes: 1 addition & 1 deletion dptb/data/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,7 @@ def __init__(
for ib in self.basis.keys():
self.basis[ib] = sorted(
self.basis[ib],
key=lambda s: (anglrMId[re.findall(r"[a-z]",s)[0]], re.findall(r"[1-9*]",s)[0])
key=lambda s: (anglrMId[re.findall(r"[a-z]",s)[0]], re.findall(r"[1-9*]",s)[0] if re.findall(r"[1-9*]",s) else '0')
)

# TODO: get full basis set
Expand Down
87 changes: 84 additions & 3 deletions dptb/entrypoints/collectskf.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from typing import Dict, List, Optional, Any
from typing import Dict, List, Optional, Any, Union
import json
from pathlib import Path
import os
import torch
import glob
from dptb.nn.dftb.sk_param import SKParam

from dptb.nn.dftb2nnsk import DFTB2NNSK
import logging
from dptb.utils.loggers import set_log_handles
from dptb.utils.tools import j_loader, setup_seed, j_must_have
from dptb.utils.argcheck import normalize, collect_cutoffs, normalize_skf2nnsk


__all__ = ["skf2pth"]
__all__ = ["skf2pth", "skf2nnsk"]


log = logging.getLogger(__name__)
Expand Down Expand Up @@ -45,3 +49,80 @@ def skf2pth(
torch.save(skdict, output)


def skf2nnsk(
INPUT:str,
init_model: Optional[str],
output:str,
log_level: int,
log_path: Optional[str] = None,
**kwargs
):
run_opt = {
"init_model": init_model,
"log_path": log_path,
"log_level": log_level
}

# setup output path
if output:
Path(output).parent.mkdir(exist_ok=True, parents=True)
Path(output).mkdir(exist_ok=True, parents=True)
if not log_path:
log_path = os.path.join(str(output), "log.txt")
Path(log_path).parent.mkdir(exist_ok=True, parents=True)

run_opt.update({
"output": str(Path(output).absolute()),
"log_path": str(Path(log_path).absolute())
})
set_log_handles(log_level, Path(log_path) if log_path else None)

jdata = j_loader(INPUT)
jdata = normalize_skf2nnsk(jdata)

common_options = jdata['common_options']
model_options = jdata['model_options']
train_options = jdata['train_options']

basis = j_must_have(common_options, "basis")
skdata_file = j_must_have(common_options, "skdata")

if skdata_file.split('.')[-1] != 'pth':
log.error("The skdata file should be a pth file.")
raise ValueError("The skdata file should be a pth file.")
log.info(f"Loading skdata from {skdata_file}")
skdata = torch.load(skdata_file)

if isinstance(basis, str) and basis == "auto":
log.info("Automatically determining basis")
basis = dict(zip(skdata['OnsiteE'], [['s', 'p', 'd']] * len(skdata['OnsiteE'])))
else:
assert isinstance(basis, dict), "basis must be a dict or 'auto'"

train_options = jdata['train_options']

if init_model:
dftb2nn = DFTB2NNSK.load(ckpt=init_model,
skdata=skdata,
train_options=train_options,
output=run_opt.get('output', './')
)

else:
dftb2nn = DFTB2NNSK(
basis = basis,
skdata = skdata,
method=j_must_have(model_options, "method"),
rs=model_options.get('rs', None),
w = j_must_have(model_options, "w"),
cal_rcuts= model_options.get('rs', None) is None,
atomic_radius= model_options.get('atomic_radius', 'cov'),
train_options=train_options,
output=run_opt.get('output', './')
)

dftb2nn.optimize()




36 changes: 35 additions & 1 deletion dptb/entrypoints/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from dptb.entrypoints.data import data
from dptb.utils.loggers import set_log_handles
from dptb.utils.config_check import check_config_train
from dptb.entrypoints.collectskf import skf2pth
from dptb.entrypoints.collectskf import skf2pth, skf2nnsk
from dptb import __version__


Expand Down Expand Up @@ -364,6 +364,37 @@ def main_parser() -> argparse.ArgumentParser:
help="The output pth files of sk params from skfiles."
)

# neighbour
parser_skf2nn = subparsers.add_parser(
"skf2nn",
parents=[parser_log],
help="Convert the sk files to nn-sk TB model",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)

parser_skf2nn.add_argument(
"INPUT", help="the input parameter file in json or yaml format",
type=str,
default=None
)

parser_skf2nn.add_argument(
"-i",
"--init-model",
type=str,
default=None,
help="Initialize the model by the provided checkpoint.",
)

parser_skf2nn.add_argument(
"-o",
"--output",
type=str,
default="./",
help="The output files in training.",
)


return parser

def parse_args(args: Optional[List[str]] = None) -> argparse.Namespace:
Expand Down Expand Up @@ -424,3 +455,6 @@ def main():

elif args.command == 'cskf':
skf2pth(**dict_args)

elif args.command == 'skf2nn':
skf2nnsk(**dict_args)
17 changes: 15 additions & 2 deletions dptb/nn/dftb/hopping_dftb.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from dptb.nn.sktb.hopping import BaseHopping
import torch
from dptb.utils._xitorch.interpolate import Interp1D

import logging
log = logging.getLogger(__name__)
class HoppingIntp(BaseHopping):

def __init__(
Expand Down Expand Up @@ -36,7 +37,19 @@ def dftb(self, rij:torch.Tensor, xx:torch.Tensor, yy:torch.Tensor, **kwargs):
assert rij.shape[0] == self.num_ingrls, "the bond distance shape rij is not correct."
else:
raise ValueError("The shape of rij is not correct.")
# 检查 rij 是否在 xx 的范围内
min_x, max_x = self.xx.min(), self.xx.max()
mask_in_range = (rij >= min_x) & (rij <= max_x)
mask_out_range = ~mask_in_range
if mask_out_range.any():
# log.warning("Some rij values are outside the interpolation range and will be set to 0.")
# 创建 rij 的副本,并将范围外的值替换为范围内的值(例如,使用 min_x)
rij_modified = rij.clone()
rij_modified[mask_out_range] = (min_x + max_x) / 2
yyintp = self.intpfunc(xq=rij_modified, y=yy)
yyintp[mask_out_range] = 0.0
else:
yyintp = self.intpfunc(xq=rij, y=yy)

yyintp = self.intpfunc(xq=rij,y=yy)
return yyintp.T

Loading