-
Notifications
You must be signed in to change notification settings - Fork 42
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Perform gradient clipping on global batch when using gradient accumulation #6
base: main
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thanks Anna!
praxis/optimizers.py
Outdated
|
||
raw_grad_norm = _compute_grad_norm(raw_grads) | ||
|
||
grads, grad_scale = clip_grads(raw_grads, raw_grad_norm) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
do we need to compute and return grad_scale
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not needed. I no longer return grad_scale
with the latest commit
praxis/optimizers.py
Outdated
grad_scale = jnp.array(1.0) | ||
return grads, grad_scale | ||
|
||
raw_grad_norm = _compute_grad_norm(raw_grads) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
iiuc, if clip_grad_single_norm_to_value
is True, then raw_grad_norm
is not used and we have to compute grad_single_norm separately anyways?
can we move the if-elif-else statement inside out and avoid redundant computation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely. I have addressed this with my latest commit
praxis/optimizers.py
Outdated
|
||
def scale_gradients( | ||
raw_grads: NestedMap, | ||
clip_grad_norm_to_value: Optional[float] = None, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
looking at praxis optimizers, clip_gradient_norm_to_value and clip_gradient_single_norm_to_value default are 0.0 and not None right?
so perhaps the types here should be float and default 0.0 instead of Optional?
clip_grad_single_norm_to_value: Optional[float] = None): | ||
|
||
def clip_grads(grads): | ||
if clip_grad_norm_to_value: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
maybe assert only one of them is true?
Refactoring to allow gradient clipping to be performed on full batch rather than subbatches when using
ShardedStaticAccumulator
. Note that this refactor allows us to maintain support forenable_skip_step_on_gradient_anomalies
and requiresx+1
grad norm calculations per global batch when usingShardedStaticAccumulator
withx
subbatches (once per subbatch to determine whether step should be skipped, once when applying gradient clipping in base optimizer update) and requires one grad clip per global batch.This PR should be taken together with the corresponding Paxml PR.