Skip to content

Commit

Permalink
release LOMO optimizer
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenqincn committed Aug 17, 2023
1 parent 0cfb86d commit 287b152
Show file tree
Hide file tree
Showing 4 changed files with 94 additions and 78 deletions.
25 changes: 25 additions & 0 deletions LICENSE
Original file line number Diff line number Diff line change
Expand Up @@ -703,3 +703,28 @@ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.


---------------------------------------------------------------------------------
Code in federatedscope/llm/trainer/lomo_trainer.py and federatedscope/contrib/optimizer/lomo.py
is adapted from https://github.com/OpenLMLab/LOMO (MIT License)

Copyright (c) 2023 OpenLMLab

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:

The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.

THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
74 changes: 34 additions & 40 deletions federatedscope/contrib/optimizer/lomo.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,39 @@
import os
import torch
from torch.optim import Optimizer
import torch.distributed as dist
# This implementation is adapted from https://github.com/OpenLMLab/LOMO (MIT License)

from federatedscope.register import register_optimizer
# Copyright (c) 2023 OpenLMLab

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

class LOMO(Optimizer):
"""
一个自定义的优化器类LOMO,用于在分布式训练中的梯度更新。
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

该类实现两个梯度更新函数 :meth:`fuse_update` 和 :meth:`fuse_update_zero3`,分别用于非ZeRO和ZeRO模式下的梯度更新。
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

:param model: 待优化的模型
:param lr: 学习率,默认值为1e-3
:param clip_grad_norm: 梯度裁剪的范数阈值
import torch
from torch.optim import Optimizer

.. note::
from federatedscope.register import register_optimizer

clip_grad_norm须为正数

:param clip_grad_value: 梯度裁剪的值域阈值
class LOMO(Optimizer):
"""
an optimizer for LOMOTrainer
"""

def __init__(self, model, lr=1e-3, clip_grad_norm=None, clip_grad_value=None):
self.model = model
self.lr = lr
# self.local_rank = int(os.environ["LOCAL_RANK"])
# self.world_size = dist.get_world_size()
self.local_rank = 0
self.world_size = 1
self.clip_grad_norm = clip_grad_norm
Expand Down Expand Up @@ -69,15 +75,12 @@ def __init__(self, model, lr=1e-3, clip_grad_norm=None, clip_grad_value=None):

def fuse_update(self):
"""
在非ZeRO模式下更新模型参数的梯度。
update model parameters in non-ZeRO mode
:return: func,一个闭包函数,用于更新模型参数的梯度
:return: a closure function used for updating model parameters
"""

def func(x):
"""
闭包函数,用于更新模型参数的梯度。
"""
with torch.no_grad():
for n, p in self.model.named_parameters():
if p.requires_grad and p.grad is not None:
Expand Down Expand Up @@ -111,9 +114,9 @@ def func(x):

def fuse_update_zero3(self):
"""
在ZeRO模式下更新模型参数的梯度。
update model parameters in ZeRO mode
:return: func,一个闭包函数,用于更新模型参数的梯度。
:return: a closure function used for updating model parameters
"""
def func(x):
with torch.no_grad():
Expand Down Expand Up @@ -159,10 +162,10 @@ def func(x):

def fused_backward(self, loss, lr):
"""
执行一步反向传播并更新模型的梯度。
update the model parameters based on the gradient by a step of backpropagation
:param loss: 模型的loss值
:param lr: 学习率
:param loss: loss value, scalar
:param lr: learning rate
"""
self.lr = lr
# Users need call grad_norm themselves and then call backward_step
Expand All @@ -180,9 +183,9 @@ def fused_backward(self, loss, lr):

def grad_norm(self, loss):
"""
计算梯度的范数。
calculate the norm of gradients
:param loss: 模型的loss值
:param loss: loss value, scala
"""
self.gather_norm = True
self.grad_norms = []
Expand Down Expand Up @@ -241,19 +244,11 @@ def __init__(self,
def loss_scale(self):
return self.cur_scale

# `x` is a torch.Tensor

def _has_inf_or_nan(self, x):
try:
# if x is half, the .float() incurs an additional deep copy, but it's necessary if
# Pytorch's .sum() creates a one-element tensor of the same type as x
# (which is true for some recent version of pytorch).
cpu_sum = float(x.float().sum())
# More efficient version that can be used if .sum() returns a Python scalar
# cpu_sum = float(x.sum())
except RuntimeError as instance:
# We want to check if inst is actually an overflow exception.
# RuntimeError could come from a different error.
# If so, we still want the exception to propagate.
if "value cannot be converted" not in instance.args[0]:
raise
return True
Expand All @@ -262,10 +257,9 @@ def _has_inf_or_nan(self, x):
return True
return False

# `overflow` is boolean indicating whether the gradient overflowed

def update_scale(self, overflow):
if overflow:
# self.cur_scale /= self.scale_factor
if self.delayed_shift == 1 or self.cur_hysteresis == 1:
if (self.cur_scale == self.min_scale) and self.raise_error_at_min_scale:
raise Exception(
Expand Down
9 changes: 4 additions & 5 deletions federatedscope/llm/baseline/lomo.yaml
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use_gpu: True
device: 2
device: 0
early_stop:
patience: 0
federate:
Expand All @@ -17,13 +17,12 @@ llm:
chat:
max_len: 2000
adapter:
use: True
use: False
args: [ { 'adapter_package': 'peft', 'adapter_method': 'lora', 'r': 8, 'lora_alpha': 16, 'lora_dropout': 0.05 } ]
dataloader:
batch_size: 1
model:
# type: 'decapoda-research/llama-7b-hf@huggingface_llm'
type: 'openlm-research/open_llama_7b@huggingface_llm'
type: 'decapoda-research/llama-7b-hf@huggingface_llm'
train:
local_update_steps: 30
batch_or_epoch: batch
Expand All @@ -37,6 +36,6 @@ criterion:
trainer:
type: lomotrainer
eval:
freq: 1
freq: 10
metrics: ['loss']
count_flops: False
64 changes: 31 additions & 33 deletions federatedscope/llm/trainer/lomo_trainer.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,50 @@
# This implementation is adapted from https://github.com/OpenLMLab/LOMO (MIT License)

# Copyright (c) 2023 OpenLMLab

# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:

# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.

# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.

import torch
import logging
from federatedscope.register import register_trainer
# from federatedscope.core.trainers import GeneralTorchTrainer
from federatedscope.llm.trainer.trainer import LLMTrainer
from federatedscope.core.trainers.context import CtxVar
from federatedscope.core.trainers.enums import LIFECYCLE
from federatedscope.core.monitors.monitor import Monitor
from federatedscope.llm.model.adapter_builder import AdapterModel
from federatedscope.contrib.optimizer.lomo import LOMO


logger = logging.getLogger(__name__)


class LOMOTrainer(LLMTrainer):
def _hook_on_epoch_start(self, ctx):
if not isinstance(ctx.optimizer, LOMO):
raise AttributeError('"LOMO" must be set as the type of `train.optimizer` if the trainer is LOMOTrainer')
return super()._hook_on_epoch_start(ctx)


def _hook_on_batch_forward(self, ctx):
input_ids = ctx.data_batch['input_ids'].to(ctx.device)
labels = ctx.data_batch['labels'].to(ctx.device)
attention_mask = ctx.data_batch['attention_mask'].to(ctx.device)

# the first forward
outputs = ctx.model(input_ids=input_ids,
labels=labels,
attention_mask=attention_mask)
Expand All @@ -33,26 +60,9 @@ def _hook_on_batch_forward(self, ctx):
else:
ctx.skip_this_batch = CtxVar(False, LIFECYCLE.BATCH)

# if self.training_args.clip_grad_norm is not None and self.training_args.clip_grad_norm > 0:
if not ctx.skip_this_batch and ctx.optimizer.clip_grad_norm is not None and ctx.optimizer.clip_grad_norm > 0:
ctx.optimizer.grad_norm(loss)
# TODO check how to implement this
# if ctx.optimizer.loss_scaler and ctx.optimizer.loss_scaler.has_overflow_serial:
# # print(f"Gradient overflow, skipping step {self.global_step}")
# ctx.optimizer.get_param_coordinator(training=True).reset_step()
# # if self.allow_print:
# # self.wandb.log(
# # {
# # 'train/loss': loss.item(),
# # 'train/learning_rate': self.lr,
# # 'train/global_step': self.global_step,
# # },
# # step=self.global_step
# # )

# else:
# ctx.optimizer.get_param_coordinator(training=True).reset_step()
# 第二次forward
# the second forward
input_ids = ctx.data_batch['input_ids'].to(ctx.device)
labels = ctx.data_batch['labels'].to(ctx.device)
attention_mask = ctx.data_batch['attention_mask'].to(ctx.device)
Expand All @@ -70,26 +80,14 @@ def _hook_on_batch_forward(self, ctx):
ctx.batch_size = CtxVar(len(labels), LIFECYCLE.BATCH)

def _hook_on_batch_backward(self, ctx):
# ctx.optimizer.zero_grad()
if ctx.skip_this_batch:
return

# scaled_loss = loss * self.loss_scaler.loss_scale
#
# scaled_loss.backward()
# # update the last one since the hook function will not be called for the last parameter
# self.grad_func(0)
# self.loss_scaler.update_scale(overflow=False)
ctx.optimizer.fused_backward(ctx.loss_task, ctx.optimizer.lr)
# TODO check how to implement this
# ctx.optimizer.get_param_coordinator(training=True).reset_step()
# ctx.loss_task.backward()

if ctx.grad_clip > 0:
torch.nn.utils.clip_grad_norm_(ctx.model.parameters(),
ctx.grad_clip)

# ctx.optimizer.step()
if ctx.scheduler is not None:
ctx.scheduler.step()

Expand Down

0 comments on commit 287b152

Please sign in to comment.