-
Notifications
You must be signed in to change notification settings - Fork 214
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
experimental support to LOMO optimizer #681
base: llm
Are you sure you want to change the base?
Conversation
# SOFTWARE. | ||
|
||
import torch | ||
from torch.optim import Optimizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since FS has TF backend, we should use try-catch to avoid error:
try:
import torch
except ImportError:
torch=None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Currently, LOMO has no support to TF. A try-catch
statement is added following this suggestion as:
try:
import torch
except ImportError:
torch=None
raise ImportError('Currently, LOMO optimizer is only implemented with `pytorch`')
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please don't raise ImportError
or anything in header (do it in __init__
or somewhere), since tf backend user might run into error since this file will be imported.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
OK, in the latest commit, this error raising has beem moved in __init__
in the corresponding optimizer.
try:
import torch
except ImportError:
torch=None
class LOMO(Optimizer):
"""
an optimizer for LOMOTrainer
"""
def __init__(self, model, lr=1e-3, clip_grad_norm=None, clip_grad_value=None):
if torch is None:
raise ImportError('Currently, LOMO optimizer is only implemented with `pytorch`')
self.model = model
@@ -160,6 +161,9 @@ def get_trainer(model=None, | |||
dict_path = "federatedscope.nlp.hetero_tasks.trainer" | |||
elif config.trainer.type.lower() in ['llmtrainer']: | |||
dict_path = "federatedscope.llm.trainer.trainer" | |||
elif config.trainer.type.lower() in ['lomotrainer']: | |||
print('in type') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
use logger
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is an omission during development, which has been removed in a new commit.
|
||
class LOMOTrainer(LLMTrainer): | ||
def _hook_on_epoch_start(self, ctx): | ||
if not isinstance(ctx.optimizer, LOMO): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this check should be in _hook_on_fit_start_init
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This check has been moved in _hook_on_fit_start_init
def _hook_on_fit_start_init(self, ctx):
ret = super()._hook_on_fit_start_init(ctx)
if not isinstance(ctx.optimizer, LOMO):
raise AttributeError(f'"lomo" must be set as the type of ',
f'`train.optimizer` if the trainer is LOMOTrainer')
return ret
return super()._hook_on_epoch_start(ctx) | ||
|
||
|
||
def _hook_on_batch_forward(self, ctx): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since the train
and eval
will use the same hook_func, we should add an if-else
when eval only needs one forward, right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, one additional check has been added as following:
if ctx.cur_mode in [MODE.TRAIN, MODE.FINETUNE] \
and (
not ctx.skip_this_batch
and ctx.optimizer.clip_grad_norm is not None
and ctx.optimizer.clip_grad_norm > 0
):
Another minor issue is that since the FS-LLM is publicly available now, the |
All suggestions provided above have been adopted. A review is re-requested. Many thanks. |
In the latest commit, the code has been formatted with |
import torch | ||
except ImportError: | ||
torch = None | ||
from torch.optim import Optimizer |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
torch.Optimizer
should be in try-catch
.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many thanks. It has been solved.
def func(x): | ||
with torch.no_grad(): | ||
for n, p in self.model.named_parameters(): | ||
if p.requires_grad and p.grad is not None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Line 87 ~ 115:
There are too many judgement branches and they are too deeply nested.I suggest changing it to a series of self-explanatory boolean variables to increase readability
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for this suggestion. The mentioned lines have been reformatted with code comments added in the latest commit.
|
||
# check if zero3 is enabled | ||
p0 = list(self.model.parameters())[0] | ||
if hasattr(p0, 'ds_tensor'): # zero3 is enabled |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
perhaps add some reference to make the checking reasonable, e.g., stage3_code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Many thanks. This reference has been added in the latest commit.
Add support to LOMO optimizer for LLM Full Parameter Fine-tuning for Large Language Models with Limited Resources