Skip to content

Commit

Permalink
Simplified calling dbcsr_init_lib
Browse files Browse the repository at this point in the history
- Rely on POINTER to optionally initialize accdrv_active_device_id
  • Loading branch information
hfp committed Oct 20, 2024
1 parent 3a17d66 commit 634e43d
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 31 deletions.
23 changes: 9 additions & 14 deletions src/f77_interface.F
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,9 @@ SUBROUTINE init_cp2k(init_mpi, ierr)
LOGICAL, INTENT(in) :: init_mpi
INTEGER, INTENT(out) :: ierr

INTEGER :: offload_device_count, unit_nr
INTEGER :: unit_nr, offload_device_count
INTEGER, TARGET :: offload_chosen_device
INTEGER, POINTER :: active_device_id
TYPE(cp_logger_type), POINTER :: logger

IF (.NOT. module_initialized) THEN
Expand Down Expand Up @@ -275,28 +277,21 @@ SUBROUTINE init_cp2k(init_mpi, ierr)
! *** init the bibliography ***
CALL add_all_references()

NULLIFY (active_device_id)
offload_device_count = offload_get_device_count()

! Select active offload device when available.
IF (offload_device_count > 0) THEN
CALL offload_set_chosen_device(MOD(default_para_env%mepos, offload_device_count))
offload_chosen_device = MOD(default_para_env%mepos, offload_device_count)
CALL offload_set_chosen_device(offload_chosen_device)
active_device_id => offload_chosen_device
END IF

! Initialize the DBCSR configuration
! Attach the time handler hooks to DBCSR
#if defined __DBCSR_ACC
IF (offload_device_count > 0) THEN
CALL dbcsr_init_lib(default_para_env%get_handle(), timeset_hook, timestop_hook, &
cp_abort_hook, cp_warn_hook, io_unit=unit_nr, &
accdrv_active_device_id=offload_get_chosen_device())
ELSE
CALL dbcsr_init_lib(default_para_env%get_handle(), timeset_hook, timestop_hook, &
cp_abort_hook, cp_warn_hook, io_unit=unit_nr)
END IF
#else
CALL dbcsr_init_lib(default_para_env%get_handle(), timeset_hook, timestop_hook, &
cp_abort_hook, cp_warn_hook, io_unit=unit_nr)
#endif
cp_abort_hook, cp_warn_hook, io_unit=unit_nr, &
accdrv_active_device_id=active_device_id)
CALL cp_sirius_init() ! independent of method_name_id == do_sirius
CALL cp_dlaf_initialize()
CALL pw_fpga_init()
Expand Down
26 changes: 9 additions & 17 deletions src/start/cp2k_runs.F
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,8 @@ RECURSIVE SUBROUTINE cp2k_run(input_declaration, input_file_name, output_unit, m
INTEGER :: f_env_handle, grid_backend, ierr, &
iter_level, method_name_id, &
new_env_id, prog_name_id, run_type_id
INTEGER, TARGET :: offload_chosen_device
INTEGER, POINTER :: active_device_id
INTEGER(KIND=int_8) :: m_memory_max_mpi
LOGICAL :: echo_input, grid_apply_cutoff, &
grid_validate, I_was_ionode
Expand All @@ -186,20 +188,18 @@ RECURSIVE SUBROUTINE cp2k_run(input_declaration, input_file_name, output_unit, m
TYPE(global_environment_type), POINTER :: globenv
TYPE(section_vals_type), POINTER :: glob_section, input_file, root_section

NULLIFY (para_env, f_env, dft_control)
NULLIFY (para_env, f_env, dft_control, active_device_id)
ALLOCATE (para_env)
para_env = mpi_comm

#if defined(__DBCSR_ACC)
IF (offload_get_device_count() > 0) THEN
CALL dbcsr_init_lib(mpi_comm%get_handle(), io_unit=output_unit, &
accdrv_active_device_id=offload_get_chosen_device())
ELSE
CALL dbcsr_init_lib(mpi_comm%get_handle(), io_unit=output_unit)
offload_chosen_device = offload_get_chosen_device()
active_device_id => offload_chosen_device
END IF
#else
CALL dbcsr_init_lib(mpi_comm%get_handle(), io_unit=output_unit)
#endif
CALL dbcsr_init_lib(mpi_comm%get_handle(), io_unit=output_unit, &
accdrv_active_device_id=active_device_id)

NULLIFY (globenv, force_env)

Expand Down Expand Up @@ -264,16 +264,8 @@ RECURSIVE SUBROUTINE cp2k_run(input_declaration, input_file_name, output_unit, m
CASE (do_farming) ! TODO: refactor cp2k's startup code
CALL dbcsr_finalize_lib()
CALL farming_run(input_declaration, root_section, para_env, initial_variables)
#if defined(__DBCSR_ACC)
IF (offload_get_device_count() > 0) THEN
CALL dbcsr_init_lib(mpi_comm%get_handle(), io_unit=output_unit, &
accdrv_active_device_id=offload_get_chosen_device())
ELSE
CALL dbcsr_init_lib(mpi_comm%get_handle(), io_unit=output_unit)
END IF
#else
CALL dbcsr_init_lib(mpi_comm%get_handle(), io_unit=output_unit)
#endif
CALL dbcsr_init_lib(mpi_comm%get_handle(), io_unit=output_unit, &
accdrv_active_device_id=active_device_id)
CASE (do_opt_basis)
CALL run_optimize_basis(input_declaration, root_section, para_env)
globenv%run_type_id = none_run
Expand Down

0 comments on commit 634e43d

Please sign in to comment.