Skip to content

Commit

Permalink
add option to clip grad by value
Browse files Browse the repository at this point in the history
  • Loading branch information
misko committed Aug 14, 2024
1 parent 917056a commit 925dca9
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,8 @@ def multiply(obj, num):
]
self.scheduler = LRScheduler(self.optimizer, self.config["optim"])

self.clip_grad_norm = self.config["optim"].get("clip_grad_norm")
self.clip_grad_norm = self.config["optim"].get("clip_grad_norm",None)
self.clip_grad_value = self.config["optim"].get("clip_grad_value",None)
self.ema_decay = self.config["optim"].get("ema_decay")
if self.ema_decay:
self.ema = ExponentialMovingAverage(
Expand Down
20 changes: 13 additions & 7 deletions src/fairchem/core/trainers/base_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -702,6 +702,9 @@ def load_extras(self) -> None:
self.clip_grad_norm = aii(
self.config["optim"].get("clip_grad_norm", None), (int, float)
)
self.clip_grad_value = aii(
self.config["optim"].get("clip_grad_value", None), (int, float)
)
self.ema_decay = aii(self.config["optim"].get("ema_decay"), float)
if self.ema_decay:
self.ema = ExponentialMovingAverage(
Expand Down Expand Up @@ -886,15 +889,18 @@ def _backward(self, loss) -> None:
"Please check if all shared parameters are used "
"and point to PyTorch parameters."
)
if self.clip_grad_norm:
if self.clip_grad_norm or self.clip_grad_value:
if self.scaler:
self.scaler.unscale_(self.optimizer)
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
max_norm=self.clip_grad_norm,
)
if self.logger is not None:
self.logger.log({"grad_norm": grad_norm}, step=self.step, split="train")
if self.clip_grad_norm:
grad_norm = torch.nn.utils.clip_grad_norm_(
self.model.parameters(),
max_norm=self.clip_grad_norm,
)
if self.logger is not None:
self.logger.log({"grad_norm": grad_norm}, step=self.step, split="train")
if self.clip_grad_value:
torch.nn.utils.clip_grad_value_(self.model.parameters(), clip_value=self.clip_grad_value)
if self.scaler:
self.scaler.step(self.optimizer)
self.scaler.update()
Expand Down

0 comments on commit 925dca9

Please sign in to comment.