From 6733ffbc94b73b94a03a5d9edf3bc458fda6d16f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Ole=20Sch=C3=BCtt?= Date: Wed, 9 Oct 2024 23:36:17 +0200 Subject: [PATCH] Torch: Use ALLOCATABLEs to prevent passing a temporary array --- src/manybody_allegro.F | 15 +-- src/manybody_nequip.F | 12 +-- src/nequip_unittest.F | 223 +++++++++++++++++++++-------------------- src/torch_api.F | 3 +- 4 files changed, 129 insertions(+), 124 deletions(-) diff --git a/src/manybody_allegro.F b/src/manybody_allegro.F index c6d279f36d..6470f671d4 100644 --- a/src/manybody_allegro.F +++ b/src/manybody_allegro.F @@ -250,14 +250,14 @@ SUBROUTINE allegro_energy_store_force_virial(nonbonded, particle_set, cell, atom INTEGER, ALLOCATABLE, DIMENSION(:) :: work_list INTEGER, DIMENSION(:, :), POINTER :: list, sort_list LOGICAL, ALLOCATABLE :: use_atom(:) - REAL(kind=dp) :: drij, lattice(3, 3), rab2_max, rij(3) - REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :) :: edge_cell_shifts, new_edge_cell_shifts, & - pos + REAL(kind=dp) :: drij, rab2_max, rij(3) + REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :) :: edge_cell_shifts, lattice, & + new_edge_cell_shifts, pos REAL(kind=dp), DIMENSION(3) :: cell_v, cvi REAL(kind=dp), DIMENSION(:, :), POINTER :: atomic_energy, forces, virial REAL(kind=dp), DIMENSION(:, :, :), POINTER :: virial3d - REAL(kind=sp) :: lattice_sp(3, 3) - REAL(kind=sp), ALLOCATABLE, DIMENSION(:, :) :: new_edge_cell_shifts_sp, pos_sp + REAL(kind=sp), ALLOCATABLE, DIMENSION(:, :) :: lattice_sp, new_edge_cell_shifts_sp, & + pos_sp REAL(kind=sp), DIMENSION(:, :), POINTER :: atomic_energy_sp, forces_sp TYPE(allegro_data_type), POINTER :: allegro_data TYPE(neighbor_kind_pairs_type), POINTER :: neighbor_kind_pair @@ -401,8 +401,9 @@ SUBROUTINE allegro_energy_store_force_virial(nonbonded, particle_set, cell, atom t_edge_index(:, :) = TRANSPOSE(temp_edge_index) DEALLOCATE (temp_edge_index, edge_index) - lattice = cell%hmat/pot%set(1)%allegro%unit_cell_val - lattice_sp = REAL(lattice, kind=sp) + ALLOCATE (lattice(3, 3), lattice_sp(3, 3)) + lattice(:, :) = cell%hmat/pot%set(1)%allegro%unit_cell_val + lattice_sp(:, :) = REAL(lattice, kind=sp) iat_use = 0 ALLOCATE (pos(3, n_atoms_use), atom_types(n_atoms_use)) DO iat = 1, n_atoms_use diff --git a/src/manybody_nequip.F b/src/manybody_nequip.F index a0632427de..1d53414990 100644 --- a/src/manybody_nequip.F +++ b/src/manybody_nequip.F @@ -229,15 +229,14 @@ SUBROUTINE nequip_energy_store_force_virial(nonbonded, particle_set, cell, atomi edge_count_cell, work_list INTEGER, DIMENSION(:, :), POINTER :: list, sort_list LOGICAL, ALLOCATABLE :: use_atom(:) - REAL(kind=dp) :: drij, lattice(3, 3), rab2_max, rij(3) - REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :) :: edge_cell_shifts, pos, & + REAL(kind=dp) :: drij, rab2_max, rij(3) + REAL(kind=dp), ALLOCATABLE, DIMENSION(:, :) :: edge_cell_shifts, lattice, pos, & temp_edge_cell_shifts REAL(kind=dp), DIMENSION(3) :: cell_v, cvi REAL(kind=dp), DIMENSION(:, :), POINTER :: atomic_energy, forces, total_energy, & virial REAL(kind=dp), DIMENSION(:, :, :), POINTER :: virial3d - REAL(kind=sp) :: lattice_sp(3, 3) - REAL(kind=sp), ALLOCATABLE, DIMENSION(:, :) :: edge_cell_shifts_sp, pos_sp + REAL(kind=sp), ALLOCATABLE, DIMENSION(:, :) :: edge_cell_shifts_sp, lattice_sp, pos_sp REAL(kind=sp), DIMENSION(:, :), POINTER :: atomic_energy_sp, forces_sp, & total_energy_sp TYPE(neighbor_kind_pairs_type), POINTER :: neighbor_kind_pair @@ -403,8 +402,9 @@ SUBROUTINE nequip_energy_store_force_virial(nonbonded, particle_set, cell, atomi t_edge_index(:, :) = TRANSPOSE(edge_index) DEALLOCATE (temp_edge_index, temp_edge_cell_shifts, edge_index) - lattice = cell%hmat/nequip%unit_cell_val - lattice_sp = REAL(lattice, kind=sp) + ALLOCATE (lattice(3, 3), lattice_sp(3, 3)) + lattice(:, :) = cell%hmat/nequip%unit_cell_val + lattice_sp(:, :) = REAL(lattice, kind=sp) iat_use = 0 ALLOCATE (pos(3, n_atoms_use), atom_types(n_atoms_use)) diff --git a/src/nequip_unittest.F b/src/nequip_unittest.F index 26ad9e6b6e..153e14be3a 100644 --- a/src/nequip_unittest.F +++ b/src/nequip_unittest.F @@ -29,10 +29,10 @@ PROGRAM nequip_unittest ! Inputs. INTEGER, PARAMETER :: natoms = 96 INTEGER :: iatom, nedges - REAL(sp), DIMENSION(3, natoms) :: pos - REAL(dp), DIMENSION(3, 3):: cell, hinv - INTEGER(kind=int_8), DIMENSION(natoms):: atom_types - INTEGER(kind=int_8), DIMENSION(:, :), ALLOCATABLE:: edge_index + REAL(sp), DIMENSION(:, :), ALLOCATABLE :: pos, cell + REAL(dp), DIMENSION(3, 3) :: hinv + INTEGER(kind=int_8), DIMENSION(:), ALLOCATABLE :: atom_types + INTEGER(kind=int_8), DIMENSION(:, :), ALLOCATABLE :: edge_index REAL(sp), DIMENSION(:, :), ALLOCATABLE:: edge_cell_shift ! Torch objects. @@ -44,110 +44,113 @@ PROGRAM nequip_unittest NULLIFY (total_energy, atomic_energy, forces) ! A box with 32 water molecules. - pos = RESHAPE(REAL([ & - 42.8861696_dp, -0.0556816_dp, 38.3291611_dp, & - 34.2025887_dp, -0.6185484_dp, 37.3655680_dp, & - 30.0803925_dp, -2.0124176_dp, 36.4807960_dp, & - 28.7057911_dp, -2.6880392_dp, 36.6020983_dp, & - 36.2479426_dp, -0.5163484_dp, 34.4923596_dp, & - 37.6964724_dp, -0.0410872_dp, 35.0140735_dp, & - 27.7606699_dp, 7.4854206_dp, 33.9276919_dp, & - 28.8160999_dp, 6.4985777_dp, 34.2163608_dp, & - 37.1576372_dp, 9.0188280_dp, 31.9265812_dp, & - 38.6063816_dp, 9.5820079_dp, 32.3435972_dp, & - 34.3031959_dp, 2.2195014_dp, 45.9880451_dp, & - 33.2444139_dp, 1.3025332_dp, 46.4698427_dp, & - 38.7286174_dp, -5.0541897_dp, 26.0743968_dp, & - 38.3483921_dp, -6.2832846_dp, 26.9867253_dp, & - 32.8642520_dp, 3.2060632_dp, 30.8971160_dp, & - 31.2904088_dp, 3.0871834_dp, 30.6273977_dp, & - 33.7519869_dp, -3.1383262_dp, 39.6727607_dp, & - 34.6642979_dp, -3.6643859_dp, 38.6466027_dp, & - 42.7173214_dp, 5.1246883_dp, 32.5883401_dp, & - 41.5627455_dp, 5.5893544_dp, 33.4174902_dp, & - 32.4283800_dp, 9.1182520_dp, 30.5477678_dp, & - 32.6432407_dp, 10.770683_dp, 30.4842778_dp, & - 31.4848670_dp, 4.6777144_dp, 37.3957194_dp, & - 32.3171882_dp, -6.2287496_dp, 36.4671864_dp, & - 26.6621340_dp, 3.1708123_dp, 35.6820146_dp, & - 26.5271367_dp, 1.6039040_dp, 35.4883482_dp, & - 32.0238236_dp, 16.918208_dp, 31.6883569_dp, & - 31.4006579_dp, 7.0315610_dp, 30.2394554_dp, & - 33.5264253_dp, -3.5594808_dp, 34.2636830_dp, & - 34.6404855_dp, -3.2653833_dp, 35.4971482_dp, & - 40.0564375_dp, -0.3054386_dp, 29.8312074_dp, & - 39.4784464_dp, -1.0948314_dp, 38.3101140_dp, & - 39.7040761_dp, 1.9584631_dp, 33.3902375_dp, & - 38.3338570_dp, 2.6967178_dp, 42.9261945_dp, & - 40.1820455_dp, -7.2199289_dp, 27.6580390_dp, & - 39.3204431_dp, -8.4564252_dp, 28.1319658_dp, & - 36.3876963_dp, 8.8117085_dp, 38.3545362_dp, & - 36.3205637_dp, 9.0063075_dp, 36.7526001_dp, & - 29.9991583_dp, -5.5637817_dp, 33.9295050_dp, & - 30.7728545_dp, -5.0385870_dp, 35.1998067_dp, & - 40.0592517_dp, 6.3305279_dp, 28.2579461_dp, & - 40.2398360_dp, 5.1745923_dp, 29.2962956_dp, & - 26.3320911_dp, 2.4393638_dp, 33.5653868_dp, & - 26.9606971_dp, 1.2711078_dp, 32.5923884_dp, & - 34.8372697_dp, -0.4722708_dp, 30.3824362_dp, & - 35.3968813_dp, -1.9268483_dp, 30.3081837_dp, & - 32.1217607_dp, -0.7333429_dp, 36.5104382_dp, & - 32.2180843_dp, 7.8454304_dp, 35.6671967_dp, & - 36.3780998_dp, -4.3048878_dp, 36.4539793_dp, & - 35.8119275_dp, -3.0013928_dp, 27.0348937_dp, & - 29.6452491_dp, 1.0652123_dp, 35.7143653_dp, & - 30.3794654_dp, -0.0668146_dp, 34.9882468_dp, & - 34.2149336_dp, -1.6559120_dp, 33.8876437_dp, & - 34.7842435_dp, -1.0252141_dp, 32.5034832_dp, & - 40.4649954_dp, 1.1467825_dp, 31.3073503_dp, & - 41.3262469_dp, 0.6550803_dp, 32.4555882_dp, & - 29.0210859_dp, 3.5038194_dp, 39.9087702_dp, & - 29.4945426_dp, 3.7276637_dp, 41.3766138_dp, & - 34.1359664_dp, -6.7533422_dp, 32.3568410_dp, & - 34.9546570_dp, -5.7704242_dp, 31.4571066_dp, & - 33.2532356_dp, 1.5268048_dp, 44.0562171_dp, & - 33.7931669_dp, 0.5014632_dp, 43.0597590_dp, & - 36.8205409_dp, 2.6214681_dp, 40.6834006_dp, & - 37.5552706_dp, 1.5649832_dp, 39.7648935_dp, & - 43.2099087_dp, -0.0628456_dp, 47.2593155_dp, & - 29.3940583_dp, -2.3133019_dp, 37.1407883_dp, & - 36.7415708_dp, -0.0838710_dp, 35.2591783_dp, & - 27.9424776_dp, 6.7622961_dp, 34.5648384_dp, & - 37.6812656_dp, 9.4216399_dp, 32.6478643_dp, & - 33.3171290_dp, 2.0951401_dp, 45.8722265_dp, & - 37.9951355_dp, 4.3611431_dp, 26.5571819_dp, & - 32.1824670_dp, 2.6611503_dp, 30.4577248_dp, & - 34.6538012_dp, -3.4374573_dp, 39.5889245_dp, & - 42.2929833_dp, 5.9471069_dp, 32.8460995_dp, & - 32.9604690_dp, 9.9050313_dp, 30.1587306_dp, & - 31.4281886_dp, -5.8338304_dp, 36.6738743_dp, & - 26.0563730_dp, 2.4973869_dp, 35.3486870_dp, & - 32.0334927_dp, 17.3252289_dp, 30.8116013_dp, & - 33.8252182_dp, -2.9520949_dp, 35.0220460_dp, & - 39.4569981_dp, -0.3072759_dp, 38.9347829_dp, & - 29.4846708_dp, 2.8692561_dp, 43.0061868_dp, & - 39.2864184_dp, -7.6206103_dp, 27.6271147_dp, & - 35.8797502_dp, 8.6515870_dp, 37.5221734_dp, & - 30.3582543_dp, -4.7607656_dp, 34.3355645_dp, & - 40.7098956_dp, 5.8331250_dp, 28.7558375_dp, & - 26.7179083_dp, 2.2415138_dp, 32.6577297_dp, & - 35.6589256_dp, -0.9968903_dp, 30.5749530_dp, & - 31.5851602_dp, -1.3121804_dp, 35.9011109_dp, & - 35.5489386_dp, -3.9056138_dp, 26.8214490_dp, & - 29.5656616_dp, 0.4681794_dp, 34.9670711_dp, & - 34.7615128_dp, -0.9569680_dp, 33.4891367_dp, & - 40.4853406_dp, 0.4023620_dp, 31.9425416_dp, & - 29.6728289_dp, 4.0134825_dp, 40.4505780_dp, & - 34.1272286_dp, -5.8796882_dp, 31.8925999_dp, & - 33.1168884_dp, 1.2338084_dp, 43.1127997_dp, & - 37.1996993_dp, 2.5049007_dp, 39.7917126_dp], kind=sp), shape=[3, natoms]) - - cell(1, :) = [9.85_dp, 0.0_dp, 0.0_dp] - cell(2, :) = [0.0_dp, 9.85_dp, 0.0_dp] - cell(3, :) = [0.0_dp, 0.0_dp, 9.85_dp] - - hinv = inv_3x3(cell) - + ALLOCATE (pos(3, natoms)) + pos(:, :) = RESHAPE(REAL([ & + 42.8861696_dp, -0.0556816_dp, 38.3291611_dp, & + 34.2025887_dp, -0.6185484_dp, 37.3655680_dp, & + 30.0803925_dp, -2.0124176_dp, 36.4807960_dp, & + 28.7057911_dp, -2.6880392_dp, 36.6020983_dp, & + 36.2479426_dp, -0.5163484_dp, 34.4923596_dp, & + 37.6964724_dp, -0.0410872_dp, 35.0140735_dp, & + 27.7606699_dp, 7.4854206_dp, 33.9276919_dp, & + 28.8160999_dp, 6.4985777_dp, 34.2163608_dp, & + 37.1576372_dp, 9.0188280_dp, 31.9265812_dp, & + 38.6063816_dp, 9.5820079_dp, 32.3435972_dp, & + 34.3031959_dp, 2.2195014_dp, 45.9880451_dp, & + 33.2444139_dp, 1.3025332_dp, 46.4698427_dp, & + 38.7286174_dp, -5.0541897_dp, 26.0743968_dp, & + 38.3483921_dp, -6.2832846_dp, 26.9867253_dp, & + 32.8642520_dp, 3.2060632_dp, 30.8971160_dp, & + 31.2904088_dp, 3.0871834_dp, 30.6273977_dp, & + 33.7519869_dp, -3.1383262_dp, 39.6727607_dp, & + 34.6642979_dp, -3.6643859_dp, 38.6466027_dp, & + 42.7173214_dp, 5.1246883_dp, 32.5883401_dp, & + 41.5627455_dp, 5.5893544_dp, 33.4174902_dp, & + 32.4283800_dp, 9.1182520_dp, 30.5477678_dp, & + 32.6432407_dp, 10.770683_dp, 30.4842778_dp, & + 31.4848670_dp, 4.6777144_dp, 37.3957194_dp, & + 32.3171882_dp, -6.2287496_dp, 36.4671864_dp, & + 26.6621340_dp, 3.1708123_dp, 35.6820146_dp, & + 26.5271367_dp, 1.6039040_dp, 35.4883482_dp, & + 32.0238236_dp, 16.918208_dp, 31.6883569_dp, & + 31.4006579_dp, 7.0315610_dp, 30.2394554_dp, & + 33.5264253_dp, -3.5594808_dp, 34.2636830_dp, & + 34.6404855_dp, -3.2653833_dp, 35.4971482_dp, & + 40.0564375_dp, -0.3054386_dp, 29.8312074_dp, & + 39.4784464_dp, -1.0948314_dp, 38.3101140_dp, & + 39.7040761_dp, 1.9584631_dp, 33.3902375_dp, & + 38.3338570_dp, 2.6967178_dp, 42.9261945_dp, & + 40.1820455_dp, -7.2199289_dp, 27.6580390_dp, & + 39.3204431_dp, -8.4564252_dp, 28.1319658_dp, & + 36.3876963_dp, 8.8117085_dp, 38.3545362_dp, & + 36.3205637_dp, 9.0063075_dp, 36.7526001_dp, & + 29.9991583_dp, -5.5637817_dp, 33.9295050_dp, & + 30.7728545_dp, -5.0385870_dp, 35.1998067_dp, & + 40.0592517_dp, 6.3305279_dp, 28.2579461_dp, & + 40.2398360_dp, 5.1745923_dp, 29.2962956_dp, & + 26.3320911_dp, 2.4393638_dp, 33.5653868_dp, & + 26.9606971_dp, 1.2711078_dp, 32.5923884_dp, & + 34.8372697_dp, -0.4722708_dp, 30.3824362_dp, & + 35.3968813_dp, -1.9268483_dp, 30.3081837_dp, & + 32.1217607_dp, -0.7333429_dp, 36.5104382_dp, & + 32.2180843_dp, 7.8454304_dp, 35.6671967_dp, & + 36.3780998_dp, -4.3048878_dp, 36.4539793_dp, & + 35.8119275_dp, -3.0013928_dp, 27.0348937_dp, & + 29.6452491_dp, 1.0652123_dp, 35.7143653_dp, & + 30.3794654_dp, -0.0668146_dp, 34.9882468_dp, & + 34.2149336_dp, -1.6559120_dp, 33.8876437_dp, & + 34.7842435_dp, -1.0252141_dp, 32.5034832_dp, & + 40.4649954_dp, 1.1467825_dp, 31.3073503_dp, & + 41.3262469_dp, 0.6550803_dp, 32.4555882_dp, & + 29.0210859_dp, 3.5038194_dp, 39.9087702_dp, & + 29.4945426_dp, 3.7276637_dp, 41.3766138_dp, & + 34.1359664_dp, -6.7533422_dp, 32.3568410_dp, & + 34.9546570_dp, -5.7704242_dp, 31.4571066_dp, & + 33.2532356_dp, 1.5268048_dp, 44.0562171_dp, & + 33.7931669_dp, 0.5014632_dp, 43.0597590_dp, & + 36.8205409_dp, 2.6214681_dp, 40.6834006_dp, & + 37.5552706_dp, 1.5649832_dp, 39.7648935_dp, & + 43.2099087_dp, -0.0628456_dp, 47.2593155_dp, & + 29.3940583_dp, -2.3133019_dp, 37.1407883_dp, & + 36.7415708_dp, -0.0838710_dp, 35.2591783_dp, & + 27.9424776_dp, 6.7622961_dp, 34.5648384_dp, & + 37.6812656_dp, 9.4216399_dp, 32.6478643_dp, & + 33.3171290_dp, 2.0951401_dp, 45.8722265_dp, & + 37.9951355_dp, 4.3611431_dp, 26.5571819_dp, & + 32.1824670_dp, 2.6611503_dp, 30.4577248_dp, & + 34.6538012_dp, -3.4374573_dp, 39.5889245_dp, & + 42.2929833_dp, 5.9471069_dp, 32.8460995_dp, & + 32.9604690_dp, 9.9050313_dp, 30.1587306_dp, & + 31.4281886_dp, -5.8338304_dp, 36.6738743_dp, & + 26.0563730_dp, 2.4973869_dp, 35.3486870_dp, & + 32.0334927_dp, 17.3252289_dp, 30.8116013_dp, & + 33.8252182_dp, -2.9520949_dp, 35.0220460_dp, & + 39.4569981_dp, -0.3072759_dp, 38.9347829_dp, & + 29.4846708_dp, 2.8692561_dp, 43.0061868_dp, & + 39.2864184_dp, -7.6206103_dp, 27.6271147_dp, & + 35.8797502_dp, 8.6515870_dp, 37.5221734_dp, & + 30.3582543_dp, -4.7607656_dp, 34.3355645_dp, & + 40.7098956_dp, 5.8331250_dp, 28.7558375_dp, & + 26.7179083_dp, 2.2415138_dp, 32.6577297_dp, & + 35.6589256_dp, -0.9968903_dp, 30.5749530_dp, & + 31.5851602_dp, -1.3121804_dp, 35.9011109_dp, & + 35.5489386_dp, -3.9056138_dp, 26.8214490_dp, & + 29.5656616_dp, 0.4681794_dp, 34.9670711_dp, & + 34.7615128_dp, -0.9569680_dp, 33.4891367_dp, & + 40.4853406_dp, 0.4023620_dp, 31.9425416_dp, & + 29.6728289_dp, 4.0134825_dp, 40.4505780_dp, & + 34.1272286_dp, -5.8796882_dp, 31.8925999_dp, & + 33.1168884_dp, 1.2338084_dp, 43.1127997_dp, & + 37.1996993_dp, 2.5049007_dp, 39.7917126_dp], kind=sp), shape=[3, natoms]) + + ALLOCATE (cell(3, 3)) + cell(1, :) = [9.85_sp, 0.0_sp, 0.0_sp] + cell(2, :) = [0.0_sp, 9.85_sp, 0.0_sp] + cell(3, :) = [0.0_sp, 0.0_sp, 9.85_sp] + + hinv(:, :) = inv_3x3(REAL(cell, kind=dp)) + + ALLOCATE (atom_types(natoms)) atom_types(:64) = 0 ! Hydrogen atom_types(65:) = 1 ! Oxygen @@ -172,7 +175,7 @@ PROGRAM nequip_unittest CALL torch_dict_insert(inputs, "pos", pos) CALL torch_dict_insert(inputs, "edge_index", edge_index) CALL torch_dict_insert(inputs, "edge_cell_shift", edge_cell_shift) - CALL torch_dict_insert(inputs, "cell", REAL(cell, kind=sp)) + CALL torch_dict_insert(inputs, "cell", cell) CALL torch_dict_insert(inputs, "atom_types", atom_types) CALL torch_dict_create(outputs) @@ -191,7 +194,7 @@ PROGRAM nequip_unittest CALL torch_dict_release(inputs) CALL torch_dict_release(outputs) CALL torch_model_release(model) - DEALLOCATE (edge_index, edge_cell_shift, total_energy, atomic_energy, forces) + DEALLOCATE (edge_index, edge_cell_shift, total_energy, atomic_energy, forces, pos, cell, atom_types) WRITE (*, *) "NequIP unittest was successfully :-)" diff --git a/src/torch_api.F b/src/torch_api.F index d3990e0aea..bfc86aef29 100644 --- a/src/torch_api.F +++ b/src/torch_api.F @@ -69,13 +69,14 @@ MODULE torch_api ! ************************************************************************************************** !> \brief Inserts array into Torch dictionary. The passed array has to outlive the dictionary! +!> The source must be an ALLOCATABLE to prevent passing a temporary array. !> \author Ole Schuett ! ************************************************************************************************** SUBROUTINE torch_dict_insert_${typename}$_${ndims}$d(dict, key, source) TYPE(torch_dict_type), INTENT(INOUT) :: dict CHARACTER(len=*), INTENT(IN) :: key #:set arraydims = ", ".join(":" for i in range(ndims)) - ${type_f}$, CONTIGUOUS, DIMENSION(${arraydims}$), INTENT(IN) :: source + ${type_f}$, DIMENSION(${arraydims}$), ALLOCATABLE, INTENT(IN) :: source #if defined(__LIBTORCH) INTEGER(kind=int_8), DIMENSION(${ndims}$) :: sizes_c