Skip to content

Commit

Permalink
Browse files Browse the repository at this point in the history
…into 13b
  • Loading branch information
ymcui committed Aug 14, 2023
2 parents 6cfafa9 + 8125a66 commit 2b0c9b8
Show file tree
Hide file tree
Showing 9 changed files with 554 additions and 81 deletions.
50 changes: 48 additions & 2 deletions scripts/inference/gradio_demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import requests
from typing import Iterable, List
import subprocess
import re

DEFAULT_SYSTEM_PROMPT = """You are a helpful assistant. 你是一个乐于助人的助手。"""

Expand Down Expand Up @@ -87,6 +88,17 @@
default=8000,
help="Port of vLLM service.")
args = parser.parse_args()

ENABLE_CFG_SAMPLING = True
try:
from transformers.generation import UnbatchedClassifierFreeGuidanceLogitsProcessor
except ImportError:
ENABLE_CFG_SAMPLING = False
print("Install the latest transformers (commit equal or later than d533465) to enable CFG sampling.")
if args.use_vllm is True:
print("CFG sampling is disabled when using vLLM.")
ENABLE_CFG_SAMPLING = False

if args.only_cpu is True:
args.gpus = ""
if args.load_in_8bit or args.load_in_4bit:
Expand Down Expand Up @@ -333,19 +345,23 @@ def get_streaming_response(response: requests.Response) -> Iterable[List[str]]:
def predict(
history,
system_prompt,
negative_prompt,
max_new_tokens=128,
top_p=0.9,
temperature=0.2,
top_k=40,
do_sample=True,
repetition_penalty=1.1,
guidance_scale=1.0,
presence_penalty=0.0,
):
if len(system_prompt) == 0:
system_prompt = DEFAULT_SYSTEM_PROMPT
while True:
print("len(history):", len(history))
print("history: ", history)
history[-1][1] = ""
if len(history)==1:
if len(history) == 1:
input = history[0][0]
prompt = generate_prompt(input,response="", with_system_prompt=True, system_prompt=system_prompt)
else:
Expand Down Expand Up @@ -390,9 +406,18 @@ def predict(
yield history

else:
negative_text = None
if len(negative_prompt) != 0:
negative_text = re.sub(r"<<SYS>>\n(.*)\n<</SYS>>", f"<<SYS>>\n{negative_prompt}\n<</SYS>>", prompt)
inputs = tokenizer(prompt, return_tensors="pt")
input_ids = inputs["input_ids"].to(device)

if negative_text is None:
negative_prompt_ids = None
negative_prompt_attention_mask = None
else:
negative_inputs = tokenizer(negative_text,return_tensors="pt")
negative_prompt_ids = negative_inputs["input_ids"].to(device)
negative_prompt_attention_mask = negative_inputs["attention_mask"].to(device)
generate_params = {
'input_ids': input_ids,
'max_new_tokens': max_new_tokens,
Expand All @@ -402,6 +427,10 @@ def predict(
'do_sample': do_sample,
'repetition_penalty': repetition_penalty,
}
if ENABLE_CFG_SAMPLING is True:
generate_params['guidance_scale'] = guidance_scale
generate_params['negative_prompt_ids'] = negative_prompt_ids
generate_params['negative_prompt_attention_mask'] = negative_prompt_attention_mask

def generate_with_callback(callback=None, **kwargs):
if 'stopping_criteria' in kwargs:
Expand Down Expand Up @@ -450,6 +479,13 @@ def generate_with_streaming(**kwargs):
placeholder=DEFAULT_SYSTEM_PROMPT,
lines=1).style(
container=True)
negative_prompt_input = gr.Textbox(
show_label=True,
label="反向提示语(仅在对话开始前或清空历史后修改有效,对话过程中修改无效)",
placeholder="(可选,默认为空)",
lines=1,
visible=ENABLE_CFG_SAMPLING).style(
container=True)
with gr.Column(scale=12):
user_input = gr.Textbox(
show_label=True,
Expand Down Expand Up @@ -492,6 +528,14 @@ def generate_with_streaming(**kwargs):
label="Repetition Penalty",
interactive=True,
visible=False if args.use_vllm else True)
guidance_scale = gr.Slider(
1.0,
3.0,
value=1.0,
step=0.1,
label="Guidance Scale",
interactive=True,
visible=ENABLE_CFG_SAMPLING)
presence_penalty = gr.Slider(
-2.0,
2.0,
Expand All @@ -505,12 +549,14 @@ def generate_with_streaming(**kwargs):
predict_params = [
chatbot,
system_prompt_input,
negative_prompt_input,
max_new_token,
top_p,
temperature,
top_k,
do_sample,
repetition_penalty,
guidance_scale,
presence_penalty]

submitBtn.click(
Expand Down
94 changes: 74 additions & 20 deletions scripts/inference/inference_hf.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,14 +25,25 @@
parser.add_argument('--load_in_4bit', action='store_true', help="Load the LLM in the 4bit mode")
parser.add_argument("--use_vllm", action='store_true', help="Use vLLM as back-end LLM service.")
parser.add_argument('--system_prompt', type=str, default=DEFAULT_SYSTEM_PROMPT, help="The system prompt of the prompt template.")
parser.add_argument('--negative_prompt', type=str, default=None, help="Negative prompt in CFG sampling.")
parser.add_argument('--guidance_scale', type=float, default=1.0, help="The guidance scale for CFG sampling. CFG is enabled by setting `guidance_scale > 1`.")
args = parser.parse_args()

if args.guidance_scale > 1:
try:
from transformers.generation import UnbatchedClassifierFreeGuidanceLogitsProcessor
except ImportError:
raise ImportError("Please install the latest transformers (commit equal or later than d533465) to enable CFG sampling.")

if args.use_vllm:
if args.lora_model is not None:
raise ValueError("vLLM currently does not support LoRA, please merge the LoRA weights to the base model.")
if args.load_in_8bit or args.load_in_4bit:
raise ValueError("vLLM currently does not support quantization, please use fp16 (default) or unuse --use_vllm.")
if args.only_cpu:
raise ValueError("vLLM requires GPUs with compute capability not less than 7.0. If you want to run only on CPU, please unuse --use_vllm.")
if args.guidance_scale > 1:
raise ValueError("guidance_scale > 1, but vLLM does not support CFG sampling. Please unset guidance_scale. ")
if args.load_in_8bit and args.load_in_4bit:
raise ValueError("Only one quantization method can be chosen for inference. Please check your arguments")
if args.only_cpu is True:
Expand Down Expand Up @@ -76,8 +87,7 @@

sample_data = ["为什么要减少污染,保护环境?"]

def generate_prompt(instruction):
system_prompt = args.system_prompt or DEFAULT_SYSTEM_PROMPT
def generate_prompt(instruction, system_prompt=DEFAULT_SYSTEM_PROMPT):
return TEMPLATE.format_map({'instruction': instruction,'system_prompt': system_prompt})

if __name__ == '__main__':
Expand Down Expand Up @@ -156,22 +166,44 @@ def generate_prompt(instruction):
if len(raw_input_text.strip())==0:
break
if args.with_prompt:
input_text = generate_prompt(instruction=raw_input_text)
input_text = generate_prompt(instruction=raw_input_text, system_prompt=args.system_prompt)
negative_text = None if args.negative_prompt is None \
else generate_prompt(instruction=raw_input_text, system_prompt=args.negative_prompt)
else:
input_text = raw_input_text
negative_text = args.negative_prompt

if args.use_vllm:
output = model.generate([input_text], SamplingParams(**generation_config), use_tqdm=False)
response = output[0].outputs[0].text
else:
inputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ?
generation_output = model.generate(
input_ids = inputs["input_ids"].to(device),
attention_mask = inputs['attention_mask'].to(device),
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
generation_config = generation_config
)
if args.guidance_scale ==1:
generation_output = model.generate(
input_ids = inputs["input_ids"].to(device),
attention_mask = inputs['attention_mask'].to(device),
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
generation_config = generation_config
)
else: # enable CFG sampling
if negative_text is None:
negative_prompt_ids = None
negative_prompt_attention_mask = None
else:
negative_inputs = tokenizer(negative_text,return_tensors="pt")
negative_prompt_ids = negative_inputs["input_ids"].to(device)
negative_prompt_attention_mask = negative_inputs["attention_mask"].to(device)
generation_output = model.generate(
input_ids = inputs["input_ids"].to(device),
attention_mask = inputs['attention_mask'].to(device),
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
generation_config = generation_config,
guidance_scale = args.guidance_scale,
negative_prompt_ids = negative_prompt_ids,
negative_prompt_attention_mask = negative_prompt_attention_mask
)
s = generation_output[0]
output = tokenizer.decode(s,skip_special_tokens=True)
if args.with_prompt:
Expand All @@ -185,7 +217,7 @@ def generate_prompt(instruction):
results = []
if args.use_vllm:
if args.with_prompt is True:
inputs = [generate_prompt(example) for example in examples]
inputs = [generate_prompt(example, system_prompt=args.system_prompt) for example in examples]
else:
inputs = examples
outputs = model.generate(inputs, SamplingParams(**generation_config))
Expand All @@ -201,18 +233,40 @@ def generate_prompt(instruction):

else:
for index, example in enumerate(examples):
if args.with_prompt is True:
input_text = generate_prompt(instruction=example)
if args.with_prompt:
input_text = generate_prompt(instruction=example, system_prompt=args.system_prompt)
negative_text = None if args.negative_prompt is None else \
generate_prompt(instruction=example, system_prompt=args.negative_prompt)
else:
input_text = example
negative_text = args.negative_prompt
inputs = tokenizer(input_text,return_tensors="pt") #add_special_tokens=False ?
generation_output = model.generate(
input_ids = inputs["input_ids"].to(device),
attention_mask = inputs['attention_mask'].to(device),
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
generation_config = generation_config
)
if args.guidance_scale == 1:
generation_output = model.generate(
input_ids = inputs["input_ids"].to(device),
attention_mask = inputs['attention_mask'].to(device),
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
generation_config = generation_config
)
else: # enable CFG sampling
if negative_text is None:
negative_prompt_ids = None
negative_prompt_attention_mask = None
else:
negative_inputs = tokenizer(negative_text,return_tensors="pt")
negative_prompt_ids = negative_inputs["input_ids"].to(device)
negative_prompt_attention_mask = negative_inputs["attention_mask"].to(device)
generation_output = model.generate(
input_ids = inputs["input_ids"].to(device),
attention_mask = inputs['attention_mask'].to(device),
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.pad_token_id,
generation_config = generation_config,
guidance_scale = args.guidance_scale,
negative_prompt_ids = negative_prompt_ids,
negative_prompt_attention_mask = negative_prompt_attention_mask
)
s = generation_output[0]
output = tokenizer.decode(s,skip_special_tokens=True)
if args.with_prompt:
Expand Down
12 changes: 7 additions & 5 deletions scripts/openai_server_demo/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

安装依赖
``` shell
pip install fastapi uvicorn shortuuid
pip install fastapi uvicorn shortuuid sse_starlette
```

启动脚本
Expand Down Expand Up @@ -137,7 +137,7 @@ curl http://localhost:19327/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user","message": "给我讲一些有关杭州的故事吧"}
{"role": "user","content": "给我讲一些有关杭州的故事吧"}
],
"repetition_penalty": 1.0
}'
Expand Down Expand Up @@ -179,9 +179,9 @@ curl http://localhost:19327/v1/chat/completions \
-H "Content-Type: application/json" \
-d '{
"messages": [
{"role": "user","message": "给我讲一些有关杭州的故事吧"},
{"role": "assistant","message": "好的,请问您对杭州有什么特别的偏好吗?"},
{"role": "user","message": "我比较喜欢和西湖,可以给我讲一下西湖吗"}
{"role": "user","content": "给我讲一些有关杭州的故事吧"},
{"role": "assistant","content": "好的,请问您对杭州有什么特别的偏好吗?"},
{"role": "user","content": "我比较喜欢和西湖,可以给我讲一下西湖吗"}
],
"repetition_penalty": 1.0
}'
Expand Down Expand Up @@ -246,6 +246,8 @@ json返回体:

`do_sample`: 启用随机采样策略。默认为true。

`stream`: OpenAI格式的流式返回。默认为false,设置为true时,会按照OpenAI的格式流式返回数据,可以作为任意基于ChatGPT的应用的后端。

### 文本嵌入向量(text embedding)

文本嵌入向量有很多作用,包括但不限于基于大型文档问答、总结一本书中的内容、为大语言模型找到与当前用户输入最相近的记忆等等。
Expand Down
18 changes: 11 additions & 7 deletions scripts/openai_server_demo/README_vllm.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ python scripts/openai_server_demo/openai_api_server_vllm.py --model /path/to/bas

`--host {host_name}`: 部署服务的host name。默认值是`localhost`

`--prot {port}`: 部署服务的端口号。默认值是`8000`
`--port {port}`: 部署服务的端口号。默认值是`8000`

## API文档

Expand Down Expand Up @@ -118,7 +118,7 @@ json返回体:

`temperature`: 在0和2之间选择的采样温度。较高的值如0.8会使输出更加随机,而较低的值如0.2则会使其输出更具有确定性。temperature越高,使用随机采样最为decoding的概率越大。

`use_beam_search`: 使用束搜索(beam search)。默认为`False`,即启用随机采样策略(random sampling)
`use_beam_search`: 使用束搜索(beam search)。默认为`false`,即启用随机采样策略(random sampling)

`n`: 输出序列的数量,默认为1

Expand All @@ -130,6 +130,8 @@ json返回体:

`presence_penalty`: 重复惩罚,取值范围-2 ~ 2,默认值为0。值大于0表示鼓励模型使用新的token,反之鼓励重复。

`stream`: 设置为`true`时,按流式输出的形式返回。默认为`false`


### 聊天(chat completion)

Expand All @@ -145,7 +147,7 @@ curl http://localhost:8000/v1/chat/completions \
-d '{
"model": "chinese-llama-alpaca-2",
"messages": [
{"role": "user","message": "给我讲一些有关杭州的故事吧"}
{"role": "user","content": "给我讲一些有关杭州的故事吧"}
]
}'
```
Expand Down Expand Up @@ -180,9 +182,9 @@ curl http://localhost:8000/v1/chat/completions \
-d '{
"model": "chinese-llama-alpaca-2",
"messages": [
{"role": "user","message": "给我讲一些有关杭州的故事吧"},
{"role": "assistant","message": "好的,请问您对杭州有什么特别的偏好吗?"},
{"role": "user","message": "我比较喜欢和西湖,可以给我讲一下西湖吗"}
{"role": "user","content": "给我讲一些有关杭州的故事吧"},
{"role": "assistant","content": "好的,请问您对杭州有什么特别的偏好吗?"},
{"role": "user","content": "我比较喜欢和西湖,可以给我讲一下西湖吗"}
],
"repetition_penalty": 1.0
}'
Expand Down Expand Up @@ -216,7 +218,7 @@ json返回体:

`temperature`: 在0和2之间选择的采样温度。较高的值如0.8会使输出更加随机,而较低的值如0.2则会使其输出更具有确定性。temperature越高,使用随机采样最为decoding的概率越大。

`use_beam_search`: 使用束搜索(beam search)。默认为`False`,即启用随机采样策略(random sampling)
`use_beam_search`: 使用束搜索(beam search)。默认为`false`,即启用随机采样策略(random sampling)

`n`: 输出序列的数量,默认为1

Expand All @@ -227,3 +229,5 @@ json返回体:
`top_p`: 在随机采样(random sampling)时,累积概率超过top_p的token将作为候选token被随机采样,越低随机性越大,举个例子,当top_p设定为0.6时,概率前5的token概率分别为{0.23, 0.20, 0.18, 0.11, 0.10}时,前三个token的累积概率为0.61,那么第4个token将被过滤掉,只有前三的token将作为候选token被随机采样。

`presence_penalty`: 重复惩罚,取值范围-2 ~ 2,默认值为0。值大于0表示鼓励模型使用新的token,反之鼓励重复。

`stream`: 设置为`true`时,按流式输出的形式返回。默认为`false`
Loading

0 comments on commit 2b0c9b8

Please sign in to comment.