Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bad Case]: function calling可以不用vllm吗? #203

Open
cristianohello opened this issue Sep 6, 2024 · 4 comments
Open

[Bad Case]: function calling可以不用vllm吗? #203

cristianohello opened this issue Sep 6, 2024 · 4 comments
Labels
badcase Bad cases

Comments

@cristianohello
Copy link

Description / 描述

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time

s = time.time()
path = "D:/MiniCPM3-4B"
device = "cpu"

tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True)

messages2 = [
{"role": "user", "content": "推荐5个北京的景点。"},
]

messages = [
{"role": "user", "content": "汽车保险有哪些"},
]
model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(device)

model_outputs = model.generate(
model_inputs,
max_new_tokens=1024,
top_p=0.7,
temperature=0.7
)

output_token_ids = [
model_outputs[i][len(model_inputs[i]):] for i in range(len(model_inputs))
]

responses = tokenizer.batch_decode(output_token_ids, skip_special_tokens=True)[0]
print(responses)
e = time.time()
print(e-s)

Case Explaination / 案例解释

from transformers import AutoModelForCausalLM, AutoTokenizer
import torch
import time

s = time.time()
path = "D:/MiniCPM3-4B"
device = "cpu"

tokenizer = AutoTokenizer.from_pretrained(path, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(path, torch_dtype=torch.bfloat16, device_map=device, trust_remote_code=True)

messages2 = [
{"role": "user", "content": "推荐5个北京的景点。"},
]

messages = [
{"role": "user", "content": "汽车保险有哪些"},
]
model_inputs = tokenizer.apply_chat_template(messages, return_tensors="pt", add_generation_prompt=True).to(device)

model_outputs = model.generate(
model_inputs,
max_new_tokens=1024,
top_p=0.7,
temperature=0.7
)

output_token_ids = [
model_outputs[i][len(model_inputs[i]):] for i in range(len(model_inputs))
]

responses = tokenizer.batch_decode(output_token_ids, skip_special_tokens=True)[0]
print(responses)
e = time.time()
print(e-s)

@cristianohello cristianohello added the badcase Bad cases label Sep 6, 2024
@cristianohello
Copy link
Author

比如上面的代码。怎么使用function calling?不要用vllm.

@Cppowboy
Copy link

Cppowboy commented Sep 6, 2024

可以参考示例代码,先把你的messages format成某个格式的prompt,然后用模型推理就可以。用transformer或vllm都可以。

@cristianohello
Copy link
Author

可以参考示例代码,先把你的messages format成某个格式的prompt,然后用模型推理就可以。用transformer或vllm都可以。

能给出具体代码吗?刚入坑的小白

@Cppowboy
Copy link

Cppowboy commented Sep 12, 2024

可以使用tokenizer.apply_chat_template来做模板处理,供参考。

from transformers import AutoTokenizer
import json


tools = [
    {
        "type": "function",
        "function": {
            "name": "get_delivery_date",
            "description": "Get the delivery date for a customer's order. Call this whenever you need to know the delivery date, for example when a customer asks 'Where is my package'",
            "parameters": {
                "type": "object",
                "properties": {
                    "order_id": {
                        "type": "string",
                        "description": "The customer's order ID.",
                    },
                },
                "required": ["order_id"],
                "additionalProperties": False,
            },
        },
    }
]

messages = [
    {
        "role": "system",
        "content": "You are a helpful customer support assistant. Use the supplied tools to assist the user.",
    },
    {
        "role": "user",
        "content": "Hi, can you tell me the delivery date for my order? The order id is 1234 and 4321.",
    },
    #{
    #    "content": "",
    #    "tool_calls": [
    #        {
    #            "type": "function",
    #            "function": {
    #                "name": "get_delivery_date",
    #                "arguments": {"order_id": "1234"},
    #            },
    #            "id": "call_b4ab0b4ec4b5442e86f017fe0385e22e",
    #        },
    #        {
    #            "type": "function",
    #            "function": {
    #                "name": "get_delivery_date",
    #                "arguments": {"order_id": "4321"},
    #            },
    #            "id": "call_628965479dd84794bbb72ab9bdda0c39",
    #        },
    #    ],
    #    "role": "assistant",
    #},
    #{
    #    "role": "tool",
    #    "content": '{"delivery_date": "2024-09-05", "order_id": "1234"}',
    #    "tool_call_id": "call_b4ab0b4ec4b5442e86f017fe0385e22e",
    #},
    #{
    #    "role": "tool",
    #    "content": '{"delivery_date": "2024-09-05", "order_id": "4321"}',
    #    "tool_call_id": "call_628965479dd84794bbb72ab9bdda0c39",
    #},
    #{
    #    "content": "Both your orders will be delivered on 2024-09-05.",
    #    "role": "assistant",
    #    "thought": "\nI have the information you need, both orders will be delivered on the same date, 2024-09-05.\n",
    #},
]

tokenizer = AutoTokenizer.from_pretrained(
    "openbmb/MiniCPM3-4B", trust_remote_code=True
)
model = AutoModelForCausalLM.from_pretrained("openbmb/MiniCPM3-4B", torch_dtype=torch.bfloat16, device_map="cuda:0", trust_remote_code=True)
model_inputs = tokenizer.apply_chat_template(
    messages, tools=tools, tokenize=True, add_generation_prompt=True, return_tensors="pt"
).to("cuda:0")

model_outputs = model.generate(
model_inputs,
max_new_tokens=1024,
top_p=0.7,
temperature=0.7
)

output_token_ids = [
model_outputs[i][len(model_inputs[i]):] for i in range(len(model_inputs))
]

responses = tokenizer.batch_decode(output_token_ids, skip_special_tokens=True)[0]
print(responses)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
badcase Bad cases
Projects
None yet
Development

No branches or pull requests

2 participants