Skip to content

Commit

Permalink
Fix indexing and enable stress prediction in NequIP & Allegro (cp2k#3428
Browse files Browse the repository at this point in the history
)

Co-authored-by: Maria Bilichenko <[email protected]>
  • Loading branch information
mariabilichenk0 and Maria Bilichenko authored May 24, 2024
1 parent 67d2f81 commit 519c511
Show file tree
Hide file tree
Showing 26 changed files with 342 additions and 124 deletions.
Binary file added data/Allegro/gra-water-deployed-neq060sp.pth
Binary file not shown.
Binary file added data/Allegro/si-deployed-neq060dp.pth
Binary file not shown.
Binary file removed data/Allegro/si-deployed.pth
Binary file not shown.
Binary file removed data/Allegro/water-gra-film-double.pth
Binary file not shown.
Binary file added data/NequIP/water-deployed-neq060dp.pth
Binary file not shown.
Binary file added data/NequIP/water-deployed-neq060sp.pth
Binary file not shown.
Binary file removed data/NequIP/water-double.pth
Binary file not shown.
Binary file removed data/NequIP/water.pth
Binary file not shown.
2 changes: 1 addition & 1 deletion src/fist_force.F
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ SUBROUTINE fist_calc_energy_force(fist_env, debug)

! Compute embedding function and manybody energy
CALL energy_manybody(fist_nonbond_env, atomic_kind_set, local_particles, particle_set, &
cell, pot_manybody, para_env, mm_section)
cell, pot_manybody, para_env, mm_section, use_virial)

! Nonbond contribution + Manybody Forces
IF (shell_present) THEN
Expand Down
180 changes: 138 additions & 42 deletions src/force_fields_input.F
Original file line number Diff line number Diff line change
Expand Up @@ -799,6 +799,7 @@ SUBROUTINE read_nequip_section(nonbonded, section, start)
nonbonded%pot(start + n_items)%pot%set(1)%nequip%unit_energy = unit_energy
nonbonded%pot(start + n_items)%pot%set(1)%nequip%unit_cell = unit_cell
CALL read_nequip_data(nonbonded%pot(start + n_items)%pot%set(1)%nequip)
CALL check_cp2k_atom_names_in_torch(atm_names, nonbonded%pot(start + n_items)%pot%set(1)%nequip%type_names_torch)
nonbonded%pot(start + n_items)%pot%rcutsq = nonbonded%pot(start + n_items)%pot%set(1)%nequip%rcutsq
n_items = n_items + 1
END DO
Expand Down Expand Up @@ -847,11 +848,11 @@ SUBROUTINE read_allegro_section(nonbonded, section, start)
nonbonded%pot(start + n_items)%pot%set(1)%allegro%unit_energy = unit_energy
nonbonded%pot(start + n_items)%pot%set(1)%allegro%unit_cell = unit_cell
CALL read_allegro_data(nonbonded%pot(start + n_items)%pot%set(1)%allegro)
CALL check_cp2k_atom_names_in_torch(atm_names, nonbonded%pot(start + n_items)%pot%set(1)%allegro%type_names_torch)
nonbonded%pot(start + n_items)%pot%rcutsq = nonbonded%pot(start + n_items)%pot%set(1)%allegro%rcutsq
n_items = n_items + 1
END DO
END DO
END SUBROUTINE read_allegro_section
! **************************************************************************************************
Expand Down Expand Up @@ -2433,8 +2434,10 @@ SUBROUTINE read_nequip_data(nequip)
TYPE(nequip_pot_type), POINTER :: nequip
CHARACTER(len=*), PARAMETER :: routineN = 'read_nequip_data'
CHARACTER(LEN=1), PARAMETER :: delimiter = ' '
CHARACTER(LEN=default_path_length) :: allow_tf32_str, config_str, cutoff_str
CHARACTER(LEN=100), ALLOCATABLE, DIMENSION(:) :: tokenized_string
CHARACTER(LEN=default_path_length) :: allow_tf32_str, cutoff_str, types_str
INTEGER :: handle
LOGICAL :: allow_tf32, found
Expand All @@ -2449,6 +2452,16 @@ SUBROUTINE read_nequip_data(nequip)
nequip%nequip_version = torch_model_read_metadata(nequip%nequip_file_name, "nequip_version")
cutoff_str = torch_model_read_metadata(nequip%nequip_file_name, "r_max")
types_str = torch_model_read_metadata(nequip%nequip_file_name, "type_names")
CALL tokenize_string(TRIM(types_str), delimiter, tokenized_string)
IF (ALLOCATED(nequip%type_names_torch)) THEN
DEALLOCATE (nequip%type_names_torch)
END IF
ALLOCATE (nequip%type_names_torch(SIZE(tokenized_string)))
nequip%type_names_torch(:) = tokenized_string(:)
READ (cutoff_str, *) nequip%rcutsq
nequip%rcutsq = cp_unit_to_cp2k(nequip%rcutsq, nequip%unit_coords)
nequip%rcutsq = nequip%rcutsq*nequip%rcutsq
Expand All @@ -2457,9 +2470,18 @@ SUBROUTINE read_nequip_data(nequip)
nequip%unit_energy_val = cp_unit_to_cp2k(nequip%unit_energy_val, nequip%unit_energy)
nequip%unit_cell_val = cp_unit_to_cp2k(nequip%unit_cell_val, nequip%unit_cell)
! look in config which contains all the .yaml file options to see if we use float32 or float64
config_str = torch_model_read_metadata(nequip%nequip_file_name, "config")
CALL read_default_dtype(config_str, nequip%do_nequip_sp)
IF (torch_model_read_metadata(nequip%nequip_file_name, "default_dtype") == "float32" .AND. &
torch_model_read_metadata(nequip%nequip_file_name, "model_dtype") == "float32") THEN
nequip%do_nequip_sp = .TRUE.
ELSE IF (torch_model_read_metadata(nequip%nequip_file_name, "default_dtype") == "float64" .AND. &
torch_model_read_metadata(nequip%nequip_file_name, "model_dtype") == "float64") THEN
nequip%do_nequip_sp = .FALSE.
ELSE
CALL cp_abort(__LOCATION__, &
"Both default_dtype and model_dtype should be either float32 or float64. Currently, default_dtype is <"// &
torch_model_read_metadata(nequip%nequip_file_name, "default_dtype")//"> and model_dtype is <"// &
torch_model_read_metadata(nequip%nequip_file_name, "model_dtype")//">.")
END IF
allow_tf32_str = torch_model_read_metadata(nequip%nequip_file_name, "allow_tf32")
allow_tf32 = (TRIM(allow_tf32_str) == "1")
Expand All @@ -2482,8 +2504,10 @@ SUBROUTINE read_allegro_data(allegro)
TYPE(allegro_pot_type), POINTER :: allegro
CHARACTER(len=*), PARAMETER :: routineN = 'read_allegro_data'
CHARACTER(LEN=1), PARAMETER :: delimiter = ' '
CHARACTER(LEN=default_path_length) :: allow_tf32_str, config_str, cutoff_str
CHARACTER(LEN=100), ALLOCATABLE, DIMENSION(:) :: tokenized_string
CHARACTER(LEN=default_path_length) :: allow_tf32_str, cutoff_str, types_str
INTEGER :: handle
LOGICAL :: allow_tf32, found
Expand All @@ -2503,6 +2527,15 @@ SUBROUTINE read_allegro_data(allegro)
"> has not been deployed; did you forget to run `nequip-deploy`?")
END IF
cutoff_str = torch_model_read_metadata(allegro%allegro_file_name, "r_max")
types_str = torch_model_read_metadata(allegro%allegro_file_name, "type_names")
CALL tokenize_string(TRIM(types_str), delimiter, tokenized_string)
IF (ALLOCATED(allegro%type_names_torch)) THEN
DEALLOCATE (allegro%type_names_torch)
END IF
ALLOCATE (allegro%type_names_torch(SIZE(tokenized_string)))
allegro%type_names_torch(:) = tokenized_string(:)
READ (cutoff_str, *) allegro%rcutsq
allegro%rcutsq = cp_unit_to_cp2k(allegro%rcutsq, allegro%unit_coords)
allegro%rcutsq = allegro%rcutsq*allegro%rcutsq
Expand All @@ -2511,9 +2544,18 @@ SUBROUTINE read_allegro_data(allegro)
allegro%unit_energy_val = cp_unit_to_cp2k(allegro%unit_energy_val, allegro%unit_energy)
allegro%unit_cell_val = cp_unit_to_cp2k(allegro%unit_cell_val, allegro%unit_cell)
! look in config which contains all the .yaml file options to see if we use float32 or float64
config_str = torch_model_read_metadata(allegro%allegro_file_name, "config")
CALL read_default_dtype(config_str, allegro%do_allegro_sp)
IF (torch_model_read_metadata(allegro%allegro_file_name, "default_dtype") == "float32" .AND. &
torch_model_read_metadata(allegro%allegro_file_name, "model_dtype") == "float32") THEN
allegro%do_allegro_sp = .TRUE.
ELSE IF (torch_model_read_metadata(allegro%allegro_file_name, "default_dtype") == "float64" .AND. &
torch_model_read_metadata(allegro%allegro_file_name, "model_dtype") == "float64") THEN
allegro%do_allegro_sp = .FALSE.
ELSE
CALL cp_abort(__LOCATION__, &
"Both default_dtype and model_dtype should be either float32 or float64. Currently, default_dtype is <"// &
torch_model_read_metadata(allegro%allegro_file_name, "default_dtype")//"> and model_dtype is <"// &
torch_model_read_metadata(allegro%allegro_file_name, "model_dtype")//">.")
END IF
allow_tf32_str = torch_model_read_metadata(allegro%allegro_file_name, "allow_tf32")
allow_tf32 = (TRIM(allow_tf32_str) == "1")
Expand All @@ -2528,47 +2570,102 @@ SUBROUTINE read_allegro_data(allegro)
END SUBROUTINE read_allegro_data
! **************************************************************************************************
!> \brief reads the default_dtype used in the Allegro/NequIP model by parsing the config file
!> \param config_str ...
!> \param do_model_sp ...
!> \author Gabriele Tocci
!> \brief returns tokenized string of kinds from .pth file
!> \param element ...
!> \param delimiter ...
!> \param tokenized_array ...
!> \author Maria Bilichenko
! **************************************************************************************************
SUBROUTINE read_default_dtype(config_str, do_model_sp)
SUBROUTINE tokenize_string(element, delimiter, tokenized_array)
CHARACTER(LEN=*), INTENT(IN) :: element
CHARACTER(LEN=1), INTENT(IN) :: delimiter
CHARACTER(LEN=100), ALLOCATABLE, DIMENSION(:), &
INTENT(OUT) :: tokenized_array
CHARACTER(LEN=default_path_length) :: config_str
LOGICAL :: do_model_sp
CHARACTER(LEN=100) :: temp_kinds
INTEGER :: end_pos, i, num_elements, start
LOGICAL, ALLOCATABLE, DIMENSION(:) :: delim_positions
CHARACTER(len=*), PARAMETER :: routineN = 'read_default_dtype'
! Find positions of delimiter within element
ALLOCATE (delim_positions(LEN(element)))
delim_positions = .FALSE.
INTEGER :: handle, i, idx, len_config
DO i = 1, LEN(element)
IF (element(i:i) == delimiter) delim_positions(i) = .TRUE.
END DO
CALL timeset(routineN, handle)
num_elements = COUNT(delim_positions) + 1
len_config = LEN_TRIM(config_str)
idx = INDEX(config_str, "default_dtype:")
IF (idx /= 0) THEN
i = idx + 14 ! skip over "default_dtype:"
DO WHILE (i <= len_config .AND. config_str(i:i) == " ")
i = i + 1 ! skip over any whitespace
END DO
ALLOCATE (tokenized_array(num_elements))
IF (i > len_config) THEN
CALL cp_abort(__LOCATION__, &
"No default_dtype found, check the Nequip/Allegro .yaml or .pth files."// &
" Default_dtype should be either <float32> or <float64>.")
ELSE IF (config_str(i:i + 6) == "float32") THEN
do_model_sp = .TRUE.
ELSE IF (config_str(i:i + 6) == "float64") THEN
do_model_sp = .FALSE.
ELSE
CALL cp_abort(__LOCATION__, &
"The default_dtype should be either <float32> or <float64>."// &
" Check the NequIP/Allegro .yaml and .pth files.")
start = 1
DO i = 1, num_elements
IF (LEN(element) < 3 .AND. COUNT(delim_positions) == 0) THEN ! if there is 1 kind only and it has one or two
!characters (C or Cl) the end_pos will be the index of the last character (1 or 2)
end_pos = LEN(element)
ELSE ! else, the end_pos is determined by the index of the space - 1
end_pos = find_end_pos(start, delim_positions)
END IF
END IF
temp_kinds = element(start:end_pos)
IF (TRIM(temp_kinds) /= '') THEN
tokenized_array(i) = temp_kinds
END IF
start = end_pos + 2
END DO
DEALLOCATE (delim_positions)
END SUBROUTINE tokenize_string
CALL timestop(handle)
END SUBROUTINE read_default_dtype
! **************************************************************************************************
!> \brief finds the position of the atom by the spacing
!> \param start ...
!> \param delim_positions ...
!> \return ...
!> \author Maria Bilichenko
! **************************************************************************************************
INTEGER FUNCTION find_end_pos(start, delim_positions)
INTEGER, INTENT(IN) :: start
LOGICAL, DIMENSION(:), INTENT(IN) :: delim_positions
INTEGER :: end_pos, i
end_pos = start
DO i = start, SIZE(delim_positions)
IF (delim_positions(i)) THEN
end_pos = i - 1
EXIT
END IF
END DO
find_end_pos = end_pos
END FUNCTION find_end_pos
! **************************************************************************************************
!> \brief checks if all the ATOMS from *.inp file are available in *.pth file
!> \param cp2k_inp_atom_types ...
!> \param torch_atom_types ...
!> \author Maria Bilichenko
! **************************************************************************************************
SUBROUTINE check_cp2k_atom_names_in_torch(cp2k_inp_atom_types, torch_atom_types)
CHARACTER(LEN=*), DIMENSION(:), INTENT(IN) :: cp2k_inp_atom_types, torch_atom_types
INTEGER :: i, j
LOGICAL :: found
DO i = 1, SIZE(cp2k_inp_atom_types)
found = .FALSE.
DO j = 1, SIZE(torch_atom_types)
IF (TRIM(cp2k_inp_atom_types(i)) == TRIM(torch_atom_types(j))) THEN
found = .TRUE.
EXIT
END IF
END DO
IF (.NOT. found) THEN
CALL cp_abort(__LOCATION__, &
"Atom "//TRIM(cp2k_inp_atom_types(i))// &
" is defined in the CP2K input file but is missing in the torch model file")
END IF
END DO
END SUBROUTINE check_cp2k_atom_names_in_torch
! **************************************************************************************************
!> \brief reads TABPOT potential from file
Expand Down Expand Up @@ -2626,5 +2723,4 @@ SUBROUTINE read_tabpot_data(tab, para_env, mm_section)
CALL cp_print_key_finished_output(iw, logger, mm_section, "PRINT%FF_INFO")
CALL timestop(handle)
END SUBROUTINE read_tabpot_data
END MODULE force_fields_input
Loading

0 comments on commit 519c511

Please sign in to comment.