Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Offsite-tuning model generation #676

Open
wants to merge 21 commits into
base: dev/llm
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 48 additions & 18 deletions federatedscope/llm/misc/fschat.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
import logging
import torch
import transformers
import os
import gc

transformers.logging.set_verbosity(40)

Expand All @@ -18,38 +20,66 @@

class FSChatBot(object):
def __init__(self, config):
model_name, _ = config.model.type.split('@')
self.tokenizer, _ = get_tokenizer(model_name, config.data.root,
config.llm.tok_len)
self.model = get_llm(config)
self.config = config

self.device = f'cuda:{config.device}'
self.add_special_tokens = True

if config.llm.offsite_tuning.use:
from federatedscope.llm.offsite_tuning.utils import \
wrap_offsite_tuning_for_eval
self.model = wrap_offsite_tuning_for_eval(self.model, config)
else:
try:
ckpt = torch.load(config.federate.save_to, map_location='cpu')
self.prefix = ['']
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should prefix be passed by the config?

self.dirname, self.filename = os.path.split(config.federate.save_to)
self.next_model()

def next_model(self):
if hasattr(self, 'model'):
delattr(self, 'model')
gc.collect()

model_name, _ = self.config.model.type.split('@')
self.tokenizer, _ = get_tokenizer(model_name, self.config.data.root,
self.config.llm.tok_len)
self.model = get_llm(self.config)

self.curpfx = None
for pre in self.prefix:
if os.path.exists(os.path.join(self.dirname, pre + self.filename)):
self.curpfx = pre
break

# Load model from the checkpoints
if self.curpfx is not None:
ckpt_path = os.path.join(self.dirname, self.curpfx + self.filename)
if self.config.llm.offsite_tuning.use:
from federatedscope.llm.offsite_tuning.utils import \
wrap_offsite_tuning_for_eval
self.model = wrap_offsite_tuning_for_eval(
self.model, self.config, ckpt_path)
else:
ckpt = torch.load(ckpt_path, map_location='cpu')
if 'model' and 'cur_round' in ckpt:
self.model.load_state_dict(ckpt['model'])
logger.info(
f"Load with the model of Round {ckpt['cur_round']}")
else:
self.model.load_state_dict(ckpt)
except Exception as error:
print(f"{error}, will use raw model.")
logger.info(f'Model loads from the checkpoint {ckpt_path}')

# remove the prefix up to the current one
self.prefix = self.prefix[self.prefix.index(self.curpfx) + 1:]
elif len(self.prefix) > 1:
logger.info("will use raw model.")
else:
raise ValueError('No more model is able to us')

if config.train.is_enable_half:
if self.config.train.is_enable_half:
self.model.half()

self.model = self.model.to(self.device)
self.model = self.model.eval()
if torch.__version__ >= "2" and sys.platform != "win32":
self.model = torch.compile(self.model)

self.max_history_len = config.llm.chat.max_history_len
self.max_len = config.llm.chat.max_len
self.max_history_len = self.config.llm.chat.max_history_len
self.max_len = self.config.llm.chat.max_len
self.history = []

def _build_prompt(self, input_text):
Expand Down Expand Up @@ -123,8 +153,8 @@ def main():
setup_seed(init_cfg.seed)

chat_bot = FSChatBot(init_cfg)
welcome = "Welcome to FSChatBot" \
"`clear` to clear history" \
welcome = "Welcome to FSChatBot, " \
"`clear` to clear history, " \
"`quit` to end chat."
print(welcome)
while True:
Expand Down
17 changes: 6 additions & 11 deletions federatedscope/llm/offsite_tuning/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from federatedscope.core.workers.server import Server

from federatedscope.llm.offsite_tuning.utils import \
generate_emulator_and_adapter, align_student_with_teacher
generate_adap_model, align_student_with_teacher

logger = logging.getLogger(__name__)

Expand All @@ -30,17 +30,8 @@ def __init__(self,
device='cpu',
strategy=None,
**kwargs):
compress_strategy = config.llm.offsite_tuning.strategy
emulator_l = config.llm.offsite_tuning.emu_l
emulator_r = config.llm.offsite_tuning.emu_r
offsite_tuning_kwargs = config.llm.offsite_tuning.kwargs[0]
logger.info('Server: Generating emulator and adapter...')
adap_model = \
generate_emulator_and_adapter(model,
strategy=compress_strategy,
emulator_l=emulator_l,
emulator_r=emulator_r,
**offsite_tuning_kwargs)
adap_model = generate_adap_model(model, config.llm.offsite_tuning)
# Emulator alignment
if config.llm.offsite_tuning.emu_align.use:
adap_model = align_student_with_teacher(raw_model=model,
Expand All @@ -54,7 +45,11 @@ def __init__(self,
os._exit(0)
# No need for this attr
if hasattr(adap_model, 'teacher'):
import gc
import torch
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about move line 48-49 to top:

try:
    import gc
    import torch
except ImportError:
    gc=None
    torch=None

del adap_model.teacher
gc.collect()
torch.cuda.empty_cache()

self.raw_model = model
super(OffsiteTuningServer,
Expand Down
60 changes: 37 additions & 23 deletions federatedscope/llm/offsite_tuning/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def get_layers(adapter_model):
return layers


def set_layers(adapter_model, layers, emu_l=0, emu_r=-1):
def set_layers(adapter_model, layers):
if isinstance(adapter_model.model, OPTForCausalLM):
adapter_model.model.model.decoder.layers = layers
elif isinstance(adapter_model.model, GPT2LMHeadModel):
Expand All @@ -109,12 +109,6 @@ def set_layers(adapter_model, layers, emu_l=0, emu_r=-1):
logger.warning(f'Model {type(adapter_model.model)} not support, '
f'use default setting.')
adapter_model.model.transformer.h = layers
adapter_model.student = layers[emu_l:emu_r]
adapter_model.adapter = layers[:emu_l] + layers[emu_r:]
add_prologue(adapter_model.student[0], None)
add_epilogue(adapter_model.student[-1], None)
adapter_model.student_l = adapter_model.student[0]
adapter_model.student_r = adapter_model.student[-1]
return adapter_model


Expand Down Expand Up @@ -152,13 +146,31 @@ def model_distillation(model, **kwargs):
}


def generate_adap_model(model: AdapterModel, offsite_tuning_cfg):
if offsite_tuning_cfg.strategy in COMP_FUNC_MAPPING.keys():
compress_strategy = offsite_tuning_cfg.strategy
emulator_l = offsite_tuning_cfg.emu_l
emulator_r = offsite_tuning_cfg.emu_r
emu_align = offsite_tuning_cfg.emu_align.use
offsite_tuning_kwargs = offsite_tuning_cfg.kwargs[0]
return generate_emulator_and_adapter(model,
strategy=compress_strategy,
emulator_l=emulator_l,
emulator_r=emulator_r,
emulator_alignment=emu_align,
**offsite_tuning_kwargs)
else:
raise NotImplementedError


def generate_emulator_and_adapter(model: AdapterModel,
strategy='drop_layer',
emulator_l=1,
emulator_l=0,
emulator_r=1000,
emulator_alignment=False,
**kwargs):
layers = get_layers(model)
l, r = max(emulator_l, 1), min(emulator_r, len(layers) - 1)
l, r = max(emulator_l, 0), min(emulator_r, len(layers) - 1)

# Set the to-compress part untrainable
for layer in layers[l:r]:
Expand Down Expand Up @@ -186,7 +198,14 @@ def generate_emulator_and_adapter(model: AdapterModel,

new_model = copy.deepcopy(model)
# Set student model
new_model = set_layers(new_model, emulator_and_adapter, l, r)
new_model = set_layers(new_model, emulator_and_adapter)

if emulator_alignment:
new_model.student = layers
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please merge the latest commits in which bugs are fixed. (layers should be detached from new_model)

add_prologue(new_model.student[0], None)
add_epilogue(new_model.student[-1], None)
new_model.student_l = new_model.student[0]
new_model.student_r = new_model.student[-1]

gc.collect()
torch.cuda.empty_cache()
Expand Down Expand Up @@ -303,20 +322,11 @@ def build_cfg_for_alignment(config):
return adap_model


def wrap_offsite_tuning_for_eval(model, config):
def wrap_offsite_tuning_for_eval(model, config, ckpt_path=None):
logger.info('===============use offsite tuning===============')
# We use offsite-tuning in this experiment
# Use adapter model instead
compress_strategy = config.llm.offsite_tuning.strategy
emulator_l = config.llm.offsite_tuning.emu_l
emulator_r = config.llm.offsite_tuning.emu_r
offsite_tuning_kwargs = config.llm.offsite_tuning.kwargs[0]
adap_model = \
generate_emulator_and_adapter(model,
strategy=compress_strategy,
emulator_l=emulator_l,
emulator_r=emulator_r,
**offsite_tuning_kwargs)
adap_model = generate_adap_model(model, config.llm.offsite_tuning)
# Load kd model if ckpt exits
if config.llm.offsite_tuning.emu_align.use and \
config.llm.offsite_tuning.eval_type == 'emu':
Expand All @@ -333,17 +343,21 @@ def wrap_offsite_tuning_for_eval(model, config):

# Load ckpt for eval
try:
ckpt = torch.load(config.federate.save_to, map_location='cpu')
if ckpt_path is None:
ckpt_path = config.federate.save_to
ckpt = torch.load(ckpt_path, map_location='cpu')
if 'model' and 'cur_round' in ckpt:
adap_model.load_state_dict(ckpt['model'])
logger.info(f"Load with the model of Round {ckpt['cur_round']}")
else:
adap_model.load_state_dict(ckpt)
except Exception as error:
logger.warning(f"{error}, will use raw model.")

if config.llm.offsite_tuning.eval_type == 'emu':
model = adap_model
del model.teacher
if hasattr(model, 'teacher'):
del model.teacher
elif config.llm.offsite_tuning.eval_type == 'full':
# Raw model load adapter from adapter_and_emulator
new_model_state_dict = model.state_dict()
Expand Down