From 287b15232edbb14f7fc487b43ec03151c4685135 Mon Sep 17 00:00:00 2001 From: zhenqincn_intern Date: Thu, 17 Aug 2023 15:00:24 +0800 Subject: [PATCH] release LOMO optimizer --- LICENSE | 25 ++++++++ federatedscope/contrib/optimizer/lomo.py | 74 ++++++++++------------ federatedscope/llm/baseline/lomo.yaml | 9 ++- federatedscope/llm/trainer/lomo_trainer.py | 64 +++++++++---------- 4 files changed, 94 insertions(+), 78 deletions(-) diff --git a/LICENSE b/LICENSE index b4d15e39c..5236fe4c1 100644 --- a/LICENSE +++ b/LICENSE @@ -703,3 +703,28 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + + +--------------------------------------------------------------------------------- +Code in federatedscope/llm/trainer/lomo_trainer.py and federatedscope/contrib/optimizer/lomo.py +is adapted from https://github.com/OpenLMLab/LOMO (MIT License) + +Copyright (c) 2023 OpenLMLab + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. \ No newline at end of file diff --git a/federatedscope/contrib/optimizer/lomo.py b/federatedscope/contrib/optimizer/lomo.py index 4e09d441f..94895418e 100644 --- a/federatedscope/contrib/optimizer/lomo.py +++ b/federatedscope/contrib/optimizer/lomo.py @@ -1,33 +1,39 @@ -import os -import torch -from torch.optim import Optimizer -import torch.distributed as dist +# This implementation is adapted from https://github.com/OpenLMLab/LOMO (MIT License) -from federatedscope.register import register_optimizer +# Copyright (c) 2023 OpenLMLab +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: -class LOMO(Optimizer): - """ - 一个自定义的优化器类LOMO,用于在分布式训练中的梯度更新。 +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. - 该类实现两个梯度更新函数 :meth:`fuse_update` 和 :meth:`fuse_update_zero3`,分别用于非ZeRO和ZeRO模式下的梯度更新。 +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. - :param model: 待优化的模型 - :param lr: 学习率,默认值为1e-3 - :param clip_grad_norm: 梯度裁剪的范数阈值 +import torch +from torch.optim import Optimizer - .. note:: +from federatedscope.register import register_optimizer - clip_grad_norm须为正数 - :param clip_grad_value: 梯度裁剪的值域阈值 +class LOMO(Optimizer): + """ + an optimizer for LOMOTrainer """ def __init__(self, model, lr=1e-3, clip_grad_norm=None, clip_grad_value=None): self.model = model self.lr = lr - # self.local_rank = int(os.environ["LOCAL_RANK"]) - # self.world_size = dist.get_world_size() self.local_rank = 0 self.world_size = 1 self.clip_grad_norm = clip_grad_norm @@ -69,15 +75,12 @@ def __init__(self, model, lr=1e-3, clip_grad_norm=None, clip_grad_value=None): def fuse_update(self): """ - 在非ZeRO模式下更新模型参数的梯度。 + update model parameters in non-ZeRO mode - :return: func,一个闭包函数,用于更新模型参数的梯度 + :return: a closure function used for updating model parameters """ def func(x): - """ - 闭包函数,用于更新模型参数的梯度。 - """ with torch.no_grad(): for n, p in self.model.named_parameters(): if p.requires_grad and p.grad is not None: @@ -111,9 +114,9 @@ def func(x): def fuse_update_zero3(self): """ - 在ZeRO模式下更新模型参数的梯度。 + update model parameters in ZeRO mode - :return: func,一个闭包函数,用于更新模型参数的梯度。 + :return: a closure function used for updating model parameters """ def func(x): with torch.no_grad(): @@ -159,10 +162,10 @@ def func(x): def fused_backward(self, loss, lr): """ - 执行一步反向传播并更新模型的梯度。 + update the model parameters based on the gradient by a step of backpropagation - :param loss: 模型的loss值 - :param lr: 学习率 + :param loss: loss value, scalar + :param lr: learning rate """ self.lr = lr # Users need call grad_norm themselves and then call backward_step @@ -180,9 +183,9 @@ def fused_backward(self, loss, lr): def grad_norm(self, loss): """ - 计算梯度的范数。 + calculate the norm of gradients - :param loss: 模型的loss值 + :param loss: loss value, scala """ self.gather_norm = True self.grad_norms = [] @@ -241,19 +244,11 @@ def __init__(self, def loss_scale(self): return self.cur_scale - # `x` is a torch.Tensor + def _has_inf_or_nan(self, x): try: - # if x is half, the .float() incurs an additional deep copy, but it's necessary if - # Pytorch's .sum() creates a one-element tensor of the same type as x - # (which is true for some recent version of pytorch). cpu_sum = float(x.float().sum()) - # More efficient version that can be used if .sum() returns a Python scalar - # cpu_sum = float(x.sum()) except RuntimeError as instance: - # We want to check if inst is actually an overflow exception. - # RuntimeError could come from a different error. - # If so, we still want the exception to propagate. if "value cannot be converted" not in instance.args[0]: raise return True @@ -262,10 +257,9 @@ def _has_inf_or_nan(self, x): return True return False - # `overflow` is boolean indicating whether the gradient overflowed + def update_scale(self, overflow): if overflow: - # self.cur_scale /= self.scale_factor if self.delayed_shift == 1 or self.cur_hysteresis == 1: if (self.cur_scale == self.min_scale) and self.raise_error_at_min_scale: raise Exception( diff --git a/federatedscope/llm/baseline/lomo.yaml b/federatedscope/llm/baseline/lomo.yaml index 382fd2269..fefe81dd5 100644 --- a/federatedscope/llm/baseline/lomo.yaml +++ b/federatedscope/llm/baseline/lomo.yaml @@ -1,5 +1,5 @@ use_gpu: True -device: 2 +device: 0 early_stop: patience: 0 federate: @@ -17,13 +17,12 @@ llm: chat: max_len: 2000 adapter: - use: True + use: False args: [ { 'adapter_package': 'peft', 'adapter_method': 'lora', 'r': 8, 'lora_alpha': 16, 'lora_dropout': 0.05 } ] dataloader: batch_size: 1 model: - # type: 'decapoda-research/llama-7b-hf@huggingface_llm' - type: 'openlm-research/open_llama_7b@huggingface_llm' + type: 'decapoda-research/llama-7b-hf@huggingface_llm' train: local_update_steps: 30 batch_or_epoch: batch @@ -37,6 +36,6 @@ criterion: trainer: type: lomotrainer eval: - freq: 1 + freq: 10 metrics: ['loss'] count_flops: False \ No newline at end of file diff --git a/federatedscope/llm/trainer/lomo_trainer.py b/federatedscope/llm/trainer/lomo_trainer.py index f561a8c03..da25f7867 100644 --- a/federatedscope/llm/trainer/lomo_trainer.py +++ b/federatedscope/llm/trainer/lomo_trainer.py @@ -1,23 +1,50 @@ +# This implementation is adapted from https://github.com/OpenLMLab/LOMO (MIT License) + +# Copyright (c) 2023 OpenLMLab + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + import torch import logging from federatedscope.register import register_trainer -# from federatedscope.core.trainers import GeneralTorchTrainer from federatedscope.llm.trainer.trainer import LLMTrainer from federatedscope.core.trainers.context import CtxVar from federatedscope.core.trainers.enums import LIFECYCLE -from federatedscope.core.monitors.monitor import Monitor -from federatedscope.llm.model.adapter_builder import AdapterModel +from federatedscope.contrib.optimizer.lomo import LOMO logger = logging.getLogger(__name__) class LOMOTrainer(LLMTrainer): + def _hook_on_epoch_start(self, ctx): + if not isinstance(ctx.optimizer, LOMO): + raise AttributeError('"LOMO" must be set as the type of `train.optimizer` if the trainer is LOMOTrainer') + return super()._hook_on_epoch_start(ctx) + + def _hook_on_batch_forward(self, ctx): input_ids = ctx.data_batch['input_ids'].to(ctx.device) labels = ctx.data_batch['labels'].to(ctx.device) attention_mask = ctx.data_batch['attention_mask'].to(ctx.device) + # the first forward outputs = ctx.model(input_ids=input_ids, labels=labels, attention_mask=attention_mask) @@ -33,26 +60,9 @@ def _hook_on_batch_forward(self, ctx): else: ctx.skip_this_batch = CtxVar(False, LIFECYCLE.BATCH) - # if self.training_args.clip_grad_norm is not None and self.training_args.clip_grad_norm > 0: if not ctx.skip_this_batch and ctx.optimizer.clip_grad_norm is not None and ctx.optimizer.clip_grad_norm > 0: ctx.optimizer.grad_norm(loss) - # TODO check how to implement this - # if ctx.optimizer.loss_scaler and ctx.optimizer.loss_scaler.has_overflow_serial: - # # print(f"Gradient overflow, skipping step {self.global_step}") - # ctx.optimizer.get_param_coordinator(training=True).reset_step() - # # if self.allow_print: - # # self.wandb.log( - # # { - # # 'train/loss': loss.item(), - # # 'train/learning_rate': self.lr, - # # 'train/global_step': self.global_step, - # # }, - # # step=self.global_step - # # ) - - # else: - # ctx.optimizer.get_param_coordinator(training=True).reset_step() - # 第二次forward + # the second forward input_ids = ctx.data_batch['input_ids'].to(ctx.device) labels = ctx.data_batch['labels'].to(ctx.device) attention_mask = ctx.data_batch['attention_mask'].to(ctx.device) @@ -70,26 +80,14 @@ def _hook_on_batch_forward(self, ctx): ctx.batch_size = CtxVar(len(labels), LIFECYCLE.BATCH) def _hook_on_batch_backward(self, ctx): - # ctx.optimizer.zero_grad() if ctx.skip_this_batch: return - - # scaled_loss = loss * self.loss_scaler.loss_scale - # - # scaled_loss.backward() - # # update the last one since the hook function will not be called for the last parameter - # self.grad_func(0) - # self.loss_scaler.update_scale(overflow=False) ctx.optimizer.fused_backward(ctx.loss_task, ctx.optimizer.lr) - # TODO check how to implement this - # ctx.optimizer.get_param_coordinator(training=True).reset_step() - # ctx.loss_task.backward() if ctx.grad_clip > 0: torch.nn.utils.clip_grad_norm_(ctx.model.parameters(), ctx.grad_clip) - # ctx.optimizer.step() if ctx.scheduler is not None: ctx.scheduler.step()