From cba05f3fb406e21217fd9514d9aacb13dbda4e00 Mon Sep 17 00:00:00 2001 From: shibing624 Date: Thu, 2 Nov 2023 14:35:13 +0800 Subject: [PATCH] update new version 1.1.2 --- README.md | 4 +- examples/gpt/inference_demo.py | 50 +++++-- examples/gpt/training_chatglm_demo.py | 9 +- examples/gpt/training_llama_demo.py | 9 +- .../seq2seq/training_seq2seq_model_demo.py | 5 +- tests/test_benchmark.py | 6 +- tests/test_chatglm.py | 29 +--- tests/test_chatglm_training.py | 7 +- tests/test_dataset.py | 4 +- tests/test_llama.py | 73 +--------- textgen/__init__.py | 2 +- textgen/config/model_args.py | 1 + textgen/gpt/gpt_model.py | 135 ++++++++++-------- 13 files changed, 146 insertions(+), 188 deletions(-) diff --git a/README.md b/README.md index 8847f92..8b2dfd9 100644 --- a/README.md +++ b/README.md @@ -23,6 +23,8 @@ ## 🔥 News +[2023/11/02] v1.1.2版本: GPT模型支持了[NEFTune](https://github.com/neelsjain/NEFTune)给embedding加噪SFT训练方法,SFT中使用 `--neft_alpha` 参数启用 NEFTune,例如 `--neft_alpha 5`。详见[Release-v1.1.2](https://github.com/shibing624/textgen/releases/tag/1.1.2) + [2023/09/05] v1.1.1版本: 支持多卡推理,推理速度加倍,调库textgen做batch推理,多卡推理更方便、快速。详见[Release-v1.1.1](https://github.com/shibing624/textgen/releases/tag/1.1.1) [2023/08/23] v1.1.0版本: 发布基于ShareGPT4数据集微调的中英文Vicuna-13B模型[shibing624/vicuna-baichuan-13b-chat](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat),和对应的LoRA模型[shibing624/vicuna-baichuan-13b-chat-lora](https://huggingface.co/shibing624/vicuna-baichuan-13b-chat-lora),支持多轮对话,评测效果有提升,详见[Release-v1.1.0](https://github.com/shibing624/textgen/releases/tag/1.1.0) @@ -36,7 +38,7 @@ ## 😊 Feature -- [GPT](textgen/gpt):本项目基于PyTorch实现了ChatGLM-6B/Baichuan/LLaMA2/BLOOM等GPT模型LoRA微调训练和预测,可以用于对话生成任务和领域微调训练 +- [GPT](textgen/gpt):本项目基于PyTorch实现了 ChatGLM-6B 1,2,3 / Baichuan 1,2 / LLaMA 1,2 / BLOOM / Mistral / QWen 等GPT模型LoRA微调训练和预测,可以用于对话生成任务和领域微调训练 - [UDA/EDA](textgen/augment/word_level_augment.py):本项目实现了UDA(非核心词替换)、EDA和Back Translation(回译)算法,基于TF-IDF将句子中部分不重要词替换为同义词,随机词插入、删除、替换等方法,产生新的文本,实现了文本扩增 - [Seq2Seq](textgen/seq2seq):本项目基于PyTorch实现了Seq2Seq、ConvSeq2Seq、BART模型的训练和预测,可以用于文本翻译、对话生成、摘要生成等文本生成任务 - [T5](textgen/t5):本项目基于PyTorch实现了T5和CopyT5模型训练和预测,可以用于文本翻译、对话生成、对联生成、文案撰写等文本生成任务 diff --git a/examples/gpt/inference_demo.py b/examples/gpt/inference_demo.py index f580e5e..aedba1f 100644 --- a/examples/gpt/inference_demo.py +++ b/examples/gpt/inference_demo.py @@ -22,11 +22,9 @@ def main(): parser.add_argument('--prompt_template_name', default="vicuna", type=str, help="Prompt template name, eg: alpaca, vicuna, baichuan-chat, chatglm2 etc.") parser.add_argument('--interactive', action='store_true', help="run in the instruction mode") - parser.add_argument('--single_round', action='store_true', - help="Whether to generate single round dialogue, default is multi-round dialogue") parser.add_argument('--data_file', default=None, type=str, help="A file that contains instructions (one instruction per line)") - parser.add_argument('--predictions_file', default='./predictions_result.jsonl', type=str) + parser.add_argument('--output_file', default='./predictions_result.jsonl', type=str) parser.add_argument('--batch_size', default=8, type=int, help='Batch size') args = parser.parse_args() print(args) @@ -49,19 +47,41 @@ def main(): for example in examples[:10]: print(example) if args.interactive: - print(f"Start inference with interactive mode. enable multi round: {not args.single_round}") + print(f"Start inference with interactive mode.") history = [] while True: - raw_input_text = input("Input:") - if len(raw_input_text.strip()) == 0: + try: + query = input("Input:") + except UnicodeDecodeError: + print("Detected decoding error at the inputs, please try again.") + continue + except Exception: + raise + if query == "": + print("Please input text, try again.") + continue + if query.strip() == "clear": + history = [] + print("history cleared.") + continue + if query.strip() == 'exit': break - if args.single_round: - response = model.predict([raw_input_text], prompt_template_name=args.prompt_template_name)[0] - else: - response, history = model.chat( - raw_input_text, history=history, prompt_template_name=args.prompt_template_name) - print("Response: ", response) - print("\n") + print("Response:", end='', flush=True) + try: + response = "" + for new_token in model.chat( + query, + history=history, + prompt_template_name=args.prompt_template_name, + stream=True + ): + print(new_token, end='', flush=True) + response += new_token + history = history + [[query, response]] + except KeyboardInterrupt: + print("KeyboardInterrupt detected, stop.") + continue + print() else: print("Start inference.") results = [] @@ -75,11 +95,11 @@ def main(): print(f"Input: {example}\n") print(f"Output: {response}\n") results.append({"Input": example, "Output": response}) - with open(args.predictions_file, 'w', encoding='utf-8') as f: + with open(args.output_file, 'w', encoding='utf-8') as f: for entry in results: json.dump(entry, f, ensure_ascii=False) f.write('\n') - print(f'save to {args.predictions_file}, size: {len(results)}') + print(f'save to {args.output_file}, size: {len(results)}') if __name__ == '__main__': diff --git a/examples/gpt/training_chatglm_demo.py b/examples/gpt/training_chatglm_demo.py index b09f3a8..2719824 100644 --- a/examples/gpt/training_chatglm_demo.py +++ b/examples/gpt/training_chatglm_demo.py @@ -67,10 +67,13 @@ def main(): print(response) # Chat model with multi turns conversation - response, history = model.chat('请问1加2等于多少?') + history = [] + query = "简单介绍下北京" + response = model.chat(query, history=history) + print(response) + history.append([query, response]) + response = model.chat('继续', history=history) print(response) - response, history = model.chat('两数相乘呢?', history=history) - print(response, history) if __name__ == '__main__': diff --git a/examples/gpt/training_llama_demo.py b/examples/gpt/training_llama_demo.py index 10e5e27..81bb437 100644 --- a/examples/gpt/training_llama_demo.py +++ b/examples/gpt/training_llama_demo.py @@ -67,10 +67,13 @@ def main(): print(response) # Chat model with multi turns conversation - response, history = model.chat('请问1加2等于多少?') + history = [] + query = "简单介绍下北京" + response = model.chat(query, history=history) + print(response) + history.append([query, response]) + response = model.chat('继续', history=history) print(response) - response, history = model.chat('两数相乘呢?', history=history) - print(response, history) if __name__ == '__main__': diff --git a/examples/seq2seq/training_seq2seq_model_demo.py b/examples/seq2seq/training_seq2seq_model_demo.py index 3eff49d..c7e6283 100644 --- a/examples/seq2seq/training_seq2seq_model_demo.py +++ b/examples/seq2seq/training_seq2seq_model_demo.py @@ -4,11 +4,10 @@ @description: """ import argparse -import pandas as pd -from loguru import logger -import os import sys +from loguru import logger + sys.path.append('../..') from textgen.seq2seq import Seq2SeqModel diff --git a/tests/test_benchmark.py b/tests/test_benchmark.py index 1116deb..5ecc4df 100644 --- a/tests/test_benchmark.py +++ b/tests/test_benchmark.py @@ -10,7 +10,7 @@ import pandas as pd sys.path.append('..') -from textgen import GptModel, ChatGlmModel +from textgen import GptModel pwd_path = os.path.abspath(os.path.dirname(__file__)) @@ -76,7 +76,7 @@ def test_llama_13b_alpaca_plus(): def test_chatglm_6b(): - m = ChatGlmModel('chatglm', "THUDM/chatglm-6b", peft_name=None, args={'use_peft': False}) + m = GptModel('chatglm', "THUDM/chatglm-6b", peft_name=None, args={'use_peft': False}) predict_sentences = [get_chatglm_prompt(s) for s in sentences] res = m.predict(predict_sentences) for s, i in zip(sentences, res): @@ -91,7 +91,7 @@ def test_chatglm_6b(): def test_chatglm_6b_lora(): - m = ChatGlmModel('chatglm', "THUDM/chatglm-6b", peft_name='shibing624/chatglm-6b-belle-zh-lora', + m = GptModel('chatglm', "THUDM/chatglm-6b", peft_name='shibing624/chatglm-6b-belle-zh-lora', args={'use_peft': True}, ) predict_sentences = [get_chatglm_prompt(s) for s in sentences] res = m.predict(predict_sentences) diff --git a/tests/test_chatglm.py b/tests/test_chatglm.py index ad3bfdc..7b0563f 100644 --- a/tests/test_chatglm.py +++ b/tests/test_chatglm.py @@ -5,15 +5,14 @@ """ import sys -import pytest sys.path.append('..') -from textgen import ChatGlmArgs, ChatGlmModel +from textgen import GptModel def test_csc(): from pycorrector.utils import eval - model = ChatGlmModel( + model = GptModel( 'chatglm', "THUDM/chatglm-6b", peft_name="shibing624/chatglm-6b-csc-zh-lora", args={'use_peft': True, 'eval_batch_size': 8, "max_length": 128} ) @@ -35,27 +34,11 @@ def batch_correct(sentences): def test_origin(): - m = ChatGlmModel('chatglm', "THUDM/chatglm-6b", args={'use_peft': False}) - response, history = m.chat("你好", history=[]) + m = GptModel('chatglm', "THUDM/chatglm-6b", args={'use_peft': False}) + response = m.chat("你好", history=[]) print(response) assert len(response) > 0 - response, history = m.chat("晚上睡不着应该怎么办", history=history) - print(response) - assert len(response) > 0 - - -def test_origin_int4(): - m = ChatGlmModel('chatglm', "THUDM/chatglm-6b-int4", args={'use_peft': False, "quantization_bit": None}, - cuda_device=0) - response, history = m.chat("你好", history=[], max_length=20) - print(response) - assert len(response) > 0 - - -def test_origin_int4_cpu(): - m = ChatGlmModel('chatglm', "THUDM/chatglm-6b-int4", use_cuda=False, - args={'use_peft': False, "quantization_bit": None}, - cuda_device=0) - response, history = m.chat("你好", history=[], max_length=20) + history = ["你好", response] + response = m.chat("晚上睡不着应该怎么办", history=history) print(response) assert len(response) > 0 diff --git a/tests/test_chatglm_training.py b/tests/test_chatglm_training.py index 72fad22..071c2ea 100644 --- a/tests/test_chatglm_training.py +++ b/tests/test_chatglm_training.py @@ -4,12 +4,13 @@ @description: """ import sys +import os import pytest from torch.utils.data import Dataset from datasets import load_dataset, load_from_disk sys.path.append('..') -from textgen import ChatGlmModel +from textgen import GptModel def preprocess_batch_for_hf_dataset(example, tokenizer, args): @@ -51,7 +52,7 @@ def __getitem__(self, index): def test_train_name(): - model = ChatGlmModel( + model = GptModel( "chatglm", "THUDM/chatglm-6b", args={ "dataset_class": MyDataset, @@ -74,7 +75,7 @@ def test_train_name(): def test_second_predict(): - model = ChatGlmModel("chatglm", "THUDM/chatglm-6b", + model = GptModel("chatglm", "THUDM/chatglm-6b", args={"use_peft": True}, peft_name='tmp_outputs') # load model from peft_name is equal to load model from output_dir sents = ['我要开一家美妆店,帮我起一个店铺名\n答:'] diff --git a/tests/test_dataset.py b/tests/test_dataset.py index 15887a3..2d75298 100644 --- a/tests/test_dataset.py +++ b/tests/test_dataset.py @@ -11,7 +11,7 @@ from transformers import AutoTokenizer sys.path.append('..') -from textgen.llama.llama_utils import LlamaPretrainingDataset +from textgen import GptSupervisedDataset from textgen import GptArgs @@ -29,7 +29,7 @@ def test_data(): train_data = load_data('../examples/data/pt.txt') train_df = pd.DataFrame(train_data, columns=["text"]) eval_df = train_df[:10] - ds = LlamaPretrainingDataset( + ds = GptSupervisedDataset( tokenizer, args, train_df, diff --git a/tests/test_llama.py b/tests/test_llama.py index c9ff2a2..dd5cd27 100644 --- a/tests/test_llama.py +++ b/tests/test_llama.py @@ -38,77 +38,6 @@ def test_origin_7b(): r = m.predict([predict_sentence]) print(r) assert len(r) > 0 - response, history = m.chat("你好", history=[]) + response = m.chat("你好", history=[]) print(response) assert len(response) > 0 - response, history = m.chat("晚上睡不着应该怎么办", history=history) - print(response) - assert len(response) > 0 - - predict_sentences = [generate_prompt(s) for s in sents] - res = m.predict(predict_sentences) - for s, i in zip(sents, res): - print(s, i) - print() - - -def test_lora_7b(): - m = GptModel('llama', "decapoda-research/llama-7b-hf", peft_name='ziqingyang/chinese-alpaca-lora-7b', - args={'use_peft': True}, ) - predict_sentence = generate_prompt("失眠怎么办?") - r = m.predict([predict_sentence]) - print(r) - assert len(r) > 0 - response, history = m.chat("你好", history=[]) - print(response) - assert len(response) > 0 - response, history = m.chat("晚上睡不着应该怎么办", history=history) - print(response) - assert len(response) > 0 - - predict_sentences = [generate_prompt(s) for s in sents] - res = m.predict(predict_sentences) - for s, i in zip(sents, res): - print(s, i) - print() - - -def test_origin_13b(): - m = GptModel('llama', "decapoda-research/llama-13b-hf", args={'use_peft': False}) - predict_sentence = generate_prompt("失眠怎么办?") - r = m.predict([predict_sentence]) - print(r) - assert len(r) > 0 - response, history = m.chat("你好", history=[]) - print(response) - assert len(response) > 0 - response, history = m.chat("晚上睡不着应该怎么办", history=history) - print(response) - assert len(response) > 0 - - predict_sentences = [generate_prompt(s) for s in sents] - res = m.predict(predict_sentences) - for s, i in zip(sents, res): - print(s, i) - print() - - -def test_lora_13b(): - m = GptModel('llama', "decapoda-research/llama-13b-hf", peft_name='shibing624/llama-13b-belle-zh-lora', - args={'use_peft': True}, ) - predict_sentence = generate_prompt("失眠怎么办?") - r = m.predict([predict_sentence]) - print(r) - assert len(r) > 0 - response, history = m.chat("你好", history=[]) - print(response) - assert len(response) > 0 - response, history = m.chat("晚上睡不着应该怎么办", history=history) - print(response) - assert len(response) > 0 - - predict_sentences = [generate_prompt(s) for s in sents] - res = m.predict(predict_sentences) - for s, i in zip(sents, res): - print(s, i) - print() diff --git a/textgen/__init__.py b/textgen/__init__.py index cf9c85b..caa7cc4 100644 --- a/textgen/__init__.py +++ b/textgen/__init__.py @@ -4,7 +4,7 @@ @description: """ -__version__ = '1.1.1' +__version__ = '1.1.2' from textgen.augment.text_augment import TextAugment diff --git a/textgen/config/model_args.py b/textgen/config/model_args.py index ca7f614..477dab8 100644 --- a/textgen/config/model_args.py +++ b/textgen/config/model_args.py @@ -392,3 +392,4 @@ class GptArgs(ModelArgs): qlora: bool = False preprocessing_num_workers: int = 4 prompt_template_name: str = "vicuna" + neft_alpha: int = 0 # 5 diff --git a/textgen/gpt/gpt_model.py b/textgen/gpt/gpt_model.py index a4d8402..fa84306 100644 --- a/textgen/gpt/gpt_model.py +++ b/textgen/gpt/gpt_model.py @@ -7,8 +7,8 @@ import os import random from threading import Thread -from typing import List, Tuple, Optional - +from typing import List, Tuple, Optional, Union +from types import MethodType import numpy as np import torch from loguru import logger @@ -140,23 +140,6 @@ def __init__( torch_dtype=self.torch_dtype, **kwargs ) - self.model = model_class.from_pretrained( - model_name, - config=self.config, - load_in_8bit=self.args.int8, - load_in_4bit=self.args.int4, - torch_dtype=self.torch_dtype, - low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), - device_map=self.device_map, - trust_remote_code=self.args.trust_remote_code, - quantization_config=BitsAndBytesConfig( - load_in_4bit=True, - bnb_4bit_use_double_quant=True, - bnb_4bit_quant_type="nf4", - bnb_4bit_compute_dtype=self.torch_dtype, - ) if self.args.qlora else None, - ) - self.tokenizer_class = tokenizer_class if self.args.tokenizer_name: self.tokenizer = tokenizer_class.from_pretrained( @@ -174,6 +157,38 @@ def __init__( else: self.tokenizer.pad_token = self.tokenizer.eos_token logger.debug("Add pad token: {}".format(self.tokenizer.pad_token)) + # Load model + self.model = model_class.from_pretrained( + model_name, + config=self.config, + load_in_8bit=self.args.int8, + load_in_4bit=self.args.int4, + torch_dtype=self.torch_dtype, + low_cpu_mem_usage=(not is_deepspeed_zero3_enabled()), + device_map=self.device_map, + trust_remote_code=self.args.trust_remote_code, + quantization_config=BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type="nf4", + bnb_4bit_compute_dtype=self.torch_dtype, + ) if self.args.qlora else None, + ) + # Set NEFTune trick for fine-tuning + if self.args.neft_alpha > 0: + input_embed = self.model.get_input_embeddings() + if isinstance(input_embed, torch.nn.Embedding): + def noisy_forward(self: torch.nn.Embedding, x: torch.Tensor) -> torch.Tensor: + embeddings = input_embed.__class__.forward(self, x) + dims = self.num_embeddings * self.embedding_dim + mag_norm = self.args.neft_alpha / (dims ** 0.5) + embeddings += torch.zeros_like(embeddings).uniform_(-mag_norm, mag_norm) + return embeddings + + input_embed.forward = MethodType(noisy_forward, input_embed) + logger.info("Using noisy embedding with alpha={:.2f}".format(self.args.neft_alpha)) + else: + logger.warning("Input embeddings are not normal nn.Embedding, cannot transform into noisy embedding.") self.args.model_type = model_type if model_name is None: @@ -183,18 +198,14 @@ def __init__( self.peft_name = peft_name if self.args.use_peft and self.peft_name: - self.load_peft_model() - - def load_peft_model(self): - """Load peft model""" - self.model = PeftModel.from_pretrained( - self.model, - self.peft_name, - torch_dtype=self.torch_dtype, - device_map=self.device_map, - ) - self.model = self.model.merge_and_unload() - logger.info(f"Loaded peft model from {self.peft_name}") + """Load peft model""" + self.model = PeftModel.from_pretrained( + self.model, + self.peft_name, + torch_dtype=self.torch_dtype, + device_map=self.device_map, + ) + logger.info(f"Loaded peft model from {self.peft_name}") def find_all_linear_names(self, int4=False, int8=False): cls = torch.nn.Linear @@ -585,7 +596,8 @@ def predict( def chat( self, query: str, - history: List[Tuple[str, str]] = None, + history: Union[List, List[Tuple[str, str]]] = None, + stream: bool = False, skip_prompt: bool = True, prompt_template_name: str = "vicuna", max_new_tokens: int = None, @@ -602,37 +614,42 @@ def chat( history = [] history.append([query, '']) prompt = prompt_template.get_prompt(messages=history) - streamer = TextIteratorStreamer( - self.tokenizer, timeout=60.0, skip_prompt=skip_prompt, skip_special_tokens=True) + input_ids = self.tokenizer(prompt).input_ids max_new_tokens = max_new_tokens if max_new_tokens is not None else self.args.max_length max_src_len = context_len - max_new_tokens - 8 input_ids = input_ids[-max_src_len:] - generation_kwargs = dict( - input_ids=torch.as_tensor([input_ids]).to(self.device), - max_new_tokens=max_new_tokens, - do_sample=do_sample if do_sample is not None else self.args.do_sample, - temperature=temperature if temperature is not None else self.args.temperature, - repetition_penalty=repetition_penalty if repetition_penalty is not None else self.args.repetition_penalty, - streamer=streamer, - **kwargs, - ) - thread = Thread(target=self.model.generate, kwargs=generation_kwargs) - thread.start() - stop_str = self.tokenizer.eos_token or prompt_template.stop_str - generated_text = "" - for new_text in streamer: - stop = False - pos = new_text.find(stop_str) - if pos != -1: - new_text = new_text[:pos] - stop = True - generated_text += new_text - if stop: - break - response = generated_text.strip() - history = history + [[query, response]] - return response, history + if stream: + streamer = TextIteratorStreamer( + self.tokenizer, timeout=60.0, skip_prompt=skip_prompt, skip_special_tokens=True + ) + generation_kwargs = dict( + input_ids=torch.as_tensor([input_ids]).to(self.device), + max_new_tokens=max_new_tokens, + do_sample=do_sample if do_sample is not None else self.args.do_sample, + temperature=temperature if temperature is not None else self.args.temperature, + repetition_penalty=repetition_penalty if repetition_penalty is not None else self.args.repetition_penalty, + streamer=streamer, + **kwargs, + ) + thread = Thread(target=self.model.generate, kwargs=generation_kwargs) + thread.start() + yield from streamer + else: + generation_kwargs = dict( + max_new_tokens=max_new_tokens if max_new_tokens is not None else self.args.max_length, + do_sample=do_sample if do_sample is not None else self.args.do_sample, + temperature=temperature if temperature is not None else self.args.temperature, + repetition_penalty=repetition_penalty if repetition_penalty is not None else self.args.repetition_penalty, + ) + outputs = self.model.generate( + input_ids=torch.as_tensor([input_ids]).to(self.device), + **generation_kwargs, + **kwargs, + ) + output_tensor = outputs[0][len(input_ids[0]):] if skip_prompt else outputs[0] + response = self.tokenizer.decode(output_tensor, skip_special_tokens=True) + return response def load_and_cache_examples( self, data, evaluate=False, no_cache=False, verbose=True, silent=False