Skip to content

Commit

Permalink
PAO: Do not freeze model to avoid torch memory leaks
Browse files Browse the repository at this point in the history
  • Loading branch information
oschuett committed Oct 26, 2024
1 parent b15be84 commit e1ffbe3
Showing 1 changed file with 11 additions and 4 deletions.
15 changes: 11 additions & 4 deletions src/pao_model.F
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,14 @@ MODULE pao_model
qs_environment_type
USE qs_kind_types, ONLY: get_qs_kind,&
qs_kind_type
USE torch_api, ONLY: &
torch_dict_create, torch_dict_get, torch_dict_insert, torch_dict_release, torch_dict_type, &
torch_model_eval, torch_model_freeze, torch_model_get_attr, torch_model_load
USE torch_api, ONLY: torch_dict_create,&
torch_dict_get,&
torch_dict_insert,&
torch_dict_release,&
torch_dict_type,&
torch_model_eval,&
torch_model_get_attr,&
torch_model_load
USE util, ONLY: sort
#include "./base/base_uses.f90"

Expand Down Expand Up @@ -90,7 +95,9 @@ SUBROUTINE pao_model_load(pao, qs_env, ikind, pao_model_file, model)
CALL torch_model_get_attr(model%torch_model, "feature_kind_names", feature_kind_names)

! Freeze model after all attributes have been read.
CALL torch_model_freeze(model%torch_model)
! TODO Re-enable once the memory leaks of torch::jit::freeze() are fixed.
! https://github.com/pytorch/pytorch/issues/96726
! CALL torch_model_freeze(model%torch_model)

! For each feature kind name lookup its corresponding atomic kind number.
ALLOCATE (model%feature_kinds(SIZE(feature_kind_names)))
Expand Down

0 comments on commit e1ffbe3

Please sign in to comment.