Skip to content

Commit

Permalink
Add Wannier TB for initial training SK TB parameters. (#41)
Browse files Browse the repository at this point in the history
* 📃 docs((train): update comments

1. remove useless comments in train_nnsk
2. add comment  in trainloss for  block_l2)

* feat: update using wannier as training labels

* feat: updata trainloss.py the block_l2
remove the randdom rate for calc loss

* 🧪 test: create test_process_wannier.py

* ✨ feat(dataprocess): update process_wannier.py:

add the onsite energy shift. shift the min onsite energy of wanierTB to make it aligned with onsite db values.

* update process_wannier.py

* test: update hop_bondwise.npy and test_process_wannier.py

* example: add  example of  using wannier as training labels on Si

* test: add test for train using wannier TB as labels.

* ci: update unit_test.yml

* ✨ feat(dataload  argcheck): add new tag for wannier

add a para. in data_options. use_wannier. to set whether to use wannier as labels.

* test: update input_wan.json for tests

* update docstring process_wannier.py

* Stop tracking and remove tests/data/hBN/data/set.0/xdat2.traj

* chore: delete xdat2.traj

* Stop tracking and remove dptb/tests/data/postrun/run_config.json

* Stop tracking and remove dptb/tests/data/test_all/test_config.json

* Stop tracking and remove

/tests/data/test_all/checkpoint/best_nnsk_b4.000_c4.000_w0.300.json
tests/data/test_all/fancy_ones/checkpoint/best_nnsk_b4.000_c4.000_w0.300.json

---------

Co-authored-by: Yinzhanghao Zhou <[email protected]>
  • Loading branch information
QG-phy and floatingCatty authored Nov 3, 2023
1 parent 54d9778 commit dc6055b
Show file tree
Hide file tree
Showing 30 changed files with 165,459 additions and 522 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/unit_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
id: s1
uses: actions/checkout@v3
with:
fetch-depth: 0
fetch-depth: 0
ref: "refs/pull/${{ github.event.number }}/merge"
- name: Install DeePTB & Run Test
id: s2
Expand Down
6 changes: 6 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,4 +1,10 @@
test*.ipynb
dptb/tests/data/hBN/data/set.0/xdat2.traj
dptb/tests/data/postrun/run_config.json
dptb/tests/data/test_all/test_config.json
dptb/tests/data/test_all/checkpoint/best_nnsk_b5.000_c6.615_w0.265.json
dptb/tests/data/test_all/checkpoint/best_nnsk_b4.000_c4.000_w0.300.json
dptb/tests/data/test_all/fancy_ones/checkpoint/best_nnsk_b4.000_c4.000_w0.300.json
dptb/nnet/__pycache__/
dptb/sktb/__pycache__/
dptb/negf/__pycache__/
Expand Down
36 changes: 21 additions & 15 deletions dptb/dataprocess/datareader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,18 +6,19 @@
import torch
from dptb.structure.structure import BaseStruct
from dptb.dataprocess.processor import Processor
from dptb.dataprocess.process_wannier import get_wannier_blocks
from dptb.utils.tools import j_loader
from dptb.utils.argcheck import normalize_bandinfo

def read_data(path, prefix, cutoff, proj_atom_anglr_m, proj_atom_neles, onsitemode:str='uniform', time_symm=True, **kwargs):
def read_data(path, prefix, cutoff, proj_atom_anglr_m, proj_atom_neles, onsitemode:str='uniform', time_symm=True, use_wannier=False, **kwargs):
"""根据文件路径和prefix的读取文件夹下的数据文件,并存储为神经网络模型的输入格式数据
"""
filenames = {
"xdat_file": "xdat.traj",
"eigen_file": "eigs.npy",
"kpoints_file" : "kpoints.npy",
"bandinfo_file": "bandinfo.json",
"wannier_file": "wannier.npy"
"wannier_file": "wannier90_hr.dat"
}

filenames.update(kwargs)
Expand All @@ -38,31 +39,36 @@ def read_data(path, prefix, cutoff, proj_atom_anglr_m, proj_atom_neles, onsitemo
kpoints = np.load(data_dirs[ii] + "/" + filenames['kpoints_file'])
eigs = np.load(data_dirs[ii] + "/" + filenames['eigen_file'])
bandinfo = j_loader(data_dirs[ii] + "/" + filenames['bandinfo_file'])
if os.path.exists(data_dirs[ii] + "/" + filenames['wannier_file']):
wannier = np.load(data_dirs[ii] + "/" + filenames['wannier_file'], allow_pickle=True)
wannier = [x.tolist() for x in wannier]
else:
wannier = [None]


bandinfo = normalize_bandinfo(bandinfo)
bandinfo_sets.append(bandinfo)
if len(eigs.shape)==2:
eigs = eigs[np.newaxis]
assert len(eigs.shape) == 3
kpoints_sets.append(kpoints)
eigens_sets.append(eigs)
if wannier[0] is None:
wannier = [None] * eigs.shape[0]
wannier_sets.append(wannier)


for iatom in asetrajs:

for iatom in asetrajs:
struct = BaseStruct(atom=iatom, format='ase', cutoff=cutoff, proj_atom_anglr_m=proj_atom_anglr_m, proj_atom_neles=proj_atom_neles, onsitemode=onsitemode, time_symm=time_symm)
struct_list.append(struct)
struct_list_sets.append(struct_list)


if use_wannier:
assert os.path.exists(data_dirs[ii] + "/" + filenames['wannier_file'])
#wannier = np.load(data_dirs[ii] + "/" + filenames['wannier_file'], allow_pickle=True)
assert len(struct_list) == 1, "wannier90_hr.dat should be calculated for one structure only!"
wannier_proj = bandinfo['wannier_proj']
orb_wan = bandinfo.get('orb_wan', None)
wannier = get_wannier_blocks(file=data_dirs[ii] + "/" + filenames['wannier_file'],
struct=struct_list[0], wannier_proj_orbital=wannier_proj,orb_wan=orb_wan)
wannier = [wannier]
else:
wannier = [None]

if wannier[0] is None:
wannier = [None] * eigs.shape[0]
wannier_sets.append(wannier)

return struct_list_sets, kpoints_sets, eigens_sets, bandinfo_sets, wannier_sets


Expand Down
266 changes: 266 additions & 0 deletions dptb/dataprocess/process_wannier.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,266 @@
import re
import numpy as np
from ase import Atoms
from dptb.utils.constants import anglrMId, Orbital_Order_Wan_Default, Orbital_Order_SK
from dptb.structure.structure import BaseStruct
from dptb.nnsktb.onsiteDB import onsite_energy_database

def get_wannier_blocks(file:str, struct:BaseStruct, wannier_proj_orbital:dict, orb_wan:dict=None):
""" get the hopping matrices in the order of dptb.
"""

Rlatt, hopps, indR0 = read_hr(file)
wannier_orbital_order, sk_orbital_order, iatom_nors = wan_orbital_orders(struct, wannier_proj_orbital, orb_wan)
hopping_bonds = transfrom_Hwan(hopps, Rlatt, indR0, struct, wannier_orbital_order, sk_orbital_order, iatom_nors)

return hopping_bonds


def read_hr(file='wannier90_hr.dat'):
""" Read wannier90_hr.dat. The wannier TB parameters files are generated by wannier90.x.
Parameters:
----------
file: filename of wannier90_hr.dat
Returns:
--------
Rlatt: lattice vectors of Wigner-Seitz grid points, shape (nrpts,3)
hopps: hopping matrices, shape (nrpts,num_wann,num_wann)
indR0: index of R=0, where the hopps[indR0] corresponding to the onsite Hamiltonian block.
"""

f=open(file,'r')
data=f.readlines()
#read hopping matrix
num_wann = int(data[1])
nrpts = int(data[2])
r_hop= np.zeros([num_wann,num_wann], dtype=complex)
#hop=[]
#skip n lines of degeneracy of each Wigner-Seitz grid point
skiplines = int(np.ceil(nrpts / 15.0))
istart = 3 + skiplines
deg=[]
for i in range(3,istart):
deg.append(np.array([int(j) for j in data[i].split()]))
deg=np.concatenate(deg,0)

icount=0
ii=-1
Rlatt = []
hopps = []
for i in range(istart,len(data)):
line=data[i].split()
m = int(line[3]) - 1
n = int(line[4]) - 1
r_hop[m,n] = complex(round(float(line[5]),6),round(float(line[6]),6))
icount+=1
if(icount % (num_wann*num_wann) == 0):
ii+=1
R = np.array([float(x) for x in line[0:3]])
Rlatt.append(R)
hopps.append(r_hop)
r_hop= np.zeros([num_wann,num_wann], dtype=complex)
Rlatt=np.asarray(Rlatt,dtype=int)
hopps=np.asarray(hopps)
deg = np.reshape(deg,[nrpts,1,1])
hopps=hopps/deg

for i in range(nrpts):
if (Rlatt[i]==0).all():
indR0 = i

return Rlatt, hopps, indR0


def wan_orbital_orders(struct:BaseStruct, wannier_proj_orbital:dict, orb_wan:dict=None):
""" get the wannier orbital orders for the wannier orbitals in wannier90_hr.dat.
by default is shoule be in the order of :
atom-0-s, atom-0-pz,atom-0-px, atom-0-py, atom-1-s, atom-1-pz, ..., etc.
Parameters:
-----------
structase: ase.Atoms
wannier_proj_orbital: dict,the orbital defined in projection in wannierizaion process.
e.g.: {'N': ['s','p'], 'B': ['s']} or {'N': 'p', 'B': 's'}
orb_wan: dict, the orbital order in wannier90_hr.dat.
e.g.: {'s': ['s'], 'p': ['pz','px','py'], 'd': ['dz2','dxz','dyz','dx2-y2','dxy']}
Returns:
--------
wannier_orbital_order: list, the orbital order in wannier90_hr.dat.
sk_orbital_order: list, the orbital order in dptb.
iatom_nors: list, the number of orbitals on every atom.
"""

proj_atom_anglr_m = struct.proj_atom_anglr_m
if orb_wan is None:
orb_wan = Orbital_Order_Wan_Default
orb_sk = Orbital_Order_SK

# take the projected_struct as from the structure class, which should be Atoms object.
projected_struct = struct.projected_struct
assert isinstance(projected_struct,Atoms), 'projected_struct should be ase.Atoms'
# check the consistency of wannier_proj_orbital and proj_atom_anglr_m.
assert set(wannier_proj_orbital.keys()) == set(proj_atom_anglr_m.keys())
for ii in proj_atom_anglr_m:
assert len(wannier_proj_orbital[ii]) == len(proj_atom_anglr_m[ii]), 'proj_atom_anglr_m and wannier_proj_orbital are not consistent'
for iorb in proj_atom_anglr_m[ii]:
ishell_symbol = ''.join(re.findall(r'[A-Za-z]',iorb))
assert ishell_symbol in wannier_proj_orbital[ii], 'proj_atom_anglr_m and wannier_proj_orbital are not consistent'

# ------------------------------
# get the wannier orbitals in order.
# ------------------------------
# wannier_orbital_order: the wannier orbitals in the order of wannier90_hr.dat
# e.g.: ['0-s','0-pz','0-px', ... ]
# sk_orbital_order: the orbitals in the order of dptb.
# iatom_nors: the total number of orbitals on every atom.
# ------------------------------

iatom_nors = [] # number of atoms on every atoms:
wannier_orbital_order = []
sk_orbital_order = []

projected_struct_symbols = projected_struct.get_chemical_symbols() # list of atom symbols in the projected_struct
for ia in range(len(projected_struct_symbols)):
iatom_symbols = projected_struct_symbols[ia] # atom symbol of the ia-th atom in the projected_struct
ii_num_orbs = 0
if isinstance (wannier_proj_orbital[iatom_symbols],list):
iorblist = wannier_proj_orbital[iatom_symbols]
elif isinstance (wannier_proj_orbital[iatom_symbols],str):
iorblist = [wannier_proj_orbital[iatom_symbols]]
else:
raise ValueError('wannier_proj_orbital should be a list or a string')

for iorb in iorblist:
ii_num_orbs += 2 * anglrMId[iorb] + 1
for ii_orb in orb_wan[iorb]:
wannier_orbital_order.append(f'{ia}-{ii_orb}')

iatom_nors.append(ii_num_orbs)

if isinstance (proj_atom_anglr_m[iatom_symbols],list):
iorblist = proj_atom_anglr_m[iatom_symbols]
elif isinstance (proj_atom_anglr_m[iatom_symbols],str):
iorblist = [proj_atom_anglr_m[iatom_symbols]]
else:
raise ValueError('proj_atom_anglr_m should be a list or a string')

for iorb in iorblist:
ishell_symbol = ''.join(re.findall(r'[A-Za-z]',iorb))
for ii_orb in orb_sk[ishell_symbol]:
sk_orbital_order.append(f'{ia}-{ii_orb}')

assert len(wannier_orbital_order) == len(wannier_orbital_order), 'wannier_orb_in and sk_orb_in are not consistent'
assert set(wannier_orbital_order) == set(wannier_orbital_order), 'wannier_orb_in and sk_orb_in are not consistent'

iatom_nors=np.array(iatom_nors,dtype=int)

return wannier_orbital_order, sk_orbital_order, iatom_nors

def get_onsite_shift(hopps_r00, struct, wannier_orbital_order, unit='eV'):
'''The function `get_onsite_shift` calculates the onsite shift of a given orbital in a crystal
structure based on the hopping matrix elements in wannier and a database of onsite energies.
Parameters
----------
hopps_r00
The variable `hopps_r00` represents the onsite Hamiltonian matrix elements. It is a square matrix
where each element represents the interaction energy between two orbitals on the same atom.
struct
The `struct` parameter is an object that represents the structure of the system. It likely contains
information about the positions of atoms in the system and other relevant properties.
wannier_orbital_order
The `wannier_orbital_order` parameter is a list that specifies the order of the Wannier orbitals.
Each element in the list represents a Wannier orbital and is in the format
"atom_index-orbital_symbol". For example, if there are 3 atoms and
unit, optional
The `unit` parameter specifies the unit in which the onsite shift will be calculated. It can take
one of three values: 'eV', 'Ry', or 'Hartree'.
Returns
-------
the value of the onsite shift, which is calculated based on the input parameters.
'''

projected_struct = struct.projected_struct
projected_struct_symbols = projected_struct.get_chemical_symbols() # list of atom symbols in the projected_struct

onsite_diag_elements = dict(zip(wannier_orbital_order, np.diag(hopps_r00).real))
min_key = min(onsite_diag_elements, key=onsite_diag_elements.get)
atom_ind = int(min_key.split('-')[0])
orb_symbol = min_key.split('-')[1][0]
atom_symbol = projected_struct_symbols[atom_ind]

proj_atom_anglr_m = struct.proj_atom_anglr_m

if unit == 'eV':
factor = 13.605662285137 * 2 # Hartree to eV
elif unit == 'Ry':
factor = 2.0 # Hartree to Ry
elif unit == 'Hartree':
factor = 1.0
else:
raise ValueError('unit must be eV, Ry or Hartree')

onsite_e_db={}
for i in proj_atom_anglr_m:
onsite_e_db[i]={}
for iorb in proj_atom_anglr_m[i]:
ishell_symbol = ''.join(re.findall(r'[A-Za-z]',iorb))
onsite_e_db[i][ishell_symbol] = onsite_energy_database[i][iorb] * factor

database_onsite_e_min = onsite_e_db[atom_symbol][orb_symbol]

onsite_shift = onsite_diag_elements[min_key] - database_onsite_e_min

return onsite_shift

def transfrom_Hwan(hopps, Rlatt, indR0, struct, wannier_orbital_order, sk_orbital_order, iatom_nors):
""" transform the hopping matrices from the order of wannier90_hr.dat to the order of dptb.
Parameters:
-----------
hopps: hopping matrices, shape (nrpts,num_wann,num_wann)
Rlatt: lattice vectors of Wigner-Seitz grid points, shape (nrpts,3)
wannier_orbital_order: list, the orbital order in wannier90_hr.dat.
sk_orbital_order: list, the orbital order in dptb.
iatom_nors: list, the number of orbitals on every atom.
Returns:
--------
hopping_bonds: dict, the hopping matrices in the order of dptb.
e.g.: hopping_bonds = {'0_0_0_0_0': H_0,0 block at R=000, '0_1_0_0_1': H_0,1 block at R=001, ... }
"""

norb = len(sk_orbital_order)
Mateye = np.eye(norb,dtype=int)
mtrans = np.zeros([norb,norb],dtype=int)
for i in range(norb):
iorb = sk_orbital_order[i]
assert iorb in wannier_orbital_order
ind = wannier_orbital_order.index(iorb)
mtrans[i] += Mateye[ind]

onsite_shift = get_onsite_shift(hopps[indR0], struct, wannier_orbital_order, unit='eV')

hopps_skorder = mtrans @ hopps @ mtrans.T
hopps_skorder[indR0] = hopps_skorder[indR0] - onsite_shift * np.eye(norb)

hopping_bonds = {}
for ir in range(len(Rlatt)):
iR = Rlatt[ir]
for ia in range(len(iatom_nors)):
ist, ied = (np.sum(iatom_nors[:ia]),np.sum(iatom_nors[:ia+1]))
for ja in range(len(iatom_nors)):
jst, jed = (np.sum(iatom_nors[:ja]),np.sum(iatom_nors[:ja+1]))
hopping_bonds[f'{ia}_{ja}_{iR[0]}_{iR[1]}_{iR[2]}'] = hopps_skorder[ir,ist:ied,jst:jed].real

return hopping_bonds
Loading

0 comments on commit dc6055b

Please sign in to comment.