Skip to content

Commit

Permalink
Reduce memory consumption (#48)
Browse files Browse the repository at this point in the history
  • Loading branch information
marvinfriede authored Mar 22, 2024
1 parent d001822 commit d0b9f8e
Show file tree
Hide file tree
Showing 3 changed files with 22 additions and 19 deletions.
1 change: 1 addition & 0 deletions setup.cfg
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ project_urls =
packages = find:
install_requires =
numpy
opt-einsum
tad-mctc
torch
python_requires = >=3.8
Expand Down
29 changes: 20 additions & 9 deletions src/tad_dftd3/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"""
import torch
from tad_mctc.batch import real_atoms
from tad_mctc.math import einsum

from .reference import Reference
from .typing import Any, Tensor, WeightingFunction
Expand All @@ -54,24 +55,34 @@ def atomic_c6(numbers: Tensor, weights: Tensor, reference: Reference) -> Tensor:
Parameters
----------
numbers : Tensor
The atomic numbers of the atoms in the system.
The atomic numbers of the atoms in the system of shape `(..., nat)`.
weights : Tensor
Weights of all reference systems.
Weights of all reference systems of shape `(..., nat, 7)`.
reference : Reference
Reference systems for D3 model.
Reference systems for D3 model. Contains the reference C6 coefficients
of shape `(..., nelements, nelements, 7, 7)`.
Returns
-------
Tensor
Atomic dispersion coefficients.
Atomic dispersion coefficients of shape `(..., nat, nat)`.
"""

c6 = reference.c6[numbers.unsqueeze(-1), numbers.unsqueeze(-2)]
gw = torch.mul(
weights.unsqueeze(-1).unsqueeze(-3), weights.unsqueeze(-2).unsqueeze(-4)
# (..., nel, nel, 7, 7) -> (..., nat, nat, 7, 7)
rc6 = reference.c6[numbers.unsqueeze(-1), numbers.unsqueeze(-2)]

# The default einsum path is fastest if the large tensors comes first.
# (..., n1, n2, r1, r2) * (..., n1, r1) * (..., n2, r2) -> (..., n1, n2)
return einsum(
"...ijab,...ia,...jb->...ij",
*(rc6, weights, weights),
optimize=[(0, 1), (0, 1)],
)

return torch.sum(torch.sum(torch.mul(gw, c6), dim=-1), dim=-1)
# NOTE: This old version creates large intermediate tensors and builds the
# full matrix before the sum reduction, which requires a lot of memory.
#
# gw = w.unsqueeze(-1).unsqueeze(-3) * w.unsqueeze(-2).unsqueeze(-4)
# c6 = torch.sum(torch.sum(torch.mul(gw, rc6), dim=-1), dim=-1)


def gaussian_weight(dcn: Tensor, factor: float = 4.0) -> Tensor:
Expand Down
11 changes: 1 addition & 10 deletions test/test_model/test_reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import pytest
import torch
from tad_mctc.convert import str_to_device
from tad_mctc.typing import MockTensor

from tad_dftd3 import reference
from tad_dftd3.typing import DD, Any, Tensor, TypedDict
Expand Down Expand Up @@ -69,16 +70,6 @@ def test_reference_device(device_str: str, device_str2: str) -> None:


def test_reference_different_devices() -> None:
# Custom Tensor class with overridable device property
class MockTensor(Tensor):
@property
def device(self) -> Any:
return self._device

@device.setter
def device(self, value: Any) -> None:
self._device = value

# Custom mock functions
def mock_load_cn(*_: Any, **__: Any) -> Tensor:
tensor = MockTensor([1, 2, 3])
Expand Down

0 comments on commit d0b9f8e

Please sign in to comment.