Skip to content

Commit

Permalink
add support for models with fp16
Browse files Browse the repository at this point in the history
  • Loading branch information
zhenqincn committed Aug 15, 2023
1 parent f6c46ac commit 0cfb86d
Show file tree
Hide file tree
Showing 2 changed files with 25 additions and 19 deletions.
16 changes: 9 additions & 7 deletions federatedscope/contrib/optimizer/lomo.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,13 +48,15 @@ def __init__(self, model, lr=1e-3, clip_grad_norm=None, clip_grad_value=None):
self.grad_func = self.fuse_update()
# check if fp16 is enabled
if p0.dtype == torch.float16:
self.loss_scaler = DynamicLossScaler(
init_scale=2 ** 16,
) # TODO: add args
if self.clip_grad_norm is None:
raise ValueError(
"Loss scaling is recommended to be used with grad norm to get better performance."
)
# TODO temporarily removed for test lomo with llama
self.loss_scaler = None
# self.loss_scaler = DynamicLossScaler(
# init_scale=2 ** 16,
# ) # TODO: add args
# if self.clip_grad_norm is None:
# raise ValueError(
# "Loss scaling is recommended to be used with grad norm to get better performance."
# )
else:
self.loss_scaler = None

Expand Down
28 changes: 16 additions & 12 deletions federatedscope/llm/baseline/lomo.yaml
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
use_gpu: True
device: 3
device: 2
early_stop:
patience: 10
patience: 0
federate:
mode: standalone
client_num: 1
total_round_num: 200
save_to: "gpt2.ckpt"
share_local_model: False
online_aggr: False
total_round_num: 500
save_to: "llama.ckpt"
data:
root: data/
type: 'alpaca@llm'
Expand All @@ -17,22 +15,28 @@ data:
llm:
tok_len: 1000
chat:
max_len: 1000
max_len: 2000
adapter:
use: True
args: [ { 'adapter_package': 'peft', 'adapter_method': 'lora', 'r': 8, 'lora_alpha': 16, 'lora_dropout': 0.05 } ]
dataloader:
batch_size: 1
model:
type: 'gpt2@huggingface_llm'
# type: 'decapoda-research/llama-7b-hf@huggingface_llm'
type: 'openlm-research/open_llama_7b@huggingface_llm'
train:
local_update_steps: 10
local_update_steps: 30
batch_or_epoch: batch
optimizer:
type: LOMO
lr: 0.0001
lr: 0.0003
weight_decay: 0.0
is_enable_half: True
criterion:
type: CrossEntropyLoss
trainer:
type: llmtrainer
type: lomotrainer
eval:
freq: 1
metrics: ['loss']
metrics: ['loss']
count_flops: False

0 comments on commit 0cfb86d

Please sign in to comment.