Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

experimental support to LOMO optimizer #681

Open
wants to merge 10 commits into
base: llm
Choose a base branch
from

Conversation

zhenqincn
Copy link

@CLAassistant
Copy link

CLAassistant commented Aug 17, 2023

CLA assistant check
All committers have signed the CLA.

@yxdyc yxdyc requested review from rayrayraykk, qbc2016 and yxdyc August 17, 2023 07:22
@zhenqincn zhenqincn changed the title Zhenqin/llm experimental support to LOMO optimizer Aug 22, 2023
# SOFTWARE.

import torch
from torch.optim import Optimizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since FS has TF backend, we should use try-catch to avoid error:

try:
    import torch
except ImportError:
    torch=None

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, LOMO has no support to TF. A try-catch statement is added following this suggestion as:

try:
    import torch
except ImportError:
    torch=None
    raise ImportError('Currently, LOMO optimizer is only implemented with `pytorch`')

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please don't raise ImportError or anything in header (do it in __init__ or somewhere), since tf backend user might run into error since this file will be imported.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, in the latest commit, this error raising has beem moved in __init__ in the corresponding optimizer.

try:
    import torch
except ImportError:
    torch=None

class LOMO(Optimizer):
    """
    an optimizer for LOMOTrainer
    """

    def __init__(self, model, lr=1e-3, clip_grad_norm=None, clip_grad_value=None):
        if torch is None:
            raise ImportError('Currently, LOMO optimizer is only implemented with `pytorch`')
        self.model = model

@@ -160,6 +161,9 @@ def get_trainer(model=None,
dict_path = "federatedscope.nlp.hetero_tasks.trainer"
elif config.trainer.type.lower() in ['llmtrainer']:
dict_path = "federatedscope.llm.trainer.trainer"
elif config.trainer.type.lower() in ['lomotrainer']:
print('in type')
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use logger

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an omission during development, which has been removed in a new commit.


class LOMOTrainer(LLMTrainer):
def _hook_on_epoch_start(self, ctx):
if not isinstance(ctx.optimizer, LOMO):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check should be in _hook_on_fit_start_init

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This check has been moved in _hook_on_fit_start_init

    def _hook_on_fit_start_init(self, ctx):
        ret = super()._hook_on_fit_start_init(ctx)
        if not isinstance(ctx.optimizer, LOMO):
            raise AttributeError(f'"lomo" must be set as the type of ',
                                 f'`train.optimizer` if the trainer is LOMOTrainer')
        return ret

return super()._hook_on_epoch_start(ctx)


def _hook_on_batch_forward(self, ctx):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Since the train and eval will use the same hook_func, we should add an if-else when eval only needs one forward, right?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, one additional check has been added as following:

if ctx.cur_mode in [MODE.TRAIN, MODE.FINETUNE] \
    and (
            not ctx.skip_this_batch 
            and ctx.optimizer.clip_grad_norm is not None 
            and ctx.optimizer.clip_grad_norm > 0
        ):

@rayrayraykk
Copy link
Collaborator

Another minor issue is that since the FS-LLM is publicly available now, the dev/llm branch is deprecated. And you can change the target branch of this PR to llm, thx!

@zhenqincn zhenqincn changed the base branch from dev/llm to llm September 5, 2023 09:32
@zhenqincn
Copy link
Author

All suggestions provided above have been adopted. A review is re-requested. Many thanks.

@zhenqincn
Copy link
Author

In the latest commit, the code has been formatted with pre-commit checks passed.

import torch
except ImportError:
torch = None
from torch.optim import Optimizer
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch.Optimizer should be in try-catch.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many thanks. It has been solved.

def func(x):
with torch.no_grad():
for n, p in self.model.named_parameters():
if p.requires_grad and p.grad is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Line 87 ~ 115:
There are too many judgement branches and they are too deeply nested.I suggest changing it to a series of self-explanatory boolean variables to increase readability

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for this suggestion. The mentioned lines have been reformatted with code comments added in the latest commit.


# check if zero3 is enabled
p0 = list(self.model.parameters())[0]
if hasattr(p0, 'ds_tensor'): # zero3 is enabled
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

perhaps add some reference to make the checking reasonable, e.g., stage3_code

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Many thanks. This reference has been added in the latest commit.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants