Skip to content

Commit

Permalink
Torch: Use ALLOCATABLEs to prevent passing a temporary array
Browse files Browse the repository at this point in the history
  • Loading branch information
oschuett committed Oct 9, 2024
1 parent cf6a4d8 commit 6733ffb
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 124 deletions.
15 changes: 8 additions & 7 deletions src/manybody_allegro.F
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,14 @@ SUBROUTINE allegro_energy_store_force_virial(nonbonded, particle_set, cell, atom
INTEGER, ALLOCATABLE, DIMENSION(:) :: work_list
INTEGER, DIMENSION(:, :), POINTER :: list, sort_list
LOGICAL, ALLOCATABLE :: use_atom(:)
REAL(kind=dp) :: drij, lattice(3, 3), rab2_max, rij(3)
REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :) :: edge_cell_shifts, new_edge_cell_shifts, &
pos
REAL(kind=dp) :: drij, rab2_max, rij(3)
REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :) :: edge_cell_shifts, lattice, &
new_edge_cell_shifts, pos
REAL(kind=dp), DIMENSION(3) :: cell_v, cvi
REAL(kind=dp), DIMENSION(:, :), POINTER :: atomic_energy, forces, virial
REAL(kind=dp), DIMENSION(:, :, :), POINTER :: virial3d
REAL(kind=sp) :: lattice_sp(3, 3)
REAL(kind=sp), ALLOCATABLE, DIMENSION(:, :) :: new_edge_cell_shifts_sp, pos_sp
REAL(kind=sp), ALLOCATABLE, DIMENSION(:, :) :: lattice_sp, new_edge_cell_shifts_sp, &
pos_sp
REAL(kind=sp), DIMENSION(:, :), POINTER :: atomic_energy_sp, forces_sp
TYPE(allegro_data_type), POINTER :: allegro_data
TYPE(neighbor_kind_pairs_type), POINTER :: neighbor_kind_pair
Expand Down Expand Up @@ -401,8 +401,9 @@ SUBROUTINE allegro_energy_store_force_virial(nonbonded, particle_set, cell, atom

t_edge_index(:, :) = TRANSPOSE(temp_edge_index)
DEALLOCATE (temp_edge_index, edge_index)
lattice = cell%hmat/pot%set(1)%allegro%unit_cell_val
lattice_sp = REAL(lattice, kind=sp)
ALLOCATE (lattice(3, 3), lattice_sp(3, 3))
lattice(:, :) = cell%hmat/pot%set(1)%allegro%unit_cell_val
lattice_sp(:, :) = REAL(lattice, kind=sp)
iat_use = 0
ALLOCATE (pos(3, n_atoms_use), atom_types(n_atoms_use))
DO iat = 1, n_atoms_use
Expand Down
12 changes: 6 additions & 6 deletions src/manybody_nequip.F
Original file line number Diff line number Diff line change
Expand Up @@ -229,15 +229,14 @@ SUBROUTINE nequip_energy_store_force_virial(nonbonded, particle_set, cell, atomi
edge_count_cell, work_list
INTEGER, DIMENSION(:, :), POINTER :: list, sort_list
LOGICAL, ALLOCATABLE :: use_atom(:)
REAL(kind=dp) :: drij, lattice(3, 3), rab2_max, rij(3)
REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :) :: edge_cell_shifts, pos, &
REAL(kind=dp) :: drij, rab2_max, rij(3)
REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :) :: edge_cell_shifts, lattice, pos, &
temp_edge_cell_shifts
REAL(kind=dp), DIMENSION(3) :: cell_v, cvi
REAL(kind=dp), DIMENSION(:, :), POINTER :: atomic_energy, forces, total_energy, &
virial
REAL(kind=dp), DIMENSION(:, :, :), POINTER :: virial3d
REAL(kind=sp) :: lattice_sp(3, 3)
REAL(kind=sp), ALLOCATABLE, DIMENSION(:, :) :: edge_cell_shifts_sp, pos_sp
REAL(kind=sp), ALLOCATABLE, DIMENSION(:, :) :: edge_cell_shifts_sp, lattice_sp, pos_sp
REAL(kind=sp), DIMENSION(:, :), POINTER :: atomic_energy_sp, forces_sp, &
total_energy_sp
TYPE(neighbor_kind_pairs_type), POINTER :: neighbor_kind_pair
Expand Down Expand Up @@ -403,8 +402,9 @@ SUBROUTINE nequip_energy_store_force_virial(nonbonded, particle_set, cell, atomi
t_edge_index(:, :) = TRANSPOSE(edge_index)
DEALLOCATE (temp_edge_index, temp_edge_cell_shifts, edge_index)

lattice = cell%hmat/nequip%unit_cell_val
lattice_sp = REAL(lattice, kind=sp)
ALLOCATE (lattice(3, 3), lattice_sp(3, 3))
lattice(:, :) = cell%hmat/nequip%unit_cell_val
lattice_sp(:, :) = REAL(lattice, kind=sp)

iat_use = 0
ALLOCATE (pos(3, n_atoms_use), atom_types(n_atoms_use))
Expand Down
Loading

0 comments on commit 6733ffb

Please sign in to comment.