Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rayg1234 committed Aug 19, 2024
1 parent 01ad88c commit 8daa8ca
Showing 1 changed file with 7 additions and 5 deletions.
12 changes: 7 additions & 5 deletions src/fairchem/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def __init__(
pass_through_head_outputs: bool = False,
):
super().__init__()
self.device = None
self.otf_graph = otf_graph
# This is required for hydras with models that have multiple outputs per head, since we will deprecate
# the old config system at some point, this will prevent the need to make major modifications to the trainer
Expand Down Expand Up @@ -299,17 +300,18 @@ def __init__(
raise RuntimeError("Heads not specified and not found in the starting checkpoint")

def forward(self, data: Batch):
# get device from input, at least one input must be a tensor to figure out it's device
device_from_tensors = {x.device.type for x in data.values() if isinstance(x, torch.Tensor)}
assert len(device_from_tensors) == 1, f"all inputs must be on the same device, found the following devices {device_from_tensors}"
device = device_from_tensors.pop()
# lazily get device from input to use with amp, at least one input must be a tensor to figure out it's device
if not self.device:
device_from_tensors = {x.device.type for x in data.values() if isinstance(x, torch.Tensor)}
assert len(device_from_tensors) == 1, f"all inputs must be on the same device, found the following devices {device_from_tensors}"
self.device = device_from_tensors.pop()

emb = self.backbone(data)
# Predict all output properties for all structures in the batch for now.
out = {}
for k in self.output_heads:
with torch.autocast(
device_type=device, enabled=self.output_heads[k].use_amp
device_type=self.device, enabled=self.output_heads[k].use_amp
):
if self.pass_through_head_outputs:
out.update(self.output_heads[k](data, emb))
Expand Down

0 comments on commit 8daa8ca

Please sign in to comment.