Skip to content

Commit

Permalink
fix linting and remove ununsed
Browse files Browse the repository at this point in the history
  • Loading branch information
kyonofx committed Nov 18, 2024
1 parent b91157b commit c258300
Show file tree
Hide file tree
Showing 3 changed files with 12 additions and 25 deletions.
6 changes: 4 additions & 2 deletions src/fairchem/core/_cli_hydra.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,14 +18,16 @@

from omegaconf import DictConfig

from fairchem.core.components.runner import Runner


from submitit import AutoExecutor
from submitit.helpers import Checkpointable, DelayedSubmission
from torch.distributed.launcher.api import LaunchConfig, elastic_launch

from fairchem.core.common import distutils
from fairchem.core.common.flags import flags
from fairchem.core.common.utils import get_timestamp_uid, setup_env_vars, setup_imports
from fairchem.core.components.runner import Runner

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
Expand Down Expand Up @@ -121,7 +123,7 @@ def main(
logging.info(
"WARNING: running in local mode, setting dataloading num_workers to 0, see https://github.com/pytorch/examples/issues/526"
)

launch_config = LaunchConfig(
min_nodes=1,
max_nodes=1,
Expand Down
29 changes: 7 additions & 22 deletions src/fairchem/core/models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,6 @@ def forward(self, data: Batch) -> dict[str, torch.Tensor]:
return


from torch.profiler import record_function

@registry.register_model("hydra")
class HydraModel(nn.Module, GraphModelMixin):
def __init__(
Expand Down Expand Up @@ -321,15 +319,6 @@ def __init__(
raise RuntimeError(
"Heads not specified and not found in the starting checkpoint"
)

if hasattr(self.backbone, 'torch_compile') and self.backbone.torch_compile:
logging.info("use torch compile")
torch._dynamo.config.optimize_ddp = False
torch._dynamo.config.assume_static_by_default = False
torch._dynamo.config.automatic_dynamic_shapes = True
self.backbone = torch.compile(self.backbone, dynamic=True)
for k, v in self.output_heads.items():
self.output_heads[k] = torch.compile(v, dynamic=True)

def forward(self, data: Batch):
# lazily get device from input to use with amp, at least one input must be a tensor to figure out it's device
Expand All @@ -342,20 +331,16 @@ def forward(self, data: Batch):
), f"all inputs must be on the same device, found the following devices {device_from_tensors}"
self.device = device_from_tensors.pop()

if hasattr(self.backbone, 'torch_compile') and self.backbone.torch_compile:
data = dict(data)
emb = self.backbone(data)

# Predict all output properties for all structures in the batch for now.
out = {}
for k in self.output_heads:
with record_function(f"{k} head"):
with torch.autocast(
device_type=self.device, enabled=self.output_heads[k].use_amp
):
if self.pass_through_head_outputs:
out.update(self.output_heads[k](data, emb))
else:
out[k] = self.output_heads[k](data, emb)
with torch.autocast(
device_type=self.device, enabled=self.output_heads[k].use_amp
):
if self.pass_through_head_outputs:
out.update(self.output_heads[k](data, emb))
else:
out[k] = self.output_heads[k](data, emb)

return out
2 changes: 1 addition & 1 deletion src/fairchem/core/trainers/ocp_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -372,7 +372,7 @@ def _compute_metrics(self, out, batch, evaluator, metrics=None):
### Retrieve free atoms
# fixed = batch.fixed
fixed = batch.get("fixed", torch.zeros(batch.natoms.sum())).to(batch.natoms.device).long()

mask = fixed == 0

s_idx = 0
Expand Down

0 comments on commit c258300

Please sign in to comment.