diff --git a/llm/server/.dockerignore b/llm/server/.dockerignore
new file mode 100644
index 000000000000..96dbf3cb734a
--- /dev/null
+++ b/llm/server/.dockerignore
@@ -0,0 +1,11 @@
+README.md
+requirements-dev.txt
+pyproject.toml
+Makefile
+
+dockerfiles/
+docs/
+server/__pycache__
+server/http_server
+server/engine
+server/data
diff --git a/llm/server/README.md b/llm/server/README.md
new file mode 100644
index 000000000000..b521644e5769
--- /dev/null
+++ b/llm/server/README.md
@@ -0,0 +1,39 @@
+
+
大模型服务化部署
+
+*该部署工具是基于英伟达Triton框架专为服务器场景的大模型服务化部署而设计。它提供了支持gRPC、HTTP协议的服务接口,以及流式Token输出能力。底层推理引擎支持连续批处理、weight only int8、后训练量化(PTQ)等加速优化策略,为用户带来易用且高性能的部署体验。*
+
+# 快速开始
+
+ 基于预编译镜像部署,本节以 Meta-Llama-3-8B-Instruct-A8W8C8 为例,更多模型请参考[LLaMA](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/predict/llama.md)、[Qwen](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/predict/qwen.md)、[Mixtral](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/predict/mixtral.md), 更细致的模型推理、量化教程可以参考[大模型推理教程](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/predict/inference.md):
+
+ ```
+ # 下载模型
+ wget https://paddle-qa.bj.bcebos.com/inference_model/Meta-Llama-3-8B-Instruct-A8W8C8.tar
+ mkdir Llama-3-8B-A8W8C8 && tar -xf Meta-Llama-3-8B-Instruct-A8W8C8.tar -C Llama-3-8B-A8W8C8
+
+ # 挂载模型文件
+ export MODEL_PATH=${PWD}/Llama-3-8B-A8W8C8
+
+ docker run --gpus all --shm-size 5G --network=host --privileged --cap-add=SYS_PTRACE \
+ -v ${MODEL_PATH}:/models/ \
+ -dit registry.baidubce.com/paddlepaddle/fastdeploy:llm-serving-cuda123-cudnn9-v1.2 \
+ bash -c 'export USE_CACHE_KV_INT8=1 && cd /opt/output/Serving && bash start_server.sh; exec bash'
+ ```
+
+ 等待服务启动成功(服务初次启动大概需要40s),可以通过以下命令测试:
+
+ ```
+ curl 127.0.0.1:9965/v1/chat/completions \
+ -H 'Content-Type: application/json' \
+ -d '{"text": "hello, llm"}'
+ ```
+
+Note:
+1. 请保证 shm-size >= 5,不然可能会导致服务启动失败
+
+更多关于该部署工具的使用方法,请查看[服务化部署流程](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/server/docs/deploy_usage_tutorial.md)
+
+# License
+
+遵循 [Apache-2.0开源协议](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/LICENSE) 。
diff --git a/llm/server/client/README.md b/llm/server/client/README.md
new file mode 100644
index 000000000000..06b238d3e4d9
--- /dev/null
+++ b/llm/server/client/README.md
@@ -0,0 +1,110 @@
+# 客户端使用方式
+
+## 简介
+
+服务化部署客户端提供命令行接口和Python接口,可以快速调用服务化后端部署的LLM模型服务。
+
+## 安装
+
+源码安装
+```
+pip install .
+```
+
+## 命令行接口
+
+首先通过环境变量设置模型服务模式、模型服务URL、模型ID,然后使用命令行接口调用模型服务。
+
+| 参数 | 说明 | 是否必填 | 默认值 |
+| --- | --- | --- | --- |
+| FASTDEPLOY_MODEL_URL | 模型服务部署的IP地址和端口,格式为`x.x.x.x:xxx`。 | 是 | |
+
+```
+export FASTDEPLOY_MODEL_URL="x.x.x.x:xxx"
+
+# 流式接口
+fdclient stream_generate "你好?"
+
+# 非流式接口
+fdclient generate "你好,你是谁?"
+```
+
+## Python接口
+
+首先通过Python代码设置模型服务URL(hostname+port),然后使用Python接口调用模型服务。
+
+| 参数 | 说明 | 是否必填 | 默认值 |
+| --- | --- | --- | --- |
+| hostname+port | 模型服务部署的IP地址和端口,格式为`x.x.x.x。 | 是 | |
+
+
+```
+from fastdeploy_client.chatbot import ChatBot
+
+hostname = "x.x.x.x"
+port = xxx
+
+# 流式接口,stream_generate api的参数说明见附录
+chatbot = ChatBot(hostname=hostname, port=port)
+stream_result = chatbot.stream_generate("你好", topp=0.8)
+for res in stream_result:
+ print(res)
+
+# 非流式接口,generate api的参数说明见附录
+chatbot = ChatBot(hostname=hostname, port=port)
+result = chatbot.generate("你好", topp=0.8)
+print(result)
+```
+
+### 接口说明
+```
+ChatBot.stream_generate(message,
+ max_dec_len=1024,
+ min_dec_len=2,
+ topp=0.0,
+ temperature=1.0,
+ frequency_score=0.0,
+ penalty_score=1.0,
+ presence_score=0.0,
+ eos_token_ids=254186)
+
+# 此函数返回一个iterator,其中每个元素为一个dict, 例如:{"token": "好的", "is_end": 0}
+# 其中token为生成的字符,is_end表明是否为生成的最后一个字符(0表示否,1表示是)
+# 注意:当生成结果出错时,返回错误信息;不同模型的eos_token_ids不同
+```
+
+```
+ChatBot.generate(message,
+ max_dec_len=1024,
+ min_dec_len=2,
+ topp=0.0,
+ temperature=1.0,
+ frequency_score=0.0,
+ penalty_score=1.0,
+ presence_score=0.0,
+ eos_token_ids=254186)
+
+# 此函数返回一个,例如:{"results": "好的,我知道了。"},其中results即为生成结果
+# 注意:当生成结果出错时,返回错误信息;不同模型的eos_token_ids不同
+```
+
+### 参数说明
+
+| 字段名 | 字段类型 | 说明 | 是否必填 | 默认值 | 备注 |
+| :---: | :-----: | :---: | :---: | :-----: | :----: |
+| req_id | str | 请求ID,用于标识一个请求。建议设置req_id,保证其唯一性 | 否 | 随机id | 如果推理服务中同时有两个相同req_id的请求,会返回req_id重复的错误信息 |
+| text | str | 请求的文本 | 是 | 无 | |
+| max_dec_len | int | 最大生成token的长度,如果请求的文本token长度加上max_dec_len大于模型的max_seq_len,会返回长度超限的错误信息 | 否 | max_seq_len减去文本token长度 | |
+| min_dec_len | int | 最小生成token的长度,最小是1 | 否 | 1 | |
+| topp | float | 控制随机性参数,数值越大则随机性越大,范围是0~1 | 否 | 0.7 | |
+| temperature | float | 控制随机性参数,数值越小随机性越大,需要大于 0 | 否 | 0.95 | |
+| frequency_score | float | 频率分数 | 否 | 0 | |
+| penalty_score | float | 惩罚分数 | 否 | 1 | |
+| presence_score | float | 存在分数 | 否 | 0 | |
+| stream | bool | 是否流式返回 | 否 | False | |
+| return_all_tokens | bool | 是否一次性返回所有结果 | 否 | False | 与stream参数差异见表后备注 |
+| timeout | int | 请求等待的超时时间,单位是秒 | 否 | 300 | |
+
+* 在正确配置PUSH_MODE_HTTP_PORT字段下,服务支持 GRPC 和 HTTP 两种请求服务
+ * stream 参数仅对 HTTP 请求生效
+ * return_all_tokens 参数对 GRPC 和 HTTP 请求均有效
diff --git a/llm/server/client/fastdeploy_client/__init__.py b/llm/server/client/fastdeploy_client/__init__.py
new file mode 100644
index 000000000000..83ae7a0036f3
--- /dev/null
+++ b/llm/server/client/fastdeploy_client/__init__.py
@@ -0,0 +1,20 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import sys
+
+__version__ = "4.4.0"
+
+logging.basicConfig(stream=sys.stdout, level=logging.DEBUG)
diff --git a/llm/server/client/fastdeploy_client/chatbot.py b/llm/server/client/fastdeploy_client/chatbot.py
new file mode 100644
index 000000000000..182667404251
--- /dev/null
+++ b/llm/server/client/fastdeploy_client/chatbot.py
@@ -0,0 +1,304 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import logging
+import queue
+import traceback
+import uuid
+from functools import partial
+
+import numpy as np
+import tritonclient.grpc as grpcclient
+from fastdeploy_client.message import ChatMessage
+from fastdeploy_client.utils import is_enable_benchmark
+from tritonclient import utils as triton_utils
+
+
+class ChatBotClass(object):
+ """
+ initiating conversations through the tritonclient interface of the model service.
+ """
+ def __init__(self, hostname, port, timeout=120):
+ """
+ Initialization function
+
+ Args:
+ hostname (str): gRPC hostname
+ port (int): gRPC port
+ timeout (int): Request timeout, default is 120 seconds.
+
+ Returns:
+ None
+ """
+ self.url = f"{hostname}:{port}"
+ self.timeout = timeout
+
+ def stream_generate(self,
+ message,
+ max_dec_len=1024,
+ min_dec_len=1,
+ topp=0.7,
+ temperature=0.95,
+ frequency_score=0.0,
+ penalty_score=1.0,
+ presence_score=0.0,
+ system=None,
+ **kwargs):
+ """
+ Streaming interface
+
+ Args:
+ message (Union[str, List[str], ChatMessage]): message or ChatMessage object
+ max_dec_len (int, optional): max decoding length. Defaults to 1024.
+ min_dec_len (int, optional): min decoding length. Defaults to 1.
+ topp (float, optional): randomness of the generated tokens. Defaults to 0.7.
+ temperature (float, optional): temperature. Defaults to 0.95.
+ frequency_score (float, optional): frequency score. Defaults to 0.0.
+ penalty_score (float, optional): penalty score. Defaults to 1.0.
+ presence_score (float, optional): presence score. Defaults to 0.0.
+ system (str, optional): system settings. Defaults to None.
+ **kwargs: others
+
+ For more details, please refer to https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/server/docs/deploy_usage_tutorial.md#%E8%AF%B7%E6%B1%82%E5%8F%82%E6%95%B0%E4%BB%8B%E7%BB%8D
+
+ Returns:
+ return a generator object, which yields a dict.
+ Normal, return {'token': xxx, 'is_end': xxx, 'send_idx': xxx, ..., 'error_msg': '', 'error_code': 0}
+ Others, return {'error_msg': xxx, 'error_code': xxx}, error_msg not None, error_code != 0
+ """
+ try:
+ model_name = "model"
+ inputs = [grpcclient.InferInput("IN", [1], triton_utils.np_to_triton_dtype(np.object_))]
+ outputs = [grpcclient.InferRequestedOutput("OUT")]
+ output_data = OutputData()
+
+ msg = message.message if isinstance(message, ChatMessage) else message
+ input_data = self._prepare_input_data(msg, max_dec_len, min_dec_len,
+ topp, temperature, frequency_score,
+ penalty_score, presence_score, **kwargs)
+ req_id = input_data["req_id"]
+ inputs[0].set_data_from_numpy(np.array([json.dumps([input_data])], dtype=np.object_))
+ timeout = kwargs.get("timeout", self.timeout)
+
+ with grpcclient.InferenceServerClient(url=self.url, verbose=False) as triton_client:
+ triton_client.start_stream(callback=partial(triton_callback, output_data))
+ triton_client.async_stream_infer(model_name=model_name,
+ inputs=inputs,
+ request_id=req_id,
+ outputs=outputs)
+ answer_str = ""
+ enable_benchmark = is_enable_benchmark(**kwargs)
+ while True:
+ try:
+ response = output_data._completed_requests.get(timeout=timeout)
+ except queue.Empty:
+ yield {"req_id": req_id, "error_msg": f"Fetch response from server timeout ({timeout}s)"}
+ break
+ if type(response) == triton_utils.InferenceServerException:
+ yield {"req_id": req_id, "error_msg": f"InferenceServerException raised by inference: {response.message()}"}
+ break
+ else:
+ if enable_benchmark:
+ response = json.loads(response.as_numpy("OUT")[0])
+ if isinstance(response, (list, tuple)):
+ response = response[0]
+ else:
+ response = self._format_response(response, req_id)
+ token = response.get("token", "")
+ if isinstance(token, list):
+ token = token[0]
+ answer_str += token
+ yield response
+ if response.get("is_end") == 1 or response.get("error_msg") is not None:
+ break
+ triton_client.stop_stream(cancel_requests=True)
+ triton_client.close()
+
+ if isinstance(message, ChatMessage):
+ message.message.append({"role": "assistant", "content": answer_str})
+ except Exception as e:
+ yield {"error_msg": f"{e}, details={str(traceback.format_exc())}"}
+
+ def generate(self,
+ message,
+ max_dec_len=1024,
+ min_dec_len=1,
+ topp=0.7,
+ temperature=0.95,
+ frequency_score=0.0,
+ penalty_score=1.0,
+ presence_score=0.0,
+ system=None,
+ **kwargs):
+ """
+ Return the entire sentence using the streaming interface.
+
+ Args:
+ message (Union[str, List[str], ChatMessage]): message or ChatMessage object
+ max_dec_len (int, optional): max decoding length. Defaults to 1024.
+ min_dec_len (int, optional): min decoding length. Defaults to 1.
+ topp (float, optional): randomness of the generated tokens. Defaults to 0.7.
+ temperature (float, optional): temperature. Defaults to 0.95.
+ frequency_score (float, optional): frequency score. Defaults to 0.0.
+ penalty_score (float, optional): penalty score. Defaults to 1.0.
+ presence_score (float, optional): presence score. Defaults to 0.0.
+ system (str, optional): system settings. Defaults to None.
+ **kwargs: others
+
+ For more details, please refer to https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/server/docs/deploy_usage_tutorial.md#%E8%AF%B7%E6%B1%82%E5%8F%82%E6%95%B0%E4%BB%8B%E7%BB%8D
+
+ Returns:
+ return the entire sentence or error message.
+ Normal, return {'tokens_all': xxx, ..., 'error_msg': '', 'error_code': 0}
+ Others, return {'error_msg': xxx, 'error_code': xxx}, error_msg not None, error_code != 0
+ """
+ stream_response = self.stream_generate(message, max_dec_len,
+ min_dec_len, topp, temperature,
+ frequency_score, penalty_score,
+ presence_score, system, **kwargs)
+ results = ""
+ token_ids = list()
+ error_msg = None
+ for res in stream_response:
+ if "token" not in res or "error_msg" in res:
+ error_msg = {"error_msg": f"response error, please check the info: {res}"}
+ elif isinstance(res["token"], list):
+ results = res["token"]
+ token_ids = res["token_ids"]
+ else:
+ results += res["token"]
+ token_ids += res["token_ids"]
+ if error_msg:
+ return {"req_id": res["req_id"], "error_msg": error_msg}
+ else:
+ return {"req_id": res["req_id"], "results": results, "token_ids": token_ids}
+
+ def _prepare_input_data(self,
+ message,
+ max_dec_len=1024,
+ min_dec_len=2,
+ topp=0.0,
+ temperature=1.0,
+ frequency_score=0.0,
+ penalty_score=1.0,
+ presence_score=0.0,
+ system=None,
+ **kwargs):
+ """
+ Prepare to input data
+ """
+ inputs = {
+ "max_dec_len": max_dec_len,
+ "min_dec_len": min_dec_len,
+ "topp": topp,
+ "temperature": temperature,
+ "frequency_score": frequency_score,
+ "penalty_score": penalty_score,
+ "presence_score": presence_score,
+ }
+
+ if system is not None:
+ inputs["system"] = system
+
+ inputs["req_id"] = kwargs.get("req_id", str(uuid.uuid4()))
+ if "eos_token_ids" in kwargs and kwargs["eos_token_ids"] is not None:
+ inputs["eos_token_ids"] = kwargs["eos_token_ids"]
+ inputs["response_timeout"] = kwargs.get("timeout", self.timeout)
+
+ if isinstance(message, str):
+ inputs["text"] = message
+ elif isinstance(message, list):
+ assert len(message) % 2 == 1, \
+ "The length of message should be odd while it's a list."
+ assert message[-1]["role"] == "user", \
+ "The {}-th element key should be 'user'".format(len(message) - 1)
+ for i in range(0, len(message) - 1, 2):
+ assert message[i]["role"] == "user", \
+ "The {}-th element key should be 'user'".format(i)
+ assert message[i + 1]["role"] == "assistant", \
+ "The {}-th element key should be 'assistant'".format(i + 1)
+ inputs["messages"] = message
+ else:
+ raise Exception(
+ "The message should be string or list of dict like [{'role': "
+ "'user', 'content': 'Hello, what's your name?''}]"
+ )
+
+ return inputs
+
+ def _format_response(self, response, req_id):
+ """
+ Format the service return fields
+ """
+ response = json.loads(response.as_numpy("OUT")[0])
+ if isinstance(response, (list, tuple)):
+ response = response[0]
+ is_end = response.get("is_end", False)
+
+ if "error_msg" in response:
+ return {"req_id": req_id, "error_msg": response["error_msg"]}
+ elif "choices" in response:
+ token = [x["token"] for x in response["choices"]]
+ token_ids = [x["token_ids"] for x in response["choices"]]
+ return {"req_id": req_id, "token": token, "token_ids": token_ids, "is_end": 1}
+ elif "token" not in response and "result" not in response:
+ return {"req_id": req_id, "error_msg": f"The response should contain 'token' or 'result', but got {response}"}
+ else:
+ token_ids = response.get("token_ids", [])
+ if "result" in response:
+ token = response["result"]
+ elif "token" in response:
+ token = response["token"]
+ return {"req_id": req_id, "token": token, "token_ids": token_ids, "is_end": is_end}
+
+
+class OutputData:
+ """
+ Receive data returned by Triton service
+ """
+ def __init__(self):
+ self._completed_requests = queue.Queue()
+
+
+def triton_callback(output_data, result, error):
+ """
+ callback function for Triton server
+ """
+ if error:
+ output_data._completed_requests.put(error)
+ else:
+ output_data._completed_requests.put(result)
+
+
+class ChatBot(object):
+ """
+ External interface, create a client object ChatBotForPushMode
+ """
+ def __new__(cls, hostname, port, timeout=120):
+ """
+ initialize a GRPCInferenceService client
+ Args:
+ hostname (str): server hostname
+ port (int): GRPC port
+ timeout (int): timeout(s), default 120 seconds
+ Returns:
+ ChatBotClass: BaseChatBot object
+ """
+ if not isinstance(hostname, str) or not hostname:
+ raise ValueError("Invalid hostname")
+ if not isinstance(port, int) or port <= 0 or port > 65535:
+ raise ValueError("Invalid port number")
+
+ return ChatBotClass(hostname, port, timeout)
diff --git a/llm/server/client/fastdeploy_client/command.py b/llm/server/client/fastdeploy_client/command.py
new file mode 100644
index 000000000000..8aaf9ab7c2a2
--- /dev/null
+++ b/llm/server/client/fastdeploy_client/command.py
@@ -0,0 +1,71 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import logging
+import os
+import sys
+
+from fastdeploy_client.chatbot import ChatBot
+
+
+def _get_service_configuration():
+ """
+ get service url from environment
+
+ Returns:
+ tuple: (hostname, port)
+ """
+ url = os.getenv("FASTDEPLOY_MODEL_URL")
+
+ if url is None:
+ raise ValueError("Please set service url by `export FASTDEPLOY_MODEL_URL`."
+ "For example: `export FASTDEPLOY_MODEL_URL=127.0.0.1:8500`")
+ hostname, port = url.strip().split(':')
+ port = int(port)
+ if port <= 0 or port > 65535:
+ raise ValueError("Invalid port number")
+
+ return hostname, port
+
+
+def stream_generate(prompt):
+ """
+ Streaming interface
+ """
+ hostname, port = _get_service_configuration()
+ chatbot = ChatBot(hostname=hostname, port=port)
+ stream_result = chatbot.stream_generate(prompt)
+ for res in stream_result:
+ print(res)
+
+
+def generate(prompt):
+ """
+ entire sentence interface
+ """
+ hostname, port = _get_service_configuration()
+ chatbot = ChatBot(hostname=hostname, port=port)
+ result = chatbot.generate(prompt)
+ print(result)
+
+
+def main():
+ if len(sys.argv) < 2 or sys.argv[1] not in ["generate", "stream_generate"]:
+ logging.error("Usage 1: fdclient generate \"Hello, How are you?\"")
+ return
+ prompt = sys.argv[2]
+ if sys.argv[1] == "generate":
+ return generate(prompt)
+ else:
+ return stream_generate(prompt)
diff --git a/llm/server/client/fastdeploy_client/message.py b/llm/server/client/fastdeploy_client/message.py
new file mode 100644
index 000000000000..7ce1b7b7b91d
--- /dev/null
+++ b/llm/server/client/fastdeploy_client/message.py
@@ -0,0 +1,50 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+class ChatMessage(object):
+ """
+ multi-turn chat message with ChatBot
+ """
+ def __init__(self, prompt=None):
+ if prompt is not None:
+ self.message = [{"role": "user", "content": prompt}]
+ else:
+ self.message = []
+
+ def add_user_message(self, content):
+ """
+ add user message
+ """
+ if len(self.message) > 0 and self.message[-1]["role"] != "assistant":
+ raise Exception("Cannot add user message, because the role of the "
+ f"last message is not assistant. The message is {self.message}")
+ self.message.append({"role": "user", "content": content})
+
+ def add_assistant_message(self, content):
+ """
+ add assistant message
+ """
+ if len(self.message) > 0 and self.message[-1]["role"] != "user":
+ raise Exception("Cannot add user message, because the role of the "
+ f"last message is not user. The message is {self.message}")
+ self.message.append({"role": "assistant", "content": content})
+
+ def next_prompt(self, content):
+ """
+ add user message and return a new prompt
+ """
+ self.add_user_message(content)
+
+ def __str__(self):
+ return str(self.message)
diff --git a/llm/server/client/fastdeploy_client/utils.py b/llm/server/client/fastdeploy_client/utils.py
new file mode 100644
index 000000000000..b5c3f33165fc
--- /dev/null
+++ b/llm/server/client/fastdeploy_client/utils.py
@@ -0,0 +1,19 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+def is_enable_benchmark(**kwargs):
+ """
+ Check if enable benchmark
+ """
+ return "benchmark" in kwargs and kwargs["benchmark"] == 1
diff --git a/llm/server/client/requirements.txt b/llm/server/client/requirements.txt
new file mode 100644
index 000000000000..132f7f2b0dae
--- /dev/null
+++ b/llm/server/client/requirements.txt
@@ -0,0 +1,5 @@
+grpcio
+streamlit<=1.33.0
+streamlit_chat<=0.1.1
+protobuf==3.20.0
+tritonclient[grpc]==2.41.1
diff --git a/llm/server/client/setup.py b/llm/server/client/setup.py
new file mode 100644
index 000000000000..9075c45ae05e
--- /dev/null
+++ b/llm/server/client/setup.py
@@ -0,0 +1,39 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import setuptools
+from fastdeploy_client import __version__ as version
+
+long_description = "No description"
+with open("requirements.txt") as fin:
+ REQUIRED_PACKAGES = fin.read()
+
+setuptools.setup(
+ name="fastdeploy-client",
+ version=version,
+ author="dltp-sz",
+ author_email="dltp-sz@baidu.com",
+ description="Client for fastdeploy llm serving",
+ long_description=long_description,
+ long_description_content_type="text/plain",
+ url="https://github.com/PaddlePaddle/Paddle",
+ packages=setuptools.find_packages(),
+ install_requires=REQUIRED_PACKAGES,
+ classifiers=[
+ "Programming Language :: Python :: 3",
+ "License :: OSI Approved :: Apache Software License",
+ "Operating System :: OS Independent",
+ ],
+ license='Apache 2.0',
+ entry_points={'console_scripts': ['fdclient=fastdeploy_client.command:main', ]})
diff --git a/llm/server/dockerfiles/Dockerfile_serving_cuda118_cudnn8 b/llm/server/dockerfiles/Dockerfile_serving_cuda118_cudnn8
new file mode 100644
index 000000000000..b7cd4205c4de
--- /dev/null
+++ b/llm/server/dockerfiles/Dockerfile_serving_cuda118_cudnn8
@@ -0,0 +1,34 @@
+FROM registry.baidubce.com/paddlepaddle/fastdeploy:llm-base-gcc12.3-cuda11.8-cudnn8-nccl2.15.5
+
+WORKDIR /opt/output/
+COPY ./server/ /opt/output/Serving/
+COPY ./client/ /opt/output/client/
+
+ENV LD_LIBRARY_PATH="/usr/local/cuda-11.8/compat/:$LD_LIBRARY_PATH"
+
+RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
+RUN python3 -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu118/ \
+ && python3 -m pip install paddlenlp==3.0.0b0 \
+ && python3 -m pip install --no-cache-dir sentencepiece pycryptodome tritonclient[all]==2.41.1
+
+RUN git clone https://gitee.com/paddlepaddle/PaddleNLP.git && cd PaddleNLP/csrc \
+ && python3 setup_cuda.py build && python3 setup_cuda.py install --user \
+ && cp -r /opt/output/PaddleNLP/paddlenlp /usr/local/lib/python3.10/dist-packages/ \
+ && cp -r /root/.local/lib/python3.10/site-packages/* /usr/local/lib/python3.10/dist-packages/ \
+ && rm -rf /opt/output/PaddleNLP
+
+RUN cd /opt/output/client && pip install -r requirements.txt && pip install .
+
+RUN python3 -m pip install -r /opt/output/Serving/requirements.txt && rm /opt/output/Serving/requirements.txt
+RUN mv Serving/server /usr/local/lib/python3.10/dist-packages/
+RUN mkdir -p /opt/output/Serving/llm_model/model/1 \
+ && mv /opt/output/Serving/config/config.pbtxt /opt/output/Serving/llm_model/model/ \
+ && rm -rf /opt/output/Serving/config/
+RUN echo "from server.triton_server import TritonPythonModel" >>/opt/output/Serving/llm_model/model/1/model.py
+
+RUN cd /opt/output/Serving/ \
+ && cp scripts/start_server.sh . && cp scripts/stop_server.sh . \
+ && rm -rf scripts
+
+ENV http_proxy=""
+ENV https_proxy=""
diff --git a/llm/server/dockerfiles/Dockerfile_serving_cuda123_cudnn9 b/llm/server/dockerfiles/Dockerfile_serving_cuda123_cudnn9
new file mode 100644
index 000000000000..fabb7c1724fc
--- /dev/null
+++ b/llm/server/dockerfiles/Dockerfile_serving_cuda123_cudnn9
@@ -0,0 +1,34 @@
+FROM registry.baidubce.com/paddlepaddle/fastdeploy:llm-base-gcc12.3-cuda12.3-cudnn9-nccl2.15.5
+
+WORKDIR /opt/output/
+COPY ./server/ /opt/output/Serving/
+COPY ./client/ /opt/output/client/
+
+ENV LD_LIBRARY_PATH="/usr/local/cuda-12.3/compat/:$LD_LIBRARY_PATH"
+
+RUN pip config set global.index-url https://pypi.tuna.tsinghua.edu.cn/simple
+RUN python3 -m pip install --pre paddlepaddle-gpu -i https://www.paddlepaddle.org.cn/packages/nightly/cu123/ \
+ && python3 -m pip install paddlenlp==3.0.0b0 \
+ && python3 -m pip install --no-cache-dir sentencepiece pycryptodome tritonclient[all]==2.41.1
+
+RUN git clone https://gitee.com/paddlepaddle/PaddleNLP.git && cd PaddleNLP/csrc \
+ && python3 setup_cuda.py build && python3 setup_cuda.py install --user \
+ && cp -r /opt/output/PaddleNLP/paddlenlp /usr/local/lib/python3.10/dist-packages/ \
+ && cp -r /root/.local/lib/python3.10/site-packages/* /usr/local/lib/python3.10/dist-packages/ \
+ && rm -rf /opt/output/PaddleNLP
+
+RUN cd /opt/output/client && pip install -r requirements.txt && pip install .
+
+RUN python3 -m pip install -r /opt/output/Serving/requirements.txt && rm /opt/output/Serving/requirements.txt
+RUN mv Serving/server /usr/local/lib/python3.10/dist-packages/
+RUN mkdir -p /opt/output/Serving/llm_model/model/1 \
+ && mv /opt/output/Serving/config/config.pbtxt /opt/output/Serving/llm_model/model/ \
+ && rm -rf /opt/output/Serving/config/
+RUN echo "from server.triton_server import TritonPythonModel" >>/opt/output/Serving/llm_model/model/1/model.py
+
+RUN cd /opt/output/Serving/ \
+ && cp scripts/start_server.sh . && cp scripts/stop_server.sh . \
+ && rm -rf scripts
+
+ENV http_proxy=""
+ENV https_proxy=""
diff --git a/llm/server/docs/deploy_usage_tutorial.md b/llm/server/docs/deploy_usage_tutorial.md
new file mode 100644
index 000000000000..7efff3306794
--- /dev/null
+++ b/llm/server/docs/deploy_usage_tutorial.md
@@ -0,0 +1,357 @@
+
+## 目录
+
+- [部署环境准备](#部署环境准备)
+ - [基础环境](#基础环境)
+ - [准备部署镜像](#准备部署镜像)
+ - [准备模型](#准备模型)
+ - [创建容器](#创建容器)
+- [启动服务](#启动服务)
+ - [配置参数](#配置参数)
+ - [启动服务](#启动服务)
+ - [服务状态查询](#服务状态查询)
+- [服务测试](#服务测试)
+ - [Python 客户端](#Python-客户端)
+ - [HTTP调用](#HTTP调用)
+ - [OpenAI 客户端](#OpenAI-客户端)
+ - [返回示例](#返回示例)
+- [基于dockerfile创建自己的镜像](#基于dockerfile创建自己的镜像)
+- [模型配置参数介绍](#模型配置参数介绍)
+- [请求参数介绍](#请求参数介绍)
+
+## 部署环境准备
+
+### 基础环境
+ 该服务化部署工具目前仅支持在 Linux 系统下部署,部署之前请确保系统有正确的 GPU 环境。
+
+ - 安装 docker
+ 请参考 [Install Docker Engine](https://docs.docker.com/engine/install/) 选择对应的 Linux 平台安装 docker 环境。
+
+ - 安装 NVIDIA Container Toolkit
+ 请参考 [Installing the NVIDIA Container Toolkit](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/install-guide.html#installing-the-nvidia-container-toolkit) 了解并安装 NVIDIA Container Toolkit。
+
+ NVIDIA Container Toolkit 安装成功后,参考 [Running a Sample Workload with Docker](https://docs.nvidia.com/datacenter/cloud-native/container-toolkit/latest/sample-workload.html#running-a-sample-workload-with-docker) 测试 NVIDIA Container Toolkit 是否可以正常使用。
+
+### 准备部署镜像
+
+为了方便部署,我们提供了 cuda12.3 的镜像,可以直接拉取镜像,或者使用我们提供的 `Dockerfile` [构建自定义镜像](#基于dockerfile创建自己的镜像)
+```
+docker pull registry.baidubce.com/paddlepaddle/fastdeploy:llm-serving-cuda123-cudnn9-v1.2
+```
+
+### 准备模型
+
+该部署工具为 PaddleNLP 静态图模型提供了高效的部署方案,模型静态图导出方案请参考:[LLaMA](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/predict/llama.md)、[Qwen](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/predict/qwen.md)、[Mixtral](https://github.com/PaddlePaddle/PaddleNLP/blob/develop/llm/docs/predict/mixtral.md) ...
+
+导出后的模型放在任意文件夹下,以 `/home/workspace/models_dir` 为例
+
+```
+cd /home/workspace/models_dir
+
+# 导出的模型目录结构如下所示,理论上无缝支持 PaddleNLP 导出的静态图模型,无需修改模型目录结构
+# /opt/output/Serving/models
+# ├── config.json # 模型配置文件
+# ├── xxxx.model # 词表模型文件
+# ├── special_tokens_map.json # 词表配置文件
+# ├── tokenizer_config.json # 词表配置文件
+# ├── rank_mapping.csv # 多卡模型会有此文件,如为单卡模型,则无此文件(可选,仅在多卡部署模式下需要)
+# └── rank_0 # 保存模型结构和权重文件的目录
+# ├── model.pdiparams
+# └── model.pdmodel
+```
+
+### 创建容器
+
+创建容器之前,请检查 Docker 版本和 GPU 环境,确保 Docker 支持 `--gpus all` 参数。
+
+将模型目录挂载到容器中,默认模型挂载地址为 `/models/`,服务启动时可通过 `MODEL_DIR` 环境变量自定义挂载地址。
+```
+docker run --gpus all \
+ --name paddlenlp_serving \
+ --privileged \
+ --cap-add=SYS_PTRACE \
+ --network=host \
+ --shm-size=5G \
+ -v /home/workspace/models_dir:/models/ \
+ -dit registry.baidubce.com/paddlepaddle/fastdeploy:llm-serving-cuda123-cudnn9-v1.2 bash
+
+# 进入容器,检查GPU环境和模型挂载是否正常
+docker exec -it paddlenlp_serving /bin/bash
+nvidia-smi
+ls /models/
+```
+
+## 启动服务
+
+### 配置参数
+
+根据需求和硬件信息,配置以下环境变量
+
+```
+# 单/多卡推理配置。自行修改。
+## 如果是单卡推理,使用0卡,设置如下环境变量。
+export MP_NUM=1
+export CUDA_VISIBLE_DEVICES=0
+
+## 如果是多卡推理,除了模型导出得满足2卡要求,同时设置如下环境变量。
+# export MP_NUM=2
+# export CUDA_VISIBLE_DEVICES=0,1
+
+# 如部署场景无流式Token返回需求,可配置如下开关
+# 服务将会将每个请求的所有生成Token一次性返回
+# 降低服务逐个Token发送压力
+# 默认关闭
+# export DISABLE_STREAMING=1
+
+# 配置数据服务。需要自行修改HTTP_PORT、GRPC_PORT、METRICS_PORT和INFER_QUEUE_PORT。(请事先检查端口可用)
+export HTTP_PORT="8110" # 探活服务的http端口(当前仅用于健康检查、探活)
+export GRPC_PORT="8811" # 模型推服务的grpc端口
+export METRICS_PORT="8722" # 模型服务中监督指标的端口
+export INFER_QUEUE_PORT="8813" # 模型服务内部使用的端口
+export PUSH_MODE_HTTP_PORT="9965" # 服务请求HTTP端口号,如不配置,默认为-1,即服务只支持GRPC协议
+
+# MAX_SEQ_LEN: 服务会拒绝input token数量超过MAX_SEQ_LEN的请求,并返回错误提示
+# MAX_DEC_LEN: 服务会拒绝请求中max_dec_len/min_dec_len超过此参数的请求,并返回错误提示
+export MAX_SEQ_LEN=8192
+export MAX_DEC_LEN=1024
+
+export BATCH_SIZE="48" # 设置最大Batch Size,模型可同时并发处理的最大输入数量,不能高于128
+export BLOCK_BS="5" # 缓存Block支持的最大Query Batch Size,如果出现out of memeory 错误,尝试减少该数值
+export BLOCK_RATIO="0.75" # 一般可以设置成 输入平均Token数/(输入+输出平均Token数)
+
+export MAX_CACHED_TASK_NUM="128" # 服务缓存队列最大长度,队列达到上限后,会拒绝新的请求,默认128
+# 开启HTTP接口配置如下参数
+export PUSH_MODE_HTTP_WORKERS="1" # HTTP服务进程数,在 PUSH_MODE_HTTP_PORT 配置的情况下有效,最高设置到8即可,默认为1
+```
+
+更多请求参数请参考[模型配置参数介绍](#模型配置参数介绍)
+
+### 启动服务
+
+```
+cd /opt/output/Serving
+bash start_server.sh
+
+# 重新启动服务前,需要停止服务,在/opt/output/Serving目录下执行 bash stop_server.sh
+```
+
+### 服务状态查询
+
+```
+# port为上面启动服务时候指定的HTTP_PORT
+ > 测试前请确保服务IP和端口正确
+
+live接口: (服务是否能正常接收请求)
+ http://127.0.0.1:8110/v2/health/live
+health接口:(模型是否准备好推理)
+ http://127.0.0.1:8110/v2/health/ready
+```
+
+## 服务测试
+
+### Python 客户端
+
+```
+from fastdeploy_client.chatbot import ChatBot
+
+hostname = "127.0.0.1" # 服务部署的hostname
+port = 8811 # 服务配置的GRPC_PORT
+
+chatbot = ChatBot(hostname=hostname, port=port)
+
+# 非流式接口
+result = chatbot.generate("hello", topp=0.8, max_dec_len=128, timeout=120)
+print(result)
+
+# 流式接口
+chatbot = ChatBot(hostname=hostname, port=port)
+stream_result = chatbot.stream_generate("hello", max_dec_len=128, timeout=120)
+for res in stream_result:
+ print(res)
+```
+
+### HTTP调用
+
+提示:HTTP调用接口使用变量 PUSH_MODE_HTTP_PORT 配置!HTTP_PORT 仅用于探活接口使用!
+
+```
+import uuid
+import json
+import requests
+
+push_mode_http_port = "9965" # 服务配置的PUSH_MODE_HTTP_PORT
+url = f"http://127.0.0.1:{push_mode_http_port}/v1/chat/completions"
+req_id = str(uuid.uuid1())
+data_single = {
+ "text": "Hello, how are you?",
+ "req_id": req_id,
+ "max_dec_len": 64,
+ "stream": True,
+ }
+# 逐token返回
+res = requests.post(url, json=data_single, stream=True)
+for line in res.iter_lines():
+ print(json.loads(line))
+
+# 多轮对话
+data_multi = {
+ "messages": [
+ {"role": "user", "content": "Hello, who are you"},
+ {"role": "system", "content": "I'm a helpful AI assistant."},
+ {"role": "user", "content": "List 3 countries and their capitals."},
+ ],
+ "req_id": req_id,
+ "max_dec_len": 64,
+ "stream": True,
+ }
+# 逐token返回
+res = requests.post(url, json=data_multi, stream=True)
+for line in res.iter_lines():
+ print(json.loads(line))
+```
+
+更多请求参数请参考[请求参数介绍](#请求参数介绍)
+
+### 返回示例
+
+```
+如果stream为True,流式返回
+ 如果正常,返回{'token': xxx, 'is_end': xxx, 'send_idx': xxx, ..., 'error_msg': '', 'error_code': 0}
+ 如果异常,返回{'error_msg': xxx, 'error_code': xxx},error_msg字段不为空,error_code字段不为0
+
+如果stream为False,非流式返回
+ 如果正常,返回{'tokens_all': xxx, ..., 'error_msg': '', 'error_code': 0}
+ 如果异常,返回{'error_msg': xxx, 'error_code': xxx},error_msg字段不为空,error_code字段不为0
+```
+
+### OpenAI 客户端
+
+我们提供了 OpenAI 客户端的支持,使用方法如下:
+
+提示:使用 OpenAI 客户端需要配置 `PUSH_MODE_HTTP_PORT`!
+
+```
+import openai
+
+push_mode_http_port = "9965" # 服务配置的PUSH_MODE_HTTP_PORT
+client = openai.Client(base_url=f"http://127.0.0.1:{push_mode_http_port}/v1/chat/completions", api_key="EMPTY_API_KEY")
+
+# 非流式返回
+response = client.completions.create(
+ model="default",
+ prompt="Hello, how are you?",
+ max_tokens=50,
+ stream=False,
+)
+
+print(response)
+print("\n")
+
+# 流式返回
+response = client.completions.create(
+ model="default",
+ prompt="Hello, how are you?",
+ max_tokens=100,
+ stream=True,
+)
+
+for chunk in response:
+ if chunk.choices[0] is not None:
+ print(chunk.choices[0].text, end='')
+print("\n")
+
+# Chat completion
+# 非流式返回
+response = client.chat.completions.create(
+ model="default",
+ messages=[
+ {"role": "user", "content": "Hello, who are you"},
+ {"role": "system", "content": "I'm a helpful AI assistant."},
+ {"role": "user", "content": "List 3 countries and their capitals."},
+ ],
+ temperature=0,
+ max_tokens=64,
+ stream=False,
+)
+
+print(response)
+print("\n")
+
+# 流式返回
+response = client.chat.completions.create(
+ model="default",
+ messages=[
+ {"role": "user", "content": "Hello, who are you"},
+ {"role": "system", "content": "I'm a helpful AI assistant."},
+ {"role": "user", "content": "List 3 countries and their capitals."},
+ ],
+ temperature=0,
+ max_tokens=64,
+ stream=True,
+)
+
+for chunk in response:
+ if chunk.choices[0].delta is not None:
+ print(chunk.choices[0].delta.content, end='')
+print("\n")
+```
+
+## 基于dockerfile创建自己的镜像
+
+为了方便用户构建自定义服务,我们提供了基于dockerfile创建自己的镜像的脚本。
+```
+git clone https://github.com/PaddlePaddle/PaddleNLP.git
+cd PaddleNLP/llm/server
+
+docker build --network=host -f ./dockerfiles/Dockerfile_serving_cuda123_cudnn9 -t llm-serving-cu123-self .
+```
+创建自己的镜像后,可以基于该镜像[创建容器](#创建容器)
+
+## 模型配置参数介绍
+
+| 字段名 | 字段类型 | 说明 | 是否必填 | 默认值 | 备注 |
+| :---: | :-----: | :---: | :---: | :-----: | :----: |
+| MP_NUM | int | 模型并行度 | 否 | 8 | CUDA_VISIBLE_DEVICES 需配置对应卡数 |
+| CUDA_VISIBLE_DEVICES | str | 使用 GPU 编号 | 否 | 0,1,2,3,4,5,6,7 | |
+| HTTP_PORT | int | 探活服务的http端口 | 是 | 无 | 当前仅用于健康检查、探活 |
+| GRPC_PORT | int | 模型推服务的grpc端口 | 是 | 无 | |
+| METRICS_PORT | int | 模型服务中监督指标的端口 | 是 | 无 | |
+| INFER_QUEUE_PORT | int | 模型服务内部使用的端口 | 否 | 56666 | |
+| PUSH_MODE_HTTP_PORT | int | 服务请求HTTP端口号 | 否 | -1 | 如不配置,服务只支持GRPC协议 |
+| DISABLE_STREAMING | int | 是否使用流式返回 | 否 | 0 | |
+| MAX_SEQ_LEN | int | 最大输入序列长度 | 否 | 8192 | 服务会拒绝input token数量超过MAX_SEQ_LEN的请求,并返回错误提示 |
+| MAX_DEC_LEN | int | 最大decoer序列长度 | 否 | 1024 | 服务会拒绝请求中max_dec_len/min_dec_len超过此参数的请求,并返回错误提示 |
+| BATCH_SIZE | int | 最大Batch Size | 否 | 50 | 模型可同时并发处理的最大输入数量,不能高于128 |
+| BLOCK_BS | int | 缓存Block支持的最大Query Batch Size | 否 | 50 | 如果出现out of memeory 错误,尝试减少该数值 |
+| BLOCK_RATIO | float | | 否 | 0.75 | 建议配置 输入平均Token数/(输入+输出平均Token数) |
+| MAX_CACHED_TASK_NUM | int | 服务缓存队列最大长度 | 否 | 128 | 队列达到上限后,会拒绝新的请求 |
+| PUSH_MODE_HTTP_WORKERS | int | HTTP服务进程数 | 否 | 1 | 在 PUSH_MODE_HTTP_PORT 配置的情况下有效,高并发下提高该数值,建议最高配置为8 |
+| USE_WARMUP | int | 是否进行 warmup | 否 | 0 | |
+| USE_HF_TOKENIZER | int | 是否进行使用huggingface的词表 | 否 | 0 | |
+| USE_CACHE_KV_INT8 | int | 是否将INT8配置为KV Cache的类型 | 否 | 0 | c8量化模型需要配置为1 |
+| MODEL_DIR | str | 模型文件路径 | 否 | /models/ | |
+| FD_MODEL_CONFIG_PATH | str | 模型config文件路径 | 否 | ${model_dir}/config.json | |
+| DISTRIBUTED_CONFIG | str | 模型分布式配置文件路径 | 否 | ${model_dir}/rank_mapping.csv | |
+
+## 请求参数介绍
+
+| 字段名 | 字段类型 | 说明 | 是否必填 | 默认值 | 备注 |
+| :---: | :-----: | :---: | :---: | :-----: | :----: |
+| req_id | str | 请求ID,用于标识一个请求。建议设置req_id,保证其唯一性 | 否 | 随机id | 如果推理服务中同时有两个相同req_id的请求,会返回req_id重复的错误信息 |
+| text | str | 请求的文本 | 否 | 无 | text 和 messages 必须有一个 |
+| messages | str | 多轮对话文本 | 否 | 无 | 多轮对话以list方式存储 |
+| max_dec_len | int | 最大生成token的长度,如果请求的文本token长度加上max_dec_len大于模型的max_seq_len,会返回长度超限的错误信息 | 否 | max_seq_len减去文本token长度 | |
+| min_dec_len | int | 最小生成token的长度,最小是1 | 否 | 1 | |
+| topp | float | 控制随机性参数,数值越大则随机性越大,范围是0~1 | 否 | 0.7 | |
+| temperature | float | 控制随机性参数,数值越小随机性越大,需要大于 0 | 否 | 0.95 | |
+| frequency_score | float | 频率分数 | 否 | 0 | |
+| penalty_score | float | 惩罚分数 | 否 | 1 | |
+| presence_score | float | 存在分数 | 否 | 0 | |
+| stream | bool | 是否流式返回 | 否 | False | |
+| return_all_tokens | bool | 是否一次性返回所有结果 | 否 | False | 与stream参数差异见表后备注 |
+| timeout | int | 请求等待的超时时间,单位是秒 | 否 | 300 | |
+| return_usage | bool | 是否返回输入、输出 token 数量 | 否 | False | |
+
+* 在正确配置PUSH_MODE_HTTP_PORT字段下,服务支持 GRPC 和 HTTP 两种请求服务
+ * stream 参数仅对 HTTP 请求生效
+ * return_all_tokens 参数对 GRPC 和 HTTP 请求均有效
diff --git a/llm/server/requirements-dev.txt b/llm/server/requirements-dev.txt
new file mode 100644
index 000000000000..e1eec92d201a
--- /dev/null
+++ b/llm/server/requirements-dev.txt
@@ -0,0 +1,3 @@
+black[jupyter] == 23.3.0
+isort == 5.11.5
+pre-commit
diff --git a/llm/server/server/config/config.pbtxt b/llm/server/server/config/config.pbtxt
new file mode 100644
index 000000000000..375c41d01331
--- /dev/null
+++ b/llm/server/server/config/config.pbtxt
@@ -0,0 +1,20 @@
+backend: "python"
+max_batch_size: 0
+model_transaction_policy {
+ decoupled: True
+}
+input [
+ {
+ name: "IN"
+ data_type: TYPE_STRING
+ dims: [ 1 ]
+ }
+]
+output [
+ {
+ name: "OUT"
+ data_type: TYPE_STRING
+ dims: [ 1 ]
+ }
+]
+instance_group [{ kind: KIND_CPU }]
diff --git a/llm/server/server/requirements.txt b/llm/server/server/requirements.txt
new file mode 100644
index 000000000000..d7bd2b1ac6e0
--- /dev/null
+++ b/llm/server/server/requirements.txt
@@ -0,0 +1,21 @@
+# model server
+sentencepiece
+pycryptodome
+tritonclient[all]==2.41.1
+opencv-python
+patchify
+transformers
+
+# http server
+fastapi
+httpx
+openai==1.44.1
+asyncio
+uvicorn
+shortuuid
+
+# parameter search
+pynvml
+
+# paddlenlp
+tiktoken
diff --git a/llm/server/server/scripts/start_server.sh b/llm/server/server/scripts/start_server.sh
new file mode 100644
index 000000000000..e7975b3e838f
--- /dev/null
+++ b/llm/server/server/scripts/start_server.sh
@@ -0,0 +1,57 @@
+#!/usr/bin/bash
+
+export GLOG_v=0
+export GLOG_logtostderr=1
+export PYTHONIOENCODING=utf8
+export LC_ALL=C.UTF-8
+
+# PaddlePaddle environment variables
+export FLAGS_allocator_strategy=auto_growth
+export FLAGS_dynamic_static_unified_comm=0
+export FLAGS_use_xqa_optim=1
+export FLAGS_gemm_use_half_precision_compute_type=0
+export NVIDIA_TF32_OVERRIDE=0
+
+# Model hyperparameters
+export MP_NUM=${MP_NUM:-"1"} # Number of GPUs
+export CUDA_VISIBLE_DEVICES=${CUDA_VISIBLE_DEVICES:-"0"} # GPU ids
+export MAX_SEQ_LEN=${MAX_SEQ_LEN:-"8192"}
+export MAX_DEC_LEN=${MAX_DEC_LEN:-"2048"}
+export BATCH_SIZE=${BATCH_SIZE:-"20"}
+export BLOCK_BS=${BLOCK_BS:-"4"}
+export BLOCK_SIZE=${BLOCK_SIZE:-"64"}
+export DTYPE=${DTYPE:-"bfloat16"}
+export USE_CACHE_KV_INT8=${USE_CACHE_KV_INT8:-"0"} # c8 model requires configuration 1
+export BLOCK_RATIO=${BLOCK_RATIO:-"0.75"}
+export ENC_DEC_BLOCK_NUM=${ENC_DEC_BLOCK_NUM:-"4"}
+export FIRST_TOKEN_ID=${FIRST_TOKEN_ID:-"1"}
+export MAX_PREFILL_BATCH=${MAX_PREFILL_BATCH:-"4"}
+export STOP_THRESHOLD=${STOP_THRESHOLD:-"0"}
+export MODEL_DIR=${MODEL_DIR:-"/models"}
+export DISTRIBUTED_CONFIG=${DISTRIBUTED_CONFIG:-"${MODEL_DIR}/rank_mapping.csv"}
+export CONFIG_JSON_FILE=${CONFIG_JSON_FILE:-"config.json"}
+export PUSH_MODE_HTTP_WORKERS=${PUSH_MODE_HTTP_WORKERS:-"4"}
+
+# serving port
+export HTTP_PORT=${HTTP_PORT:-"8110"}
+export GRPC_PORT=${GRPC_PORT:-"8811"}
+export METRICS_PORT=${METRICS_PORT:-"8722"}
+export INFER_QUEUE_PORT=${INFER_QUEUE_PORT:-"8813"}
+export PUSH_MODE_HTTP_PORT=${PUSH_MODE_HTTP_PORT:-"9965"}
+
+mkdir -p log
+rm -rf console.log log/*
+rm -rf /dev/shm/*
+
+echo "start serving ..."
+
+tritonserver --exit-timeout-secs 100 --cuda-memory-pool-byte-size 0:0 --cuda-memory-pool-byte-size 1:0 \
+ --cuda-memory-pool-byte-size 2:0 --cuda-memory-pool-byte-size 3:0 --cuda-memory-pool-byte-size 4:0 \
+ --cuda-memory-pool-byte-size 5:0 --cuda-memory-pool-byte-size 6:0 --cuda-memory-pool-byte-size 7:0 \
+ --pinned-memory-pool-byte-size 0 --model-repository llm_model/ \
+ --allow-http false \
+ --grpc-port=${GRPC_PORT} \
+ --metrics-port=${METRICS_PORT} \
+ --log-file log/server.log --log-info true > log/console.log 2>&1 &
+
+echo "The logs for the model service, please check" ${PWD}"/log/server.log and "${PWD}"/log/workerlog.0"
diff --git a/llm/server/server/scripts/stop_server.sh b/llm/server/server/scripts/stop_server.sh
new file mode 100644
index 000000000000..89cfa42f3aa8
--- /dev/null
+++ b/llm/server/server/scripts/stop_server.sh
@@ -0,0 +1,68 @@
+# /bin/bash
+
+pids=($(ps aux | grep -E 'tritonserver' | grep -v grep | awk '{print $2}'))
+
+if [ ${#pids[@]} -eq 0 ]; then
+ echo "Can not find tritonserver."
+ timeout=1
+else
+ timeout=300
+fi
+
+# kill processor
+for pid in "${pids[@]}"; do
+ echo "killing $pid"
+ kill -2 "$pid"
+done
+
+timeout_interval=$1
+if [ ! "$timeout_interval" == "" ]; then
+ timeout=$timeout_interval
+ echo $timeout
+fi
+
+start_time=$(date +%s)
+
+while : ; do
+ current_time=$(date +%s)
+
+ elapsed_time=$((current_time - start_time))
+
+ if [ $elapsed_time -ge $timeout ]; then
+ echo "forcibly kill all process ..."
+ pids=$(ps auxww | grep -E "tritonserver|triton_python_backend_stub|infer|multiprocessing.resource_tracker|paddle.distributed.launch|task_queue_manager|app.py|spawn_main" | grep -v grep | grep -v start_both | awk '{print $2}');
+ echo $pids;
+ for pid in ${pids[@]}; do
+ kill -9 ${pid}
+ done
+ break
+ fi
+
+ pids=$(ps auxww | grep -E "tritonserver|triton_python_backend_stub|multiprocessing.resource_tracker|paddle.distributed.launch|app.py|spawn_main" | grep -v grep | awk '{print $2}');
+ array=($(echo "$pids" | tr ' ' '\n'))
+
+ if [ ${#array[*]} -ne 0 ]; then
+ echo "cleaning process, please wait ..."
+ sleep 1
+ else
+ echo "clean finished."
+ break
+ fi
+done
+
+manager_pids=$(ps auxww | grep "task_queue_manager" | grep -v grep | awk '{print $2}')
+echo $manager_pids
+for in_pid in ${manager_pids[@]}; do
+ kill -9 ${in_pid}
+done
+echo 'end kill queue manager'
+
+health_checker_pids=$(ps auxww | grep "health.py" | grep -v grep | awk '{print $2}')
+echo $health_checker_pids
+for in_pid in ${health_checker_pids[@]}; do
+ kill -9 ${in_pid}
+done
+echo 'end kill health checker'
+
+echo "all process terminated."
+exit 0
diff --git a/llm/server/server/server/__init__.py b/llm/server/server/server/__init__.py
new file mode 100644
index 000000000000..5ae9b7e8cf18
--- /dev/null
+++ b/llm/server/server/server/__init__.py
@@ -0,0 +1,16 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+__version__ = "dev"
+__commit__ = "dev"
diff --git a/llm/server/server/server/checker.py b/llm/server/server/server/checker.py
new file mode 100644
index 000000000000..e9f776799df0
--- /dev/null
+++ b/llm/server/server/server/checker.py
@@ -0,0 +1,124 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+
+def check_basic_params(req_dict):
+ """
+ checks input requests for basic parameters
+
+ Args:
+ req_dict (dict): request parameters
+
+ Returns:
+ list[str]: if error, return a list of error messages; return an empty list otherwise
+ """
+
+ error_msg = []
+ bools = ("text" in req_dict, "input_ids" in req_dict, "messages" in req_dict)
+ if sum(bools) == 0:
+ error_msg.append("The input parameters should contain either `text`, `input_ids` or `messages`")
+ else:
+ if "text" in req_dict:
+ if not isinstance(req_dict["text"], str):
+ error_msg.append("The `text` in input parameters must be a string")
+ elif req_dict["text"] == "":
+ error_msg.append("The `text` in input parameters cannot be empty")
+ if "system" in req_dict and not isinstance(req_dict["system"], str):
+ error_msg.append("The `system` in input parameters must be a string")
+ if "input_ids" in req_dict and not isinstance(req_dict["input_ids"], list):
+ error_msg.append("The `input_ids` in input parameters must be a list")
+ if "messages" in req_dict:
+ msg_len = len(req_dict["messages"])
+ if not all("content" in item for item in req_dict["messages"]):
+ error_msg.append("The item in messages must include `content`")
+
+ if "req_id" not in req_dict:
+ error_msg.append("The input parameters should contain `req_id`.")
+
+ if "min_dec_len" in req_dict and \
+ (not isinstance(req_dict["min_dec_len"], int) or req_dict["min_dec_len"] < 1):
+ error_msg.append("The `min_dec_len` must be an integer and greater than 0")
+
+ keys = ("max_dec_len", "seq_len", "max_tokens")
+ for key in keys:
+ if key in req_dict and (not isinstance(req_dict[key], int) or req_dict[key] < 1):
+ error_msg.append(f"The `{key}` must be an integer and greater than 0")
+ if "seq_len" in req_dict and "max_dec_len" not in req_dict:
+ req_dict["max_dec_len"] = req_dict["seq_len"]
+ if "max_tokens" in req_dict and "max_dec_len" not in req_dict:
+ req_dict["max_dec_len"] = req_dict["max_tokens"]
+
+ keys = ("topp", "top_p")
+ if sum([key in req_dict for key in keys]) > 1:
+ error_msg.append(f"Only one of {keys} should be set")
+ else:
+ for key in keys:
+ if key in req_dict and not 0 <= req_dict[key] <= 1:
+ error_msg.append(f"The `{key}` must be in [0, 1]")
+ if "top_p" in req_dict and "topp" not in req_dict:
+ req_dict["topp"] = req_dict["top_p"]
+
+ if "temperature" in req_dict and not 0 <= req_dict["temperature"]:
+ error_msg.append(f"The `temperature` must be >= 0")
+
+ if "eos_token_ids" in req_dict:
+ if isinstance(req_dict["eos_token_ids"], int):
+ req_dict["eos_token_ids"] = [req_dict["eos_token_ids"]]
+ elif isinstance(req_dict["eos_token_ids"], tuple):
+ req_dict["eos_token_ids"] = list(req_dict["eos_token_ids"])
+ if not isinstance(req_dict["eos_token_ids"], list):
+ error_msg.append("The `eos_token_ids` must be an list")
+ elif len(req_dict["eos_token_ids"]) > 1:
+ error_msg.append("The length of `eos_token_ids` must be 1 if you set it")
+
+ keys = ("infer_seed", "seed")
+ if sum([key in req_dict for key in keys]) > 1:
+ error_msg.append(f"Only one of {keys} should be set")
+ else:
+ if "seed" in req_dict and "infer_seed" not in req_dict:
+ req_dict["infer_seed"] = req_dict["seed"]
+
+ if "stream" in req_dict and not isinstance(req_dict["stream"], bool):
+ error_msg.append("The `stream` must be a boolean")
+
+ if "response_type" in req_dict and (req_dict["response_type"].lower() not in ("fastdeploy", "openai")):
+ error_msg.append("The `response_type` must be either `fastdeploy` or `openai`.")
+
+ return error_msg
+
+
+def add_default_params(req_dict):
+ """
+ add default params to req_dict
+
+ Args:
+ req_dict (dict): input dict
+
+ Returns:
+ dict: req_dict with default params
+ """
+ assert isinstance(req_dict, dict), "The `req_dict` must be a dict."
+ if "min_dec_len" not in req_dict:
+ req_dict["min_dec_len"] = 1
+ if "topp" not in req_dict:
+ req_dict["topp"] = 0.7
+ if "temperature" not in req_dict:
+ req_dict["temperature"] = 0.95
+ if "penalty_score" not in req_dict:
+ req_dict["penalty_score"] = 1.0
+ if "frequency_score" not in req_dict:
+ req_dict["frequency_score"] = 0.0
+ if "presence_score" not in req_dict:
+ req_dict["presence_score"] = 0.0
+ return req_dict
diff --git a/llm/server/server/server/data/__init__.py b/llm/server/server/server/data/__init__.py
new file mode 100644
index 000000000000..97043fd7ba68
--- /dev/null
+++ b/llm/server/server/server/data/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/llm/server/server/server/data/processor.py b/llm/server/server/server/data/processor.py
new file mode 100644
index 000000000000..423fe6b61408
--- /dev/null
+++ b/llm/server/server/server/data/processor.py
@@ -0,0 +1,336 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+from abc import ABC, abstractmethod
+
+from paddlenlp.transformers import Llama3Tokenizer, LlamaTokenizer
+from paddlenlp.trl.llm_utils import get_eos_token_id
+from server.engine.config import Config
+from server.utils import data_processor_logger
+
+
+class BaseDataProcessor(ABC):
+ """base class for data processor"""
+
+ def __init__(self):
+ """
+ Returns:
+ None
+ """
+ self.tokenizer = self._load_tokenizer()
+ self.tokenizer.bos_token_id = self.tokenizer._convert_token_to_id(self.tokenizer.bos_token)
+ self.tokenizer.cls_token_id = self.tokenizer._convert_token_to_id(self.tokenizer.cls_token)
+ self.tokenizer.sep_token_id = self.tokenizer._convert_token_to_id(self.tokenizer.sep_token)
+ self.tokenizer.eos_token_id = self.tokenizer._convert_token_to_id(self.tokenizer.eos_token)
+ self.tokenizer.mask_token_id = self.tokenizer._convert_token_to_id(self.tokenizer.mask_token)
+ data_processor_logger.info((f"tokenizer infomation: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, ",
+ f"cls_token is {self.tokenizer.cls_token}, {self.tokenizer.cls_token_id}, "
+ f"sep_token is {self.tokenizer.sep_token}, {self.tokenizer.sep_token_id}, "
+ f"eos_token is {self.tokenizer.eos_token}, {self.tokenizer.eos_token_id}, "
+ f"mask_token is {self.tokenizer.mask_token}, {self.tokenizer.mask_token_id}"))
+
+ @abstractmethod
+ def process_request(self, request, **kwargs):
+ """
+ Preprocess the request
+
+ Args:
+ request (Dict): may contain text and messages fields
+ **kwargs: others
+
+ Returns:
+ bool: Whether preprocessing is successful
+ str: error message
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def process_response(self, response_dict):
+ """
+ Preprocess the response
+
+ Args:
+ response_dict (Dict): response for engine, contain ids fields
+
+ Returns:
+ Dict: response contain text fields
+ """
+ raise NotImplementedError
+
+ def text2ids(self, text):
+ """
+ text to token ids
+
+ Args:
+ text (str): text
+
+ Returns:
+ List[int]: token ids list
+ """
+ raise NotImplementedError
+
+ def messages2ids(self, messages):
+ """
+ Convert multi-turn messages into ID sequences.
+
+ Args:
+ messages (List[List[Dict[str, Any]]]): multi-turn messages.
+
+ Returns:
+ List[int]: ID sequences
+ """
+ raise NotImplementedError
+
+ def ids2tokens(self, token_ids, task_id=None):
+ """
+ token ids to strings
+
+ Args:
+ token_ids (List[int]): token ids
+ task_id (str): task id
+
+ Returns:
+ List[str]: strings
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def _load_tokenizer(self):
+ """
+ load tokenizer
+
+ Returns:
+ tokenizer (AutoTokenizer)
+ """
+ raise NotImplementedError
+
+
+class DataProcessor(BaseDataProcessor):
+ def __init__(self):
+ self.config = Config()
+ max_length = self.config.get_model_config().get('max_length', 1024)
+ self.src_length = max_length - self.config.seq_len_limit
+
+ self.decode_status = dict()
+ self.tokenizer = self._load_tokenizer()
+ data_processor_logger.info(f"tokenizer infomation: bos_token is {self.tokenizer.bos_token}, {self.tokenizer.bos_token_id}, \
+ eos_token is {self.tokenizer.eos_token}, {self.tokenizer.eos_token_id} ")
+
+ def process_request(self, request, max_seq_len=None):
+ """
+ Preprocess the request
+
+ Args:
+ request (Dict): may contain text and messages fields
+
+ Returns:
+ bool: Whether preprocessing is successful
+ str: error message
+ """
+ if "eos_token_ids" not in request or request["eos_token_ids"] == [None]:
+ request["eos_token_ids"] = []
+ request["eos_token_ids"].extend(get_eos_token_id(self.tokenizer, self.config.generation_config))
+
+ if "input_ids" not in request or \
+ (isinstance(request["input_ids"], (list, tuple)) and len(request["input_ids"]) == 0):
+ if "text" in request:
+ request["input_ids"] = self.text2ids(request["text"])
+ elif "messages" in request:
+ if self.tokenizer.chat_template is None:
+ raise ValueError(f"This model does not support chat_template.")
+ request["input_ids"] = self.messages2ids(request["messages"])
+ else:
+ raise ValueError(f"The request should have `input_ids`, `text` or `messages`: {request}.")
+
+ if max_seq_len is not None and len(request["input_ids"]) > max_seq_len:
+ request["input_ids"] = request["input_ids"][:max_seq_len-1]
+ data_processor_logger.info(f"processed request: {request}")
+ return request
+
+ def process_response(self, response_dict, **kwargs):
+ """
+ Preprocess the response
+
+ Args:
+ response_dict (Dict): response for engine, contain ids fields
+
+ Returns:
+ Dict: response contain text fields
+ """
+ is_end = response_dict.get("is_end", 0)
+ req_id = response_dict.get("req_id")
+ if "choices" in response_dict:
+ for i in range(len(response_dict["choices"])):
+ response_dict["token"] = self.ids2tokens(response_dict["choices"][i]["token_ids"], req_id)
+ return response_dict
+
+ token_ids = response_dict.get("token_ids", [])
+ response_dict["token"] = self.ids2tokens(token_ids, response_dict["req_id"])
+ response_dict["usage"] = {"completion_tokens" : response_dict["send_idx"] + 1}
+
+ if is_end:
+ response_dict["tokens_all"] = self.clear_request_status(req_id)
+ return response_dict
+
+ def text2ids(self, text):
+ """
+ text to token ids
+
+ Args:
+ text (str): text
+
+ Returns:
+ List[int]: token ids list
+ """
+ if self.config.use_hf_tokenizer:
+ tokens = self.tokenizer(
+ text,
+ return_tensors="np",
+ padding=True,
+ truncation=True,
+ )
+ else:
+ if self.tokenizer.chat_template is not None:
+ text = [text] if isinstance(text, str) else text
+ text = [self.tokenizer.apply_chat_template(sentence, tokenize=False) for sentence in text]
+
+ tokens = self.tokenizer(
+ text,
+ return_tensors="np",
+ padding=True,
+ truncation=True,
+ max_length=self.src_length,
+ add_special_tokens=self.tokenizer.chat_template is None,
+ )
+ return tokens["input_ids"][0]
+
+ def messages2ids(self, messages):
+ """
+ Convert multi-turn messages into ID sequences.
+
+ Args:
+ messages (List[List[Dict[str, Any]]]): multi-turn messages.
+
+ Returns:
+ List[int]: ID sequences
+ """
+ message_result = self.tokenizer.apply_chat_template(messages, return_tensors="pd")
+ return message_result["input_ids"][0]
+
+ def ids2tokens(self, token_id, task_id):
+ """
+ token ids to strings
+
+ Args:
+ token_ids (List[int]): token ids
+ task_id (str): task id
+
+ Returns:
+ List[str]: strings
+ """
+ if self.config.use_hf_tokenizer:
+ if task_id not in self.decode_status:
+ # history token ids & history token strings & befer decode str
+ self.decode_status[task_id] = [[], [], ""]
+
+ previous_token_ids = self.decode_status[task_id][0]
+ decode_str = self.tokenizer.batch_decode([previous_token_ids + token_id],
+ skip_special_tokens=True,
+ clean_up_tokenization_spaces=False)
+ if isinstance(decode_str, list) and len(decode_str):
+ new_str = decode_str[0].replace(self.decode_status[task_id][2], "", 1)
+ self.decode_status[task_id][1].append(new_str)
+ self.decode_status[task_id][2] = decode_str[0]
+ else:
+ new_str = ""
+ self.decode_status[task_id][0] += token_id
+ return new_str
+ else:
+ if task_id not in self.decode_status:
+ # prefix offset & read offset & history token ids & history token strings
+ self.decode_status[task_id] = [0, 0, [], []]
+
+ prefix_offset = self.decode_status[task_id][0]
+ read_offset = self.decode_status[task_id][1]
+ previous_token_ids = self.decode_status[task_id][2]
+ decode_str, prefix_offset, read_offset = self.tokenizer.decode_token(
+ previous_token_ids + token_id, prefix_offset, read_offset)
+ self.decode_status[task_id][0] = prefix_offset
+ self.decode_status[task_id][1] = read_offset
+ self.decode_status[task_id][2] += token_id
+ self.decode_status[task_id][3].append(decode_str)
+ return decode_str
+
+ def _load_tokenizer(self):
+ """
+ load tokenizer
+
+ Returns:
+ tokenizer (AutoTokenizer)
+ """
+ if self.config.use_hf_tokenizer:
+ from transformers import AutoTokenizer
+ return AutoTokenizer.from_pretrained(self.config.model_dir, use_fast=False, vocab_file=os.path.join(self.config.model_dir, "sentencepiece.bpe.model"))
+ else:
+ from paddlenlp.transformers import AutoTokenizer
+ return AutoTokenizer.from_pretrained(self.config.model_dir)
+
+ def clear_request_status(self, task_id):
+ """
+ clear request status
+
+ Args:
+ task_id (str): task id
+
+ Returns:
+ results_all (str): all token strings
+ """
+ results_all = ""
+ if task_id in self.decode_status:
+ if self.config.use_hf_tokenizer:
+ results_all = self.decode_status[task_id][2]
+ else:
+ results_all = "".join(self.decode_status[task_id][3])
+ del self.decode_status[task_id]
+ return results_all
+
+ def get_eos_tokens_lens(self):
+ """
+ get eos_token_id lens
+
+ Returns:
+ int: eos_token_id lens
+ """
+ return len(get_eos_token_id(self.tokenizer, self.config.generation_config))
+
+ def get_eos_tokens(self):
+ """
+ get all eos_token_id
+
+ Returns:
+ List[int]: eos_token_id list
+ """
+ return get_eos_token_id(self.tokenizer, self.config.generation_config)
+
+ def get_pad_id(self):
+ """
+ get pad_token_id, if not pad_token_id, use eos_token
+
+ Returns:
+ int: pad_token_id
+ """
+ if isinstance(self.tokenizer, (LlamaTokenizer, Llama3Tokenizer)) and not self.tokenizer.pad_token_id:
+ return self.tokenizer.eos_token
+ return self.tokenizer.pad_token_id
diff --git a/llm/server/server/server/engine/__init__.py b/llm/server/server/server/engine/__init__.py
new file mode 100644
index 000000000000..fd05a9208165
--- /dev/null
+++ b/llm/server/server/server/engine/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/llm/server/server/server/engine/config.py b/llm/server/server/server/engine/config.py
new file mode 100644
index 000000000000..6f0e1964e21f
--- /dev/null
+++ b/llm/server/server/server/engine/config.py
@@ -0,0 +1,236 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import os
+import sys
+from datetime import datetime
+
+from paddlenlp.generation import GenerationConfig
+from server.utils import model_server_logger
+
+
+class Config:
+ """
+ initial configuration
+ """
+
+ def __init__(self):
+ self.read_from_env()
+
+ def read_from_env(self):
+ """
+ get the configuration from environment
+ """
+ env = os.environ
+ self.model_dir = env.get(
+ "MODEL_DIR", "/opt/output/Serving/models")
+ if not self.model_dir:
+ raise Exception("The parameter MODEL_DIR is None.")
+ self.mp_num = int(env.get("MP_NUM", 8))
+ self.config_json_file = env.get("CONFIG_JSON_FILE", "config.json")
+ self.model_config_path = os.path.join(self.model_dir, self.config_json_file)
+ if env.get("FD_MODEL_CONFIG_PATH", None):
+ self.model_config_path = env.get("FD_MODEL_CONFIG_PATH")
+
+ # distributed config
+ self.distributed_config_path = os.path.join(self.model_dir, "rank_mapping.csv")
+ if os.getenv("DISTRIBUTED_CONFIG", None):
+ self.distributed_config_path = os.getenv("DISTRIBUTED_CONFIG")
+
+ # device config
+ self.device = env.get("DEVICE", "GPU")
+ self.device_ids = ",".join([str(i) for i in range(self.mp_num)])
+ if self.device == "GPU":
+ self.device_ids = os.getenv("CUDA_VISIBLE_DEVICES",
+ self.device_ids)
+ else:
+ raise Exception(f"unsupported device type: {self.device}")
+
+ # Triton config
+ self.max_prefill_batch = int(os.getenv("MAX_PREFILL_BATCH", 1))
+ if self.max_prefill_batch <= 0:
+ raise Exception(f"MAX_PREFILL_BATCH ({self.max_prefill_batch}) must be greater than 0")
+ self.disable_streaming = int(os.getenv("DISABLE_STREAMING", 0))
+
+ # max cached task num
+ self.max_cached_task_num = int(os.getenv("MAX_CACHED_TASK_NUM", "128"))
+ # if PUSH_MODE_HTTP_PORT is not configured, only GRPC service is enabled
+ self.push_mode_http_port = int(os.getenv("PUSH_MODE_HTTP_PORT", "-1"))
+ if self.push_mode_http_port > 0:
+ grpc_port = os.getenv("GRPC_PORT", None)
+ if grpc_port is None:
+ raise Exception("GRPC_PORT cannot be None, while PUSH_MODE_HTTP_PORT>0")
+ self.grpc_port = int(grpc_port)
+
+ # http worker num
+ self.push_mode_http_workers = int(os.getenv("PUSH_MODE_HTTP_WORKERS", "1"))
+ if self.push_mode_http_workers < 1:
+ raise Exception(f"PUSH_MODE_HTTP_WORKERS ({self.push_mode_http_workers}) must be positive")
+
+ # Padlle commit id
+ import paddle
+ self.paddle_commit_id = paddle.version.commit
+
+ # time interval for detecting whether the engine loop is normal during probing
+ self.check_health_interval = int(os.getenv("CHECK_HEALTH_INTERVAL", 10))
+
+ # model config
+ self.dtype = env.get("DTYPE", "bfloat16")
+ self.block_size = int(env.get("BLOCK_SIZE", 64))
+ self.use_cache_kv_int8 = int(os.getenv("USE_CACHE_KV_INT8", 0))
+ self.use_cache_kv_int4 = int(os.getenv("USE_CACHE_KV_INT4", 0))
+
+ # infer config
+ self.max_batch_size = int(env.get("BATCH_SIZE", 50))
+ self.max_seq_len = int(env.get("MAX_SEQ_LEN", 8192))
+ self.max_dec_len = int(env.get("MAX_DEC_LEN", 1024))
+ self.enc_dec_block_num = int(os.getenv("ENC_DEC_BLOCK_NUM", 2))
+ self.block_bs = float(env.get("BLOCK_BS", 50))
+ self.block_ratio = float(os.getenv("BLOCK_RATIO", 0.75))
+ self.bad_tokens = str(env.get("BAD_TOKENS", "-1"))
+ self.first_token_id = int(os.getenv("FIRST_TOKEN_ID", 1))
+
+ # infer queue port
+ self.infer_port = int(os.getenv("INFER_QUEUE_PORT", 56666))
+
+ # whether to use custom health checker
+ self.use_custom_health_checker = int(os.getenv("USE_CUSTOM_HEALTH_CHECKER", 1))
+
+ # Check the legality of requests
+ self.seq_len_limit = int(env.get("MAX_SEQ_LEN", 8192))
+ self.dec_len_limit = int(env.get("MAX_DEC_LEN", 1024))
+
+ # warmup
+ self.use_warmup = int(os.getenv("USE_WARMUP", 0)) == 1
+
+ # uuid
+ self.shm_uuid = os.getenv("SHM_UUID", '')
+
+ # use huggingface tokenizer
+ self.use_hf_tokenizer = int(os.getenv("USE_HF_TOKENIZER", 0)) == 1
+
+ # Generation config
+ try:
+ self.generation_config = GenerationConfig.from_pretrained(self.model_dir)
+ except:
+ model_server_logger.warning(
+ "Can't find generation config, so it will not use generation_config field in the model config"
+ )
+ self.generation_config = None
+
+ self.read_from_config()
+ self.postprocess()
+ self.check()
+
+ def postprocess(self):
+ """
+ calculate some parameters
+ """
+ if self.block_ratio >= 1.0:
+ self.enc_dec_block_num = (self.max_dec_len + self.block_size - 1) // self.block_size
+ self.max_query_block_num = (max(self.max_dec_len, self.max_seq_len) +
+ self.block_size - 1) // self.block_size
+ self.max_query_block_num = (self.max_dec_len + self.max_seq_len +
+ self.block_size - 1) // self.block_size
+ self.dec_token_num = self.enc_dec_block_num * self.block_size
+ self.total_block_num = int(self.block_bs * self.max_query_block_num)
+ self.max_block_num = int(self.total_block_num * self.block_ratio)
+ model_server_logger.info(f"max_block_num:{self.max_block_num}")
+
+ def check(self):
+ """
+ check the legality of config
+ """
+ assert self.max_batch_size <= 256, (
+ "The parameter `max_batch_size` is not allowed to exceed 256, "
+ "but now it's {}.".format(self.max_batch_size)
+ )
+ assert self.seq_len_limit <= self.max_seq_len, (
+ f"The seq_len_limit shouldn't greater than max_seq_len in model, "
+ f"which means the exported MAX_SEQ_LEN should less than "
+ f"{self.max_seq_len}, but now it's {self.seq_len_limit}."
+ )
+ assert self.dec_len_limit <= self.max_seq_len, (
+ f"The dec_len_limit shouldn't greater than max_seq_len in model, "
+ f"which means the exported MAX_DEC_LEN should less than "
+ f"{self.max_seq_len}, but now it's {self.dec_len_limit}."
+ )
+
+ def print(self, file=None):
+ """
+ print all config
+
+ Args:
+ file (str): the path of file to save config
+ """
+ model_server_logger.info(
+ "=================== Configuration Information ===============")
+ for k, v in self.__dict__.items():
+ if k == "generation_config" and v is not None:
+ for gck, gcv in v.to_dict().items():
+ model_server_logger.info("{:<20}:{:<6}{}".format(gck, "", gcv))
+ else:
+ model_server_logger.info("{:<20}:{:<6}{}".format(k, "", v))
+ model_server_logger.info(
+ "=============================================================")
+ if file is not None:
+ f = open(file, "a")
+ now_time = datetime.now()
+ f.write(f"{now_time} configuration information as below,\n")
+ for k, v in self.__dict__.items():
+ f.write("{:<20}:{:<6}{}\n".format(k, "", v))
+ f.close()
+
+ def get_model_config(self):
+ """
+ load config file
+
+ Returns:
+ dict: the config file
+ """
+ model_config_json = json.load(open(self.model_config_path, 'r', encoding='utf-8'))
+ return model_config_json
+
+ def read_from_config(self):
+ """
+ reset model config from json file
+ """
+ from server.utils import get_logger
+ logger = get_logger("model_server", "infer_config.log")
+ config = self.get_model_config()
+
+ def reset_value(self, value_name, key, config):
+ if key in config:
+ value = config[key]
+ setattr(self, value_name, value)
+ logger.info(f"Reset parameter {value_name} = {value} from configuration.")
+
+ reset_value(self, "block_size", "infer_model_block_size", config)
+ reset_value(self, "max_seq_len", "infer_model_max_seq_len", config)
+
+ assert self.seq_len_limit <= self.max_seq_len, f"The loading model requires len(input_ids) <= {self.max_seq_len}, but now the setting MAX_SEQ_LEN={self.seq_len_limit}."
+ assert self.dec_len_limit <= self.max_seq_len, f"The loading model requires MAX_DEC_LEN <= {self.max_seq_len}, but now the setting MAX_DEC_LEN={self.dec_len_limit}."
+
+ def get_unique_name(self, name):
+ """
+ get unique name
+
+ Args:
+ name (str): the name add uuid
+ """
+ return name + f"_{self.shm_uuid}"
+
+ def __str__(self) -> str:
+ return json.dumps(self.__dict__, indent=4)
diff --git a/llm/server/server/server/engine/engine.py b/llm/server/server/server/engine/engine.py
new file mode 100644
index 000000000000..932404d9c094
--- /dev/null
+++ b/llm/server/server/server/engine/engine.py
@@ -0,0 +1,401 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import multiprocessing
+import os
+import signal
+import subprocess
+import time
+import uuid
+import weakref
+from datetime import datetime
+from multiprocessing import shared_memory
+
+import numpy as np
+from server.engine.resource_manager import ResourceManager
+from server.engine.task_queue_manager import (TaskQueueManager,
+ launch_queue_service)
+from server.engine.token_processor import TokenProcessor, WarmUpTokenProcessor
+from server.utils import model_server_logger
+
+
+class Engine(object):
+ """
+ Engine Class
+ """
+ def __init__(self, cfg, token_processor):
+ self.cfg = cfg
+ self.resource_manager = ResourceManager(self.cfg)
+ self.token_processor = token_processor
+ self.token_processor.set_resource_manager(self.resource_manager)
+ self.is_started = False
+
+ self._init_engine_flags()
+ self._finalizer = weakref.finalize(self, self._exit_sub_services)
+
+ def start(self):
+ """
+ initialize engine and start sub services
+ """
+ assert not self.is_started, "The engine is already started.!"
+ start_time = time.time()
+ self.queue_service = self._start_tasks_queue_service()
+ self.tasks_queue = TaskQueueManager(mp_num=self.cfg.mp_num, port=self.cfg.infer_port)
+
+ self.token_processor.tasks_queue = self.tasks_queue
+ self.infer_proc = self._start_infer_service()
+ model_server_logger.info("Waitting infer processes ready...")
+ while not self._infer_processes_ready():
+ time.sleep(1)
+ self.is_started = True
+
+ # start warmup
+ if self.cfg.use_warmup:
+ model_server_logger.info("Start warmup")
+ self._set_warmup_token_processor()
+ self.warmup()
+ self._del_warmup_token_processor()
+ model_server_logger.info("Warmup finish")
+
+ # start TokenProcessor thread
+ self.token_processor.run()
+ model_server_logger.info("Infer processes are launched with {} seconds.".format(time.time() - start_time))
+
+ def warmup(self):
+ """
+ construct test tasks and avoid out of memory problem in the infer process
+ """
+ # get eos_token_id
+ from server.data.processor import DataProcessor
+ eos_token_ids = DataProcessor().get_eos_tokens()
+
+ # construct test tasks
+ res_task = []
+ for j in range(2 * self.cfg.max_batch_size):
+ data = {
+ "input_ids": [5],
+ "req_id": j,
+ "max_dec_len": self.cfg.dec_len_limit,
+ "min_dec_len": int(self.cfg.dec_len_limit * 0.5) + 1,
+ "eos_token_ids": eos_token_ids
+ }
+ res_task.append(data)
+ for j in range(2 * self.cfg.max_prefill_batch):
+ data = {
+ "input_ids": [5] * self.cfg.seq_len_limit,
+ "req_id": j + 2 * self.cfg.max_batch_size,
+ "max_dec_len": 1,
+ "min_dec_len": 1,
+ "eos_token_ids": eos_token_ids
+ }
+ res_task.append(data)
+
+ for x in res_task:
+ while self.available_batch() == 0 or not self.insert_tasks([x]):
+ time.sleep(0.0002)
+
+ self.token_processor._is_blocking = False
+ # wait for all tasks finished
+ while not self.all_tasks_finished():
+ time.sleep(1)
+
+ def insert_tasks(self, tasks):
+ """
+ insert tasks to the engine
+
+ Args:
+ tasks: list of tasks
+
+ Returns:
+ return: True if success, False otherwise
+ """
+ if not isinstance(tasks, list):
+ tasks = [tasks]
+
+ for item in tasks:
+ item["schedule_start_time"] = datetime.now()
+
+ available_batch = np.sum(self.resource_manager.stop_flags)
+ if len(tasks) > available_batch:
+ model_server_logger.error("Inserting batch:{} exceeds the available batch:{}.".format(
+ len(tasks), available_batch))
+ model_server_logger.error("The exceeded part will be ignored!")
+ tasks = tasks[:available_batch]
+
+ for i in range(len(tasks)):
+ req_id = tasks[i]["req_id"]
+ input_token_num = len(tasks[i]["input_ids"])
+ if input_token_num >= self.cfg.max_seq_len - 1:
+ model_server_logger.warning(f"{req_id}: Input length:{input_token_num}, exceed the limits.")
+ tasks[i]["input_ids"] = tasks[i]["input_ids"][:self.cfg.max_seq_len - 1]
+ if "seq_len" in tasks[i] and "max_dec_len" not in tasks[i]:
+ tasks[i]["max_dec_len"] = tasks[i]["seq_len"]
+
+ # max_dec_len + input_token_num > MAX_SEQ_LEN
+ if input_token_num + tasks[i]["max_dec_len"] > self.cfg.max_seq_len:
+ tasks[i]["max_dec_len"] = self.cfg.max_seq_len - input_token_num
+ model_server_logger.warning("Force max_dec_len to be {} for req_id={}.".format(
+ tasks[i]["max_dec_len"], tasks[i]["req_id"]))
+
+ # min_dec_len + input_token_num > MAX_SEQ_LEN
+ if input_token_num + tasks[i]["min_dec_len"] > self.cfg.max_seq_len:
+ tasks[i]["min_dec_len"] = self.cfg.max_seq_len - input_token_num
+ model_server_logger.warning("Force min_dec_len to be {} for req_id={}.".format(
+ tasks[i]["min_dec_len"], tasks[i]["req_id"]))
+
+ tasks = self.resource_manager.allocate_resources_for_new_tasks(tasks)
+ if not tasks:
+ return False
+
+ self.token_processor.number_of_tasks += len(tasks)
+ for i in range(len(tasks)):
+ self.token_processor.number_of_input_tokens += len(tasks[i]["input_ids"])
+
+ req_ids = [t["req_id"] for t in tasks]
+ model_server_logger.info(f"Tasks are sent to engine, req_ids={req_ids}")
+ self.tasks_queue.put((tasks, self.resource_manager.real_bsz))
+ return True
+
+ def task_is_finished(self, index):
+ """
+ judge if the task is finished
+
+ Args:
+ index: task index
+
+ Returns:
+ return: True if finished, False otherwise
+ """
+ assert index < len(self.resource_manager.stop_flags)
+ return self.resource_manager.stop_flags[index]
+
+ def is_queue_empty(self):
+ """
+ judge if the queue is empty
+
+ Returns:
+ return: True if empty, False otherwise
+ """
+ return self.tasks_queue.empty()
+
+ def is_resource_sufficient(self, input_token_num):
+ """
+ judge if the resource is sufficient
+
+ Args:
+ input_token_num: input token number
+
+ Returns:
+ return: True if sufficient, False otherwise
+ """
+ return self.resource_manager.is_resource_sufficient(input_token_num)
+
+ def all_tasks_finished(self):
+ """
+ judge if all tasks are finished
+
+ Returns:
+ return: True if all finished, False otherwise
+ """
+ return np.sum(self.resource_manager.stop_flags) == len(self.resource_manager.stop_flags)
+
+ def available_batch(self):
+ """
+ available batch size of the engine
+
+ Returns:
+ return: available batch size
+ """
+ return self.resource_manager.available_batch()
+
+ def available_block_num(self):
+ """
+ available block number of the engine
+
+ Returns:
+ return: available block number
+ """
+ return self.resource_manager.availabel_block_num()
+
+ def _set_warmup_token_processor(self):
+ """
+ set token_processor for warmup
+ """
+ self.token_processor_backup = self.token_processor
+ self.token_processor = WarmUpTokenProcessor(self.cfg)
+ self.token_processor.set_resource_manager(self.resource_manager)
+ self.token_processor.tasks_queue = self.tasks_queue
+
+ # start TokenProcessor thread
+ self.token_processor.run()
+
+ def _del_warmup_token_processor(self):
+ """
+ delete token_processor for warmup
+ """
+ self.token_processor.stop()
+ del self.token_processor
+
+ # reset token_processor
+ self.token_processor = self.token_processor_backup
+ del self.token_processor_backup
+
+ def _infer_processes_ready(self):
+ """
+ judge if all infer processes are ready
+
+ Returns:
+ return: True if all ready, False otherwise
+ """
+ if np.sum(self.flag_ready_array) == self.cfg.mp_num:
+ return True
+ return False
+
+ def _clear_engine_flags(self):
+ """
+ clear engine flags
+ """
+ try:
+ self.shm_flag_ready.close()
+ self.shm_flag_ready.unlink()
+ self.shm_flag_has_block_step.close()
+ self.shm_flag_has_block_step.unlink()
+ except:
+ pass
+
+ def _init_engine_flags(self):
+ """
+ Initialize shared memory to indicate engine status
+ """
+ flag_array = np.zeros([self.cfg.mp_num], dtype=np.int32)
+ try:
+ tmp = shared_memory.SharedMemory(
+ create=False, size=flag_array.nbytes, name=self.cfg.get_unique_name("shm_flag_infer_ready")
+ )
+ tmp.close()
+ tmp.unlink()
+ except:
+ pass
+ self.shm_flag_ready = shared_memory.SharedMemory(
+ create=True, size=flag_array.nbytes, name=self.cfg.get_unique_name("shm_flag_infer_ready")
+ )
+ self.flag_ready_array = np.ndarray(
+ flag_array.shape, dtype=flag_array.dtype, buffer=self.shm_flag_ready.buf
+ )
+ self.flag_ready_array[:] = 0
+
+ # broadcast flag for engine
+ broadcast_flag_array = np.zeros([1], dtype=np.int32)
+ try:
+ tmp = shared_memory.SharedMemory(
+ create=False,
+ size=broadcast_flag_array.nbytes,
+ name=self.cfg.get_unique_name("shm_pd_infer_flag_broadcast"),
+ )
+ tmp.close()
+ tmp.unlink()
+ except:
+ pass
+ self.shm_flag_broadcast = shared_memory.SharedMemory(
+ create=True, size=broadcast_flag_array.nbytes, name=self.cfg.get_unique_name("shm_pd_infer_flag_broadcast")
+ )
+ self.flag_broadcast_array = np.ndarray(
+ broadcast_flag_array.shape,
+ dtype=broadcast_flag_array.dtype,
+ buffer=self.shm_flag_broadcast.buf,
+ )
+ self.flag_broadcast_array[0] = 0
+
+ has_block_step_flag_array = np.zeros([1], dtype=np.int32)
+ try:
+ tmp = shared_memory.SharedMemory(
+ create=False,
+ size=has_block_step_flag_array.nbytes,
+ name=self.cfg.get_unique_name("shm_flag_has_block_step"))
+ tmp.close()
+ tmp.unlink()
+ except:
+ pass
+ self.shm_flag_has_block_step = shared_memory.SharedMemory(
+ create=True,
+ size=has_block_step_flag_array.nbytes,
+ name=self.cfg.get_unique_name("shm_flag_has_block_step"))
+ self.flag_has_block_step_array = np.ndarray(
+ has_block_step_flag_array.shape,
+ dtype=has_block_step_flag_array.dtype,
+ buffer=self.shm_flag_has_block_step.buf)
+ self.flag_has_block_step_array[:] = 0
+
+ def _exit_sub_services(self):
+ """
+ exit sub services
+ """
+ if hasattr(self, "queue_service") and self.queue_service is not None:
+ self.queue_service.terminate()
+ self.queue_service.join()
+ if hasattr(self, "infer_proc") and self.infer_proc is not None:
+ os.killpg(self.infer_proc.pid, signal.SIGTERM)
+
+ def _start_tasks_queue_service(self):
+ """
+ start tasks queue service
+
+ Returns:
+ p: process handle
+ """
+ p = multiprocessing.Process(target=launch_queue_service, args=(self.cfg.infer_port, self.cfg.mp_num))
+ p.start()
+ time.sleep(0.3)
+ if p.is_alive():
+ model_server_logger.info("start tasks queue service successfully")
+ else:
+ error_msg = "Failed to start tasks queue service, please check " \
+ "the log/task_queue_manager.log for details"
+ model_server_logger.info(error_msg)
+ raise Exception(error_msg)
+ return p
+
+ def _start_gpu_infer_service(self):
+ """
+ start gpu infer service
+
+ Returns:
+ p: process handle
+ """
+ current_file_path = os.path.abspath(__file__)
+ current_dir_path = os.path.split(current_file_path)[0]
+ pd_cmd = "python3 -m paddle.distributed.launch "
+ py_script = os.path.join(current_dir_path, "infer.py")
+
+ arguments = (f" --devices {self.cfg.device_ids} {py_script} --model_dir {self.cfg.model_dir}"
+ f" --max_batch_size {self.cfg.max_batch_size} --max_seq_len {self.cfg.max_seq_len}"
+ f" --max_dec_len {self.cfg.max_dec_len}"
+ f" --max_block_num {self.cfg.total_block_num} --block_size {self.cfg.block_size}"
+ f" --use_cache_kv_int8 {self.cfg.use_cache_kv_int8}"
+ f" --enc_dec_block_num {self.cfg.enc_dec_block_num}"
+ f" --block_ratio {self.cfg.block_ratio} --dtype {self.cfg.dtype}")
+ pd_cmd = pd_cmd + arguments + " >log/launch_infer.log 2>&1"
+ model_server_logger.info("Launch infer service command: {}".format(pd_cmd))
+ p = subprocess.Popen(
+ pd_cmd,
+ shell=True,
+ preexec_fn=os.setsid,
+ )
+ return p
+
+ def _start_infer_service(self):
+ """
+ start infer service
+ """
+ return self._start_gpu_infer_service()
diff --git a/llm/server/server/server/engine/infer.py b/llm/server/server/server/engine/infer.py
new file mode 100644
index 000000000000..63e87e425058
--- /dev/null
+++ b/llm/server/server/server/engine/infer.py
@@ -0,0 +1,591 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import copy
+import json
+import os
+import sys
+import time
+from concurrent.futures import ThreadPoolExecutor
+from multiprocessing import shared_memory
+
+import numpy as np
+import paddle
+import paddle.distributed as dist
+import paddle.distributed.fleet as fleet
+from paddlenlp.trl.llm_utils import get_rotary_position_embedding
+from paddlenlp_ops import step_paddle
+from server.data.processor import DataProcessor
+from server.engine.config import Config
+from server.utils import get_logger
+from task_queue_manager import TaskQueueManager
+
+File_Path = os.path.realpath(sys.argv[0])
+Dir_Path = os.path.dirname(File_Path)
+logger = get_logger("infer_server", "infer.log")
+
+
+class ModelRunner:
+ def __init__(self, args):
+ self.args = args
+
+ # 2**63 - 1
+ self.MAX_INFER_SEED = 9223372036854775806
+
+ self.config = Config()
+ self.model_cfg = self.config.get_model_config()
+ self.format_print_configuration()
+
+ self.args.num_layers = self.get_value(self.model_cfg, ["num_hidden_layers", "num_layers"])
+ self.args.num_attention_heads = self.get_value(self.model_cfg, ["num_attention_heads", "n_head"])
+ self.args.hidden_size = self.model_cfg["hidden_size"]
+
+ self.nranks = dist.get_world_size()
+ self.init_dist_env()
+ self.rank = fleet.worker_index()
+
+ self.load_model_init_val()
+
+ self.share_inputs = {}
+ self.cache_kvs = {}
+ self.init_inputs()
+
+ self.infer_queue = TaskQueueManager(rank=self.rank, mp_num=self.nranks, port=self.config.infer_port)
+
+ model_rank_path = os.path.join(self.args.model_dir, f"rank_{self.rank}")
+ if not os.path.exists(model_rank_path):
+ model_rank_path = self.args.model_dir
+
+ self.infer_engine = InferenceEngine(model_dir=model_rank_path,
+ share_inputs=self.share_inputs,
+ cache_kvs=self.cache_kvs,
+ config=self.config,
+ mp_degree=self.nranks
+ )
+
+ def read_model_config(self):
+ """
+ load model config file from json file
+
+ Returns:
+ model_config_json: dict, model config file
+ """
+ model_config_json = json.load(open(self.config_file, 'r', encoding='utf-8'))
+ return model_config_json
+
+ def get_value(self, cfg, names):
+ """
+ get value from config file by key names
+ """
+ if not isinstance(names, list):
+ names = [names]
+ for name in names:
+ if name in cfg:
+ return cfg[name]
+ break
+ raise Exception(
+ "Cannot find any one of key in {} in configuration file.".format(
+ names))
+
+ def format_print_configuration(self):
+ """
+ print model config
+ """
+ logger.info("=============== Model Information ==============")
+ for k, v in self.model_cfg.items():
+ logger.info("{:<20}:{:<6}{}".format(k, "", v))
+ logger.info("=============== Service Configuration ===============")
+ for k, v in vars(self.args).items():
+ logger.info("{:<20}:{:<6}{}".format(k, "", v))
+ logger.info("=====================================================\n")
+
+ def load_model_init_val(self):
+ """
+ initialize model config from config file
+ """
+ self.top_p = self.model_cfg.get("top_p", 0.0)
+ self.temperature = self.model_cfg.get("temperature", 1.0)
+ self.rope_theta = self.model_cfg.get('rope_theta', 10000.0)
+ self.rope_scaling = self.model_cfg.get('rope_scaling', None)
+ self.penalty_score = self.model_cfg.get('penalty_score', 1.0)
+ self.frequency_score = self.model_cfg.get('frequency_score', 0.0)
+ self.presence_score = self.model_cfg.get('presence_score', 0.0)
+ self.min_length = self.model_cfg.get('min_length', 1)
+ self.max_length = self.model_cfg.get('max_length', 1024)
+
+ data_processor = DataProcessor()
+ # reserve an eos token for request
+ self.eos_tokens_lens = data_processor.get_eos_tokens_lens() + 1
+ self.pad_token_id = data_processor.get_pad_id()
+
+ def init_dist_env(self, seed=20):
+ """
+ init distributed env
+ """
+ strategy = fleet.DistributedStrategy()
+
+ strategy.hybrid_configs = {
+ "dp_degree": 1,
+ "mp_degree": self.nranks,
+ "pp_degree": 1,
+ "sharding_degree": 1,
+ }
+
+ # Set control in tensor parallel
+ strategy.tensor_parallel_configs = {"tensor_init_seed": seed}
+ fleet.init(is_collective=True, strategy=strategy)
+
+ def init_inputs(self):
+ # init all inputs
+ if "num_key_value_heads" in self.model_cfg and \
+ self.model_cfg["num_key_value_heads"] is not None and \
+ int(self.model_cfg["num_key_value_heads"]) > 0:
+ kv_num_head = int(self.model_cfg["num_key_value_heads"]) // self.nranks
+ else:
+ kv_num_head = self.args.num_attention_heads // self.nranks
+
+ for i in range(self.args.num_layers):
+ if not self.args.use_cache_kv_int8:
+ cache_type = self.args.dtype
+ else:
+ cache_type = "uint8"
+
+ self.cache_kvs["key_caches_{}".format(i)] = paddle.full(shape=[
+ self.args.max_block_num, kv_num_head,
+ self.args.block_size, self.args.hidden_size // self.args.num_attention_heads
+ ], fill_value=0, dtype=cache_type)
+ self.cache_kvs["value_caches_{}".format(i)] = paddle.full(shape=[
+ self.args.max_block_num, kv_num_head,
+ self.args.block_size, self.args.hidden_size // self.args.num_attention_heads
+ ], fill_value=0, dtype=cache_type)
+
+ pre_max_block_num = (self.args.max_seq_len + self.args.block_size - 1) // self.args.block_size + self.args.enc_dec_block_num
+ self.share_inputs["block_tables"] = paddle.full(
+ shape=[self.args.max_batch_size, pre_max_block_num], fill_value=-1, dtype="int32")
+
+ self.share_inputs['pre_ids'] = paddle.to_tensor(
+ np.full((self.args.max_batch_size, self.args.max_dec_len), -1, dtype='int64'))
+
+ tmp_position_ids = paddle.arange(self.args.max_seq_len).reshape((1, -1))
+ self.share_inputs['rope_emb'] = get_rotary_position_embedding(tmp_position_ids,
+ self.args.hidden_size // self.args.num_attention_heads,
+ self.rope_theta, self.rope_scaling)
+ self.share_inputs['input_ids'] = paddle.full(
+ shape=[self.args.max_batch_size, self.args.max_seq_len],
+ fill_value=self.pad_token_id, dtype='int64')
+ self.share_inputs['top_p'] = paddle.full(
+ shape=[self.args.max_batch_size, 1], fill_value=self.top_p, dtype="float32")
+ self.share_inputs['temperature'] = paddle.full(
+ shape=[self.args.max_batch_size, 1], fill_value=self.temperature, dtype="float32")
+ self.share_inputs['eos_token_id'] = paddle.to_tensor(
+ np.zeros((self.eos_tokens_lens, 1)).reshape(-1, 1).astype("int64"))
+ self.share_inputs['penalty_score'] = paddle.full(
+ shape=[self.args.max_batch_size, 1], fill_value=self.penalty_score, dtype="float32")
+ self.share_inputs['frequency_score'] = paddle.full(
+ shape=[self.args.max_batch_size, 1], fill_value=self.frequency_score, dtype="float32")
+ self.share_inputs['presence_score'] = paddle.full(
+ shape=[self.args.max_batch_size, 1], fill_value=self.presence_score, dtype="float32")
+ self.share_inputs['seq_lens_this_time'] = paddle.full(
+ shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
+ self.share_inputs['seq_lens_encoder'] = paddle.full(
+ shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
+ self.share_inputs['step_seq_lens_encoder'] = paddle.full(
+ shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
+ self.share_inputs['seq_lens_decoder'] = paddle.full(
+ shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int32")
+ self.share_inputs['step_idx'] = paddle.full(
+ shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int64")
+ self.share_inputs['min_length'] = paddle.full(
+ shape=[self.args.max_batch_size, 1], fill_value=self.min_length, dtype="int64")
+ self.share_inputs['max_length'] = paddle.full(
+ shape=[self.args.max_batch_size, 1], fill_value=self.max_length, dtype="int64")
+ self.share_inputs['not_need_stop'] = paddle.full(
+ shape=[1], fill_value=False, dtype="bool")
+ self.share_inputs['stop_flags'] = paddle.full(
+ shape=[self.args.max_batch_size, 1], fill_value=True, dtype="bool")
+ self.share_inputs['stop_nums'] = paddle.full(
+ shape=[1], fill_value=self.args.max_batch_size, dtype="int64")
+ self.share_inputs['bad_tokens'] = paddle.full(
+ shape=[1], fill_value=-1, dtype="int64")
+ self.share_inputs['next_tokens'] = paddle.full(
+ shape=[self.args.max_batch_size, 1], fill_value=-1, dtype="int64")
+ self.share_inputs['is_block_step'] = paddle.full(
+ shape=[self.args.max_batch_size], fill_value=False, dtype="bool")
+ self.share_inputs['encoder_block_lens'] = paddle.full(
+ shape=[self.args.max_batch_size], fill_value=0, dtype="int32")
+ self.share_inputs['step_block_list'] = paddle.full(
+ shape=[self.args.max_batch_size], fill_value=-1, dtype="int32")
+ self.share_inputs['step_lens'] = paddle.full(shape=[1], fill_value=0, dtype="int32")
+ self.share_inputs['recover_block_list'] = paddle.full(
+ shape=[self.args.max_batch_size], fill_value=-1, dtype="int32")
+ self.share_inputs['recover_lens'] = paddle.full(
+ shape=[1], fill_value=0, dtype="int32")
+ self.share_inputs['need_block_list'] = paddle.full(
+ shape=[self.args.max_batch_size], fill_value=-1, dtype="int32")
+ self.share_inputs['need_block_len'] = paddle.full(
+ shape=[1], fill_value=0, dtype="int32")
+ self.share_inputs['used_list_len'] = paddle.full(
+ shape=[self.args.max_batch_size], fill_value=0, dtype="int32")
+ self.share_inputs['infer_seed'] = paddle.full(
+ shape=[self.args.max_batch_size, 1], fill_value=0, dtype="int64")
+ free_list = list(range(int(self.args.max_block_num * self.args.block_ratio)))
+ self.free_list_len = len(free_list)
+ self.share_inputs['free_list'] = paddle.to_tensor(free_list, dtype="int32")
+ self.share_inputs['free_list_len'] = paddle.full(
+ shape=[1], fill_value=self.free_list_len, dtype="int32")
+
+ def dy_input_preprocess(self, tasks):
+ """
+ dynamic insertion
+ """
+ for i in range(len(tasks)):
+ task = tasks[i]
+ idx = task['idx']
+ length = len(task['input_ids'])
+ self.share_inputs['input_ids'][idx:idx + 1, :length] = np.array(task['input_ids'])
+ if len(task['eos_token_ids']) < self.eos_tokens_lens:
+ task['eos_token_ids'].append(task['eos_token_ids'][0])
+ self.share_inputs['eos_token_id'][:] = np.array(task['eos_token_ids'], dtype="int64").reshape(-1, 1)
+ self.share_inputs['pre_ids'][idx:idx + 1] = -1
+ self.share_inputs['top_p'][idx:idx + 1] = task.get('topp', 0.7)
+ self.share_inputs['temperature'][idx:idx + 1] = task.get('temperature', 0.95)
+ self.share_inputs['penalty_score'][idx:idx + 1] = task.get('penalty_score', 1.0)
+ self.share_inputs['frequency_score'][idx:idx + 1] = task.get('frequency_score', 0.0)
+ self.share_inputs['presence_score'][idx:idx + 1] = task.get('presence_score', 0.0)
+ self.share_inputs['seq_lens_this_time'][idx:idx + 1] = length
+ self.share_inputs['step_seq_lens_encoder'][idx:idx + 1] = length
+ self.share_inputs['seq_lens_encoder'][idx:idx + 1] = length
+ self.share_inputs['seq_lens_decoder'][idx:idx + 1] = 0
+ self.share_inputs['step_idx'][idx:idx + 1] = 0
+ self.share_inputs['min_length'][idx:idx + 1] = task.get('min_dec_len', 1)
+ if "max_dec_len" in task:
+ max_dec_len = task['max_dec_len']
+ elif "seq_len" in task:
+ max_dec_len = task['seq_len']
+ else:
+ max_dec_len = self.args.max_dec_len
+ self.share_inputs['max_length'][idx:idx + 1] = max_dec_len
+ self.share_inputs['stop_flags'][idx:idx + 1] = False
+
+ if "infer_seed" in task:
+ self.share_inputs['infer_seed'][idx:idx + 1] = task['infer_seed']
+
+ encoder_block_num = len(task['block_tables'])
+ self.share_inputs['encoder_block_lens'][idx:idx + 1] = encoder_block_num
+ self.share_inputs["block_tables"][idx:idx + 1, :] = -1
+ self.share_inputs["block_tables"][idx:idx + 1, :encoder_block_num] = np.array(
+ task['block_tables'], dtype="int32")
+
+ def step_cuda(self, seq_lens_this_time):
+ """
+ step cuda
+ """
+ step_paddle(self.share_inputs['stop_flags'], seq_lens_this_time,
+ self.share_inputs['step_seq_lens_encoder'],
+ self.share_inputs['seq_lens_encoder'],
+ self.share_inputs['seq_lens_decoder'], self.share_inputs["block_tables"],
+ self.share_inputs['encoder_block_lens'],
+ self.share_inputs["is_block_step"], self.share_inputs['step_block_list'],
+ self.share_inputs['step_lens'], self.share_inputs['recover_block_list'],
+ self.share_inputs['recover_lens'], self.share_inputs['need_block_list'],
+ self.share_inputs['need_block_len'], self.share_inputs['used_list_len'],
+ self.share_inputs['free_list'], self.share_inputs['free_list_len'],
+ self.share_inputs['input_ids'], self.share_inputs['pre_ids'],
+ self.share_inputs['step_idx'], self.share_inputs['next_tokens'],
+ self.args.block_size, self.args.enc_dec_block_num, self.args.first_token_id)
+
+ def initialize_engine_ready_check_flag(self):
+ """
+ initialize engine ready flag in shared memory
+
+ Returns:
+ shm_engine_ready_check_flag: engine ready flag
+ engine_ready_check_flag_array: engine ready flag array
+ """
+ engine_ready_check_flag = np.zeros([1], dtype=np.int32)
+ shm_engine_ready_check_flag = shared_memory.SharedMemory(
+ name=self.config.get_unique_name("engine_ready_check_flag"))
+ engine_ready_check_flag_array = np.ndarray(engine_ready_check_flag.shape,
+ dtype=engine_ready_check_flag.dtype,
+ buffer=shm_engine_ready_check_flag.buf)
+ return shm_engine_ready_check_flag, engine_ready_check_flag_array
+
+ def initialize_engine_live_flag(self):
+ """
+ initialize infer live flag in shared memory
+
+ Returns:
+ infer_live_flag_shm: infer live flag
+ """
+ infer_live_flag_shm = shared_memory.SharedMemory(create=True,
+ size=1,
+ name=self.config.get_unique_name("shm_flag_infer_{}_live".format(self.rank)))
+ return infer_live_flag_shm
+
+ def initialize_engine_healthy_recorded_time_flag(self):
+ """
+ initialize engine healthy recorded time flag in shared memory
+
+ Returns:
+ shm_engine_healthy_recorded_time: engine healthy recorded time flag
+ """
+ engine_healthy_recorded_time = np.zeros([1], dtype=float)
+ shm_engine_healthy_recorded_time = shared_memory.SharedMemory(
+ name=self.config.get_unique_name("engine_healthy_recorded_time"))
+ engine_healthy_recorded_time_array = np.ndarray(engine_healthy_recorded_time.shape,
+ dtype=engine_healthy_recorded_time.dtype,
+ buffer=shm_engine_healthy_recorded_time.buf)
+ return shm_engine_healthy_recorded_time, engine_healthy_recorded_time_array
+
+ def run(self):
+ """
+ run infer
+ """
+ flag_array = np.zeros([1], dtype=np.int32)
+ shm_flag_broadcast = shared_memory.SharedMemory(
+ name=self.config.get_unique_name("shm_pd_infer_flag_broadcast"))
+ flag_broadcast_array = np.ndarray(flag_array.shape,
+ dtype=flag_array.dtype,
+ buffer=shm_flag_broadcast.buf)
+
+ flag_array = np.zeros([self.nranks], dtype=np.int32)
+ shm_flag_ready = shared_memory.SharedMemory(name=self.config.get_unique_name("shm_flag_infer_ready"))
+ flag_ready_array = np.ndarray(flag_array.shape,
+ dtype=flag_array.dtype,
+ buffer=shm_flag_ready.buf)
+ flag_ready_array[self.rank] = 1
+
+ flag_array = np.zeros([1], dtype=np.int32)
+ shm_flag_has_block_step = shared_memory.SharedMemory(name=self.config.get_unique_name("shm_flag_has_block_step"))
+ flag_has_block_step_array = np.ndarray(flag_array.shape,
+ dtype=flag_array.dtype,
+ buffer=shm_flag_has_block_step.buf)
+
+ use_custom_health_checker = self.config.use_custom_health_checker
+ if use_custom_health_checker:
+ shm_engine_ready_check_flag_array, engine_ready_check_flag_array = self.initialize_engine_ready_check_flag()
+ engine_ready_check_flag_array[0] = 1
+ shm_engine_healthy_recorded_time_array, engine_healthy_recorded_time_array = self.initialize_engine_healthy_recorded_time_flag()
+ engine_healthy_recorded_time_array[0] = time.time()
+ infer_live_flag_shm = self.initialize_engine_live_flag()
+ infer_seed_increment = paddle.full(shape=[self.args.max_batch_size, 1],
+ fill_value=4,
+ dtype="int64")
+ thread_executor = ThreadPoolExecutor(max_workers=1)
+ seq_lens_this_time = None
+ real_bsz = None
+
+ while True:
+ if use_custom_health_checker:
+ engine_healthy_recorded_time_array[0] = time.time()
+
+ if self.rank == 0:
+ if not self.infer_queue.empty():
+ flag_broadcast_array[0] = 1
+
+ if self.nranks > 1:
+ paddle.distributed.barrier()
+
+ if flag_broadcast_array[0] == 1:
+ logger.info(f'rank: {self.rank} start to get')
+ if seq_lens_this_time is not None:
+ self.share_inputs["seq_lens_this_time"][:real_bsz] = seq_lens_this_time
+
+ tasks, read_finish = self.infer_queue.get()
+ if read_finish:
+ flag_broadcast_array[0] = 0
+
+ req_dicts = []
+ for req_dict, bsz in tasks:
+ real_bsz = int(bsz)
+ req_dicts.extend(req_dict)
+ logger.info(
+ f'rank: {self.rank}, real_bsz: {real_bsz}, query_num: {len(req_dicts)}'
+ )
+
+ self.dy_input_preprocess(req_dicts)
+ seq_lens_this_time = copy.deepcopy(
+ self.share_inputs['seq_lens_this_time'][:real_bsz])
+ self.infer_engine.seq_lens_handle.share_external_data(
+ seq_lens_this_time)
+ self.share_inputs['not_need_stop'][0] = True
+
+ if not self.share_inputs['not_need_stop']:
+ if self.nranks > 1:
+ paddle.distributed.barrier()
+
+ time.sleep(0.001)
+ continue
+
+ self.infer_engine.predictor.run()
+ self.share_inputs['infer_seed'].add_(infer_seed_increment)
+ self.share_inputs['infer_seed'][:] %= self.MAX_INFER_SEED
+ if self.free_list_len > 0:
+ self.step_cuda(seq_lens_this_time)
+
+
+class InferenceEngine(object):
+ """
+ Model Parallel Inference Engine
+
+ Args:
+ model_dir (string): root directory of inference model
+ mp_degree (int): model parallel size
+ """
+ def __init__(self, model_dir, share_inputs, cache_kvs, config, mp_degree=1):
+ self.config = config
+ self.model_dir = model_dir
+ self.mp_degree = mp_degree
+
+ self.share_inputs = share_inputs
+ self.cache_kvs = cache_kvs
+
+ if mp_degree == 1:
+ self.nranks = 1
+ self.rank = 0
+ else:
+ self.nranks = fleet.worker_num()
+ self.rank = fleet.worker_index()
+
+ self._init_predictor()
+ self.share_data()
+
+ def _init_predictor(self):
+ """
+ predictor init
+ """
+ device_id = self.rank % 8
+ self.model_file = os.path.join(self.model_dir, f"model.pdmodel")
+ self.param_file = os.path.join(self.model_dir, f"model.pdiparams")
+ config = paddle.inference.Config(self.model_file, self.param_file)
+
+ config.switch_ir_optim(False)
+ config.enable_use_gpu(100, device_id)
+
+ # distributed config
+ if self.mp_degree > 1:
+ trainer_endpoints = fleet.worker_endpoints()
+ current_endpoint = trainer_endpoints[self.rank]
+ dist_config = config.dist_config()
+ dist_config.set_ranks(self.nranks, self.rank)
+ dist_config.set_endpoints(trainer_endpoints, current_endpoint)
+ dist_config.enable_dist_model(True)
+ if self.config.distributed_config_path:
+ dist_config.set_comm_init_config(self.config.distributed_config_path)
+ else:
+ raise Exception("Please set DISTRIBUTED_CONFIG env variable.")
+ logger.warning(
+ f"Use default distributed config, please set env DISTRIBUTED_CONFIG"
+ )
+ dist_config.set_comm_init_config(
+ os.path.join(Dir_Path + "/config", "rank_mapping_mp{}.csv".format(self.nranks)))
+
+ config.set_dist_config(dist_config)
+ self.predictor = paddle.inference.create_predictor(config)
+ self.input_names = self.predictor.get_input_names()
+ self.seq_lens_handle = self.predictor.get_input_handle('seq_lens_this_time')
+
+ def share_data(self):
+ """
+ share data
+ """
+ for name in self.input_names:
+ if "caches" in name:
+ input_tensor = self.predictor.get_input_handle(name)
+ input_tensor.share_external_data(self.cache_kvs[name])
+ continue
+ if "seq_lens_this_time" in name:
+ continue
+ input_tensor = self.predictor.get_input_handle(name)
+ input_tensor.share_external_data(self.share_inputs[name])
+
+ def predict(self, real_bsz):
+ """
+ predict
+ """
+ seq_lens_this_time = copy.deepcopy(
+ self.share_inputs['seq_lens_this_time'][:real_bsz])
+ self.seq_lens_handle.share_external_data(seq_lens_this_time)
+ self.share_inputs['not_need_stop'][0] = True
+ while self.share_inputs['not_need_stop']:
+ self.predictor.run()
+ self.share_inputs["seq_lens_this_time"][:real_bsz] = seq_lens_this_time
+
+
+def parse_args():
+ """
+ parse args from command line
+ """
+ parser = argparse.ArgumentParser("Deploy LLM Inference")
+ parser.add_argument('-m',
+ '--model_dir',
+ type=str,
+ default='./output',
+ help='model dir')
+ parser.add_argument('-mp',
+ '--mp_degree',
+ type=int,
+ default=1,
+ help='mp degree')
+ parser.add_argument('-mbs',
+ '--max_batch_size',
+ type=int,
+ default=34,
+ help='max batch size')
+ parser.add_argument('--max_block_num', type=int, default=2000)
+ parser.add_argument("--block_size", type=int, default=128)
+ parser.add_argument('--max_seq_len',
+ type=int,
+ default=3072,
+ help='max_seq_len')
+ parser.add_argument('--max_dec_len',
+ type=int,
+ default=1024,
+ help='max_dec_len')
+ parser.add_argument('--use_cache_kv_int8',
+ type=int,
+ default=0,
+ help='use cache kv int8')
+ parser.add_argument('--dtype',
+ type=str,
+ default="bfloat16",
+ help='input dtype')
+ parser.add_argument('--enc_dec_block_num',
+ type=int,
+ default=1,
+ help="encoder's decoder num")
+ parser.add_argument('--block_ratio',
+ type=float,
+ default=0.7,
+ help="block ratio")
+ parser.add_argument('--first_token_id',
+ type=int,
+ default=1,
+ help="first token id")
+ args = parser.parse_args()
+ return args
+
+
+def main():
+ """
+ start model runner
+ """
+ args = parse_args()
+ model_runner = ModelRunner(args)
+ model_runner.run()
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llm/server/server/server/engine/resource_manager.py b/llm/server/server/server/engine/resource_manager.py
new file mode 100644
index 000000000000..148e610b287b
--- /dev/null
+++ b/llm/server/server/server/engine/resource_manager.py
@@ -0,0 +1,241 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import copy
+import os
+import random
+import threading
+import time
+
+import numpy as np
+from server.utils import model_server_logger
+
+
+class ResourceManager(object):
+ """
+ record and allocate resources for the engine
+ """
+ def __init__(self, cfg):
+ self.cfg = cfg
+ self.stop_flags = [True] * cfg.max_batch_size
+ self.free_list = list(range(cfg.max_block_num - 1, -1, -1))
+ self.tasks_list = [None] * self.cfg.max_batch_size
+ # current batch status of the engine
+ self.real_bsz = 0
+ model_server_logger.info(f"{self.info()}")
+
+ def get_required_block_number(self, input_token_num):
+ """
+ Calculate Block resources are needed
+
+ Args:
+ input_token_num (int): input token number
+
+ Returns:
+ int: block number
+ """
+ block_num = (input_token_num + self.cfg.block_size - 1 + self.cfg.dec_token_num) // self.cfg.block_size
+ return block_num
+
+ def get_encoder_block_number(self, input_token_num):
+ """
+ get the number of blocks for the encoder
+
+ Args:
+ input_token_num (int): input token number
+
+ Returns:
+ int: encoder block number
+ """
+ enc_block_num = (input_token_num + self.cfg.block_size - 1) // self.cfg.block_size
+ return enc_block_num
+
+ def get_decoder_block_number(self):
+ """
+ get the number of blocks for the decoder
+
+ Returns:
+ int: decoder block number
+ """
+ return (self.cfg.dec_token_num + self.cfg.block_size - 1) // self.cfg.block_size
+
+ def total_block_number(self):
+ """
+ the number of pre allocated blocks at service startup
+
+ Returns:
+ int: total block number
+ """
+ return self.cfg.max_block_num
+
+ def _get_block_tables(self, input_token_num, required_type="all"):
+ """
+ allocate memory resources
+
+ Args:
+ input_token_num (int): input token number
+ required_type (str): required type
+
+ Returns:
+ list: block list
+ """
+ if required_type == "all":
+ block_num = self.get_required_block_number(input_token_num)
+ elif required_type == "encoder":
+ block_num = self.get_encoder_block_number(input_token_num)
+ elif required_type == "decoder":
+ block_num = self.get_decoder_block_number()
+ else:
+ raise ValueError('unknown required type')
+ block_num = min(block_num, self.cfg.max_query_block_num)
+ block_list = list()
+ if block_num > len(self.free_list):
+ model_server_logger.error("block_num:{0} > free_list len:{1}".format(block_num, len(self.free_list)))
+ return block_list
+ for _ in range(block_num):
+ used_block_id = self.free_list.pop()
+ block_list.append(used_block_id)
+ model_server_logger.info(f"dispatch {len(block_list)} blocks.")
+ return block_list
+
+ def _recycle_block_tables(self, block_tables):
+ """
+ Recycling memory resource blocks
+
+ Args:
+ block_tables (list): block list
+ """
+ ori_number = len(self.free_list)
+ self.free_list.extend(block_tables)
+ cur_number = len(self.free_list)
+ model_server_logger.info(f"recycle {cur_number - ori_number} blocks.")
+
+ def available_batch(self):
+ """
+ available batch size for engine
+
+ Returns:
+ int: available batch size
+ """
+ return np.sum(self.stop_flags)
+
+ def availabel_block_num(self):
+ """
+ available block size for engine
+
+ Returns:
+ int: available block size
+ """
+ return len(self.free_list)
+
+ def is_resource_sufficient(self, input_token_num):
+ """
+ check current available resources meet the new requirements
+
+ Args:
+ input_token_num (int): input token number
+
+ Returns:
+ bool: whether current available resources meet the new requirements
+ """
+ if self.available_batch() < 1:
+ return False
+ block_num = self.get_required_block_number(input_token_num)
+ if block_num > self.availabel_block_num():
+ return False
+ return True
+
+ def allocate_resources_for_new_tasks(self, tasks):
+ """
+ allocate resources for new tasks
+
+ Args:
+ tasks (list): task list
+
+ Returns:
+ list: processed task list
+ """
+
+ allocated_position = 0
+ processing_task_index = 0
+ processed_tasks = list()
+ while allocated_position < self.cfg.max_batch_size:
+ if processing_task_index >= len(tasks):
+ break
+
+ if len(tasks[processing_task_index]["input_ids"]) > self.cfg.max_seq_len:
+ model_server_logger.error("req_id: {0} input_ids len:{1} > {2}".format(
+ tasks[
+ processing_task_index]["req_id"], len(tasks[
+ processing_task_index]["input_ids"]), self.cfg.max_seq_len
+ ))
+ processing_task_index += 1
+ continue
+
+ can_insert = False
+ while allocated_position + 1 <= self.cfg.max_batch_size:
+ if sum(self.stop_flags[allocated_position : allocated_position + 1]) == 1:
+ can_insert = True
+ break
+ allocated_position += 1
+ if can_insert:
+ if self.stop_flags[allocated_position]:
+ task = copy.deepcopy(tasks[processing_task_index])
+
+ if not isinstance(task["eos_token_ids"], list):
+ task["eos_token_ids"] = [task["eos_token_ids"]]
+
+ if "infer_seed" in task and task["infer_seed"]:
+ task["infer_seed"] = int(task["infer_seed"])
+ else:
+ task["infer_seed"] = random.randint(0, 9223372036854775807)
+ task["idx"] = allocated_position
+ task["block_tables"] = self._get_block_tables(len(task["input_ids"]))
+ if not task["block_tables"]:
+ model_server_logger.error("req_id: {0} block_tables is empty".format(task["req_id"]))
+ continue
+
+ processed_tasks.append(task)
+ self.stop_flags[allocated_position] = False
+ task["inference_start_time"] = time.time()
+ task["inference_time_cost"] = -1.0
+ task["tokens_all_num"] = int(0)
+ self.tasks_list[allocated_position] = task
+ model_server_logger.info(f"allocate req_id: {task['req_id']}, "
+ f"allocated_position:{allocated_position}, input_ids_length: {len(task['input_ids'])}")
+ allocated_position += 1
+ processing_task_index += 1
+
+ # batch size when the statistical engine is inferring
+ for i in range(self.cfg.max_batch_size - 1, -1, -1):
+ if not self.stop_flags[i]:
+ self.real_bsz = i + 1
+ break
+
+ model_server_logger.info("in num:{0} new task num:{1} real_bsz is:{2}".format(
+ len(tasks), len(processed_tasks), self.real_bsz))
+ model_server_logger.info(f"{self.info()}")
+ return processed_tasks
+
+ def info(self):
+ """
+ get resource manager info
+
+ Returns:
+ str: resource manager info
+ """
+ info = f"ResourceManager info, " \
+ f"total_block_number: {self.total_block_number()}, total_batch_number: {len(self.stop_flags)}, " \
+ f"availabel_block_num: {self.availabel_block_num()}, available_batch: {self.available_batch()}"
+ return info
diff --git a/llm/server/server/server/engine/task_queue_manager.py b/llm/server/server/server/engine/task_queue_manager.py
new file mode 100644
index 000000000000..475365d47fba
--- /dev/null
+++ b/llm/server/server/server/engine/task_queue_manager.py
@@ -0,0 +1,165 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import threading
+import time
+from multiprocessing.managers import (AcquirerProxy, BaseManager, ListProxy,
+ Value, ValueProxy)
+from queue import Queue
+
+from server.utils import get_logger
+
+logger = get_logger("infer_server", "task_queue_manager.log")
+
+
+class QueueManager(BaseManager):
+ """
+ base class for queue manager
+ """
+
+ pass
+
+
+class TaskQueueManager(object):
+ """
+ task queue manager
+ """
+
+ def __init__(self, rank=0, mp_num=8, port=56666):
+ """
+ Initialization function, used to perform initialization
+ operations when creating objects
+ """
+ self.max_get_num = int(os.getenv("ENGINE_MAX_NEED_NUM", 0))
+ QueueManager.register('get_list')
+ QueueManager.register('get_value')
+ QueueManager.register('get_lock')
+ QueueManager.register('get_barrier1')
+ QueueManager.register('get_barrier2')
+ QueueManager.register('get_queue')
+
+ self.client_manager = QueueManager(address=('127.0.0.1', port),
+ authkey=b'infer_queue'
+ )
+ self.client_manager.connect()
+ self.list = self.client_manager.get_list()
+ self.value = self.client_manager.get_value()
+ self.lock = self.client_manager.get_lock()
+ self.barrier1 = self.client_manager.get_barrier1()
+ self.barrier2 = self.client_manager.get_barrier2()
+ self.queue = self.client_manager.get_queue()
+ self.mp_num = mp_num
+ self.rank = rank
+ self.position = 1 << rank
+ self.total_num = (1 << self.mp_num) - 1
+ logger.info(f"init task queue manager success, rank: {rank}")
+
+ def empty(self):
+ """
+ check the queue is empty for infer
+
+ Returns:
+ bool: True if the queue is empty, otherwise False
+ """
+ try:
+ return len(self.list) == 0
+ except Exception as e:
+ logger.error(f"empty function meets error: {e}")
+ raise e
+
+ def put(self, item):
+ """
+ put item to queue
+
+ Args:
+ item (any): the item to put into queue
+ """
+ self.lock.acquire()
+ if 0 < self.value.get() < self.total_num:
+ self.lock.release()
+ while 0 < self.value.get() < self.total_num:
+ time.sleep(0.001)
+ logger.info("put item to queue wait finish")
+ self.lock.acquire()
+ if self.max_get_num <= 0 and self.value.get() == self.total_num:
+ self.list[:] = []
+ self.value.set(0)
+ self.list.append(item)
+ self.lock.release()
+ logger.info("put item to queue success")
+
+ def get(self):
+ """
+ get item from queue
+
+ Returns:
+ list: the item from queue
+ bool: True if the queue is empty, otherwise False
+ """
+ input_list = []
+ read_finish = False
+ self.lock.acquire()
+ if self.value.get() & self.position == 0 and len(self.list) > 0:
+ if self.max_get_num > 0:
+ input_list.extend(self.list[: self.max_get_num])
+ else:
+ input_list.extend(self.list[:])
+ set_value = self.value.get() | self.position
+ logger.info("rank: {0} set_value: {1}".format(self.rank, set_value))
+ if set_value >= self.total_num:
+ if self.max_get_num > 0:
+ for i in range(self.max_get_num):
+ self.list.pop(0)
+ else:
+ self.list[:] = []
+ set_value = 0
+ read_finish = True
+ self.value.set(set_value)
+ self.lock.release()
+ return input_list, read_finish
+
+
+def launch_queue_service(port, num_workers):
+ """
+ Start the process communication queue service
+
+ Args:
+ port (int): the port to listen
+ num_workers (int): the number of infer process
+ """
+ try:
+ logger.info(f"start launch queue service, port:{port}")
+ value = Value("i", 0)
+ QueueManager.register("get_value", callable=lambda: value, proxytype=ValueProxy)
+ List = list()
+ QueueManager.register("get_list", callable=lambda: List, proxytype=ListProxy)
+ lock = threading.Lock()
+ QueueManager.register('get_lock',
+ callable=lambda: lock,
+ proxytype=AcquirerProxy)
+ barrier1 = threading.Barrier(num_workers)
+ QueueManager.register('get_barrier1', callable=lambda: barrier1)
+ barrier2 = threading.Barrier(num_workers)
+ QueueManager.register('get_barrier2', callable=lambda: barrier2)
+ q = Queue()
+ QueueManager.register("get_queue", callable=lambda: q)
+ m = QueueManager(address=('127.0.0.1', port), authkey=b'infer_queue')
+ s = m.get_server()
+ logger.info("launch queue service success")
+ s.serve_forever()
+ logger.info("finish queue service")
+ except Exception as e:
+ logger.error(f"launch queue service failed, error_msg: {e}")
+ raise e
diff --git a/llm/server/server/server/engine/token_processor.py b/llm/server/server/server/engine/token_processor.py
new file mode 100644
index 000000000000..507a3d43bdf9
--- /dev/null
+++ b/llm/server/server/server/engine/token_processor.py
@@ -0,0 +1,246 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import os
+import threading
+import time
+import traceback
+from collections import Counter
+from datetime import datetime
+
+import numpy as np
+from paddlenlp_ops import get_output
+from server.utils import datetime_diff, model_server_logger, monitor_logger
+
+
+class TokenProcessor(object):
+ """
+ get Token/Score from Paddle inference engine
+ """
+ def __init__(self, cfg):
+ import paddle
+ paddle.device.set_device("cpu")
+ self.cfg = cfg
+ self.resource_manager = None
+ # record all tokens for each request
+ self.all_tokens = [[] for _ in range(self.cfg.max_batch_size)]
+
+ self.tokens_counter = Counter()
+ self.output_tokens = paddle.full(shape=[self.cfg.max_batch_size + 2, 1], fill_value=2, dtype="int64")
+ self.worker = None
+
+ self.record_time_interval = int(os.getenv("RECORD_TIME_INTERVAL", "600"))
+ assert self.record_time_interval < 3600, "The RECORD_TIME_INTERVAL cannot exceed 3600."
+ self.statics_start_time = time.time()
+ self.number_of_tasks = 0
+ self.number_of_input_tokens = 0
+ self.number_of_output_tokens = 0
+
+ def set_resource_manager(self, resource_manager):
+ """
+ set ResourceManager
+
+ Args:
+ resource_manager (ResourceManager)
+ """
+ assert self.resource_manager is None, "The resource manager is not None, cannot set again."
+ self.resource_manager = resource_manager
+
+ def run(self):
+ """
+ start thread to get tokens
+ """
+ assert self.resource_manager is not None, "The resource manager is None, cannot run."
+ if self.worker is not None:
+ raise Exception("Worker is already running!")
+
+ self.worker = threading.Thread(target=self.process_sampling_results, args=())
+ self.worker.daemon = True
+ self.worker.start()
+
+ def process_sampling_results(self):
+ """
+ read tokens from paddle inference engine and process
+ """
+ while True:
+ try:
+ rank_id = 0
+ is_blocking = True
+ get_output(self.output_tokens, rank_id, is_blocking)
+
+ if self.output_tokens[0, 0] == -2:
+ continue
+ self._process_batch_output()
+ except Exception as e:
+ model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc())))
+
+ def postprocess(self, batch_result, exist_finished_task=False):
+ """
+ single post-processing function
+
+ Args:
+ batch_result (list): batch results
+ exist_finished_task (bool): whether there is a finished task
+ """
+ result_dir = "./generate_token_results"
+ if not os.path.exists(result_dir):
+ os.makedirs(result_dir)
+ for result in batch_result:
+ result_file = os.path.join(result_dir, result["req_id"])
+ with open(result_file, "a") as f:
+ f.write("{}\n".format(result))
+
+ def _get_single_result(self, i, task_id, token_id, task):
+ """
+ processing single results
+
+ Args:
+ i (int): batch index
+ task_id (str): task id
+ token_id (int): token id
+ task (dict): task information
+
+ Returns:
+ dict: result
+ """
+ inference_time_cost = time.time() - task["inference_start_time"]
+ task["inference_time_cost"] = inference_time_cost
+ task["tokens_all_num"] = len(self.all_tokens[i])
+ task["inference_current_step_time"] = datetime.now()
+ result = {
+ "req_id": task_id,
+ "is_end": 0,
+ "token_ids": [token_id],
+ "send_idx": self.tokens_counter[task_id],
+ "inference_time_cost": inference_time_cost,
+ "infer_seed": task["infer_seed"],
+ "return_all_tokens": task.get("return_all_tokens", False),
+ }
+
+ # get benchmark msg
+ if task.get("benchmark"):
+ keys = ["preprocess_start_time", "preprocess_end_time", "schedule_start_time",
+ "inference_start_time", "inference_current_step_time"]
+ for key in keys:
+ if key in task:
+ result[key] = str(task[key])
+
+ # fill some extra information
+ if token_id in task["eos_token_ids"]:
+ result["is_end"] = 1
+ result["token_ids"] = []
+ result["tokens_all_num"] = len(self.all_tokens[i]) + 1
+ result["tokens_all_ids"] = self.all_tokens[i]
+
+ info_dict = {}
+ info_dict["req_id"] = task["req_id"]
+ info_dict["input_token_num"] = len(task["input_ids"])
+ info_dict["output_token_num"] = len(self.all_tokens[i])
+ if hasattr(task, "preprocess_start_time") and hasattr(task, "preprocess_end_time"):
+ info_dict["preprocess_cost_time"] = datetime_diff(task["preprocess_start_time"],
+ task["preprocess_end_time"])
+ if hasattr(task, "preprocess_end_time") and hasattr(task, "schedule_start_time"):
+ info_dict["cache_waiting_cost_time"] = datetime_diff(task["preprocess_end_time"],
+ task["schedule_start_time"])
+ info_dict["inference_time_cost"] = task["inference_time_cost"]
+ info_dict["version"] = "4.6"
+ info_dict["timestamp"] = time.time()
+ monitor_logger.info(f"{info_dict}")
+
+ return result
+
+ def _recycle_resources(self, task_id, index, task):
+ """
+ recycle resources
+ """
+ self.resource_manager.stop_flags[index] = True
+ self.resource_manager.tasks_list[index] = None
+ self.resource_manager._recycle_block_tables(task["block_tables"])
+ if task_id in self.tokens_counter:
+ del self.tokens_counter[task_id]
+ self.all_tokens[index] = list()
+
+ def _process_batch_output(self):
+ """
+ batch post-processing function
+ """
+ tokens = self.output_tokens.numpy()
+ batch = self.output_tokens[1, 0]
+ tokens = tokens[2:batch + 2]
+
+ batch_result = list()
+ exist_finished_task = False
+ for i in range(batch):
+ if self.resource_manager.stop_flags[i]:
+ continue
+
+ token_id = int(tokens[i, 0])
+ if token_id < 0:
+ continue
+
+ task = self.resource_manager.tasks_list[i]
+
+ task_id = task["req_id"]
+ result = self._get_single_result(i, task_id, token_id, task)
+
+ self.tokens_counter[task_id] += 1
+ if token_id not in task["eos_token_ids"]:
+ self.all_tokens[i].append(token_id)
+
+ self.number_of_output_tokens += 1
+ if token_id in task["eos_token_ids"]:
+ self._recycle_resources(task_id, i, task)
+ model_server_logger.info("req_id: {0} finished".format(task_id))
+ model_server_logger.info(f"{self.resource_manager.info()}")
+ exist_finished_task = True
+ batch_result.append(result)
+
+ self.postprocess(batch_result, exist_finished_task)
+
+
+class WarmUpTokenProcessor(TokenProcessor):
+ """
+ Warmup Processor
+ """
+ def __init__(self, cfg):
+ super().__init__(cfg)
+ self._is_running = True
+ self._is_blocking = True
+
+ def postprocess(self, batch_result, exist_finished_task=False):
+ pass
+
+ def process_sampling_results(self):
+ """
+ get output from model and process it
+ """
+ while self._is_running:
+ try:
+ rank_id = 0
+ get_output(self.output_tokens, rank_id, self._is_blocking)
+
+ if self.output_tokens[0, 0] == -2:
+ continue
+ self._process_batch_output()
+ except Exception as e:
+ model_server_logger.info("while get input_data error: {0} {1}".format(e, str(traceback.format_exc())))
+
+ def stop(self):
+ """
+ stop warm up thread
+ """
+ self._is_running = False
+ self.worker.join()
+ model_server_logger.info("warm up thread stop")
+ del self.worker
diff --git a/llm/server/server/server/http_server/__init__.py b/llm/server/server/server/http_server/__init__.py
new file mode 100644
index 000000000000..fd05a9208165
--- /dev/null
+++ b/llm/server/server/server/http_server/__init__.py
@@ -0,0 +1,13 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
diff --git a/llm/server/server/server/http_server/adapter_openai.py b/llm/server/server/server/http_server/adapter_openai.py
new file mode 100644
index 000000000000..fe347640358e
--- /dev/null
+++ b/llm/server/server/server/http_server/adapter_openai.py
@@ -0,0 +1,103 @@
+import time
+import json
+import queue
+
+import numpy as np
+from typing import Dict
+from datetime import datetime
+from functools import partial
+
+import tritonclient.grpc as grpcclient
+from tritonclient import utils as triton_utils
+from openai.types.completion_usage import CompletionUsage
+from openai.types.completion_choice import CompletionChoice
+from openai.types.completion import Completion
+from openai.types.chat.chat_completion_chunk import (
+ ChoiceDelta,
+ ChatCompletionChunk,
+ Choice as ChatCompletionChoice
+)
+
+from server.http_server.api import Req, chat_completion_generator
+from server.utils import http_server_logger
+
+
+def format_openai_message_completions(req: Req, result: Dict) -> Completion:
+ choice_data = CompletionChoice(
+ index=0,
+ text=result['token'],
+ finish_reason=result.get("finish_reason", "stop"),
+ )
+ chunk = Completion(
+ id=req.req_id,
+ choices=[choice_data],
+ model=req.model,
+ created=int(time.time()),
+ object="text_completion",
+ usage=CompletionUsage(
+ completion_tokens=result["usage"]["completion_tokens"],
+ prompt_tokens=result["usage"]["prompt_tokens"],
+ total_tokens=result["usage"]["prompt_tokens"] + result["usage"]["completion_tokens"],
+ ),
+ )
+ return chunk.model_dump_json(exclude_unset=True)
+
+
+def format_openai_message_chat_completions(req: Req, result: Dict) -> ChatCompletionChunk:
+ choice_data = ChatCompletionChoice(
+ index=0,
+ delta=ChoiceDelta(
+ content=result['token'],
+ role="assistant",
+ ),
+ finish_reason=result.get("finish_reason", "stop"),
+ )
+ chunk = ChatCompletionChunk(
+ id=req.req_id,
+ choices=[choice_data],
+ model=req.model,
+ created=int(time.time()),
+ object="chat.completion.chunk",
+ usage=CompletionUsage(
+ completion_tokens=result["usage"]["completion_tokens"],
+ prompt_tokens=result["usage"]["prompt_tokens"],
+ total_tokens=result["usage"]["prompt_tokens"] + result["usage"]["completion_tokens"],
+ ),
+ )
+ return chunk.model_dump_json(exclude_unset=True)
+
+
+def openai_chat_commpletion_generator(infer_grpc_url: str, req: Req, chat_interface: bool) -> Dict:
+
+ def _openai_format_resp(resp_dict):
+ return f"data: {resp_dict}\n\n"
+
+ for resp in chat_completion_generator(infer_grpc_url, req, yield_json=False):
+ if resp.get("is_end") == 1:
+ yield _openai_format_resp("[DONE]")
+
+ if chat_interface:
+ yield _openai_format_resp(format_openai_message_chat_completions(req, resp))
+ else:
+ yield _openai_format_resp(format_openai_message_completions(req, resp))
+
+
+def openai_chat_completion_result(infer_grpc_url: str, req: Req, chat_interface: bool):
+ result = ""
+ error_resp = None
+ for resp in chat_completion_generator(infer_grpc_url, req, yield_json=False):
+ if resp.get("error_msg") or resp.get("error_code"):
+ error_resp = resp
+ error_resp["result"] = ""
+ else:
+ result += resp.get("token")
+ usage = resp.get("usage", None)
+
+ if error_resp:
+ return error_resp
+ response = {'token': result, 'error_msg': '', 'error_code': 0, 'usage': usage}
+
+ if chat_interface:
+ return format_openai_message_chat_completions(req, response)
+ else:
+ return format_openai_message_completions(req, response)
diff --git a/llm/server/server/server/http_server/api.py b/llm/server/server/server/http_server/api.py
new file mode 100644
index 000000000000..df9c066284f4
--- /dev/null
+++ b/llm/server/server/server/http_server/api.py
@@ -0,0 +1,187 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import json
+import queue
+import time
+import uuid
+import shortuuid
+from datetime import datetime
+from functools import partial
+from typing import Dict, List, Optional
+
+import numpy as np
+import tritonclient.grpc as grpcclient
+from pydantic import BaseModel, Field
+from tritonclient import utils as triton_utils
+
+
+class Req(BaseModel):
+ req_id: str = Field(default_factory=lambda: str(uuid.uuid4()))
+ input_ids: Optional[List[int]] = None
+ text: Optional[str] = None
+ messages: Optional[List] = None
+ max_dec_len: Optional[int] = None
+ seq_len: Optional[int] = None
+ min_dec_len: Optional[int] = None
+ temperature: Optional[float] = None
+ topp: Optional[float] = None
+ penalty_score: Optional[float] = None
+ frequency_score: Optional[float] = None
+ presence_score: Optional[float] = None
+ system: Optional[str] = None
+ return_all_tokens: Optional[bool] = None
+ eos_token_ids: Optional[List[int]] = None
+ benchmark: bool = False
+ return_usage: Optional[bool] = False
+ stream: bool = False
+ timeout: int = 300
+ model: str = None
+
+ def to_dict_for_infer(self):
+ """
+ Convert the request parameters into a dictionary
+
+ Returns:
+ dict: request parameters in dict format
+ """
+ req_dict = {}
+ for key, value in self.dict().items():
+ if value is not None:
+ req_dict[key] = value
+ return req_dict
+
+ def load_openai_request(self, request_dict: dict):
+ """
+ Convert openai request to Req
+ official OpenAI API documentation: https://platform.openai.com/docs/api-reference/completions/create
+ """
+ convert_dict = {
+ "text": "prompt",
+ "frequency_score": "frequency_penalty",
+ "max_dec_len": "max_tokens",
+ "stream": "stream",
+ "return_all_tokens": "best_of",
+ "temperature": "temperature",
+ "topp": "top_p",
+ "presence_score": "presence_penalty",
+ "eos_token_ids": "stop",
+ "req_id": "id",
+ "model": "model",
+ "messages": "messages",
+ }
+
+ self.__setattr__("req_id", f"chatcmpl-{shortuuid.random()}")
+ for key, value in convert_dict.items():
+ if request_dict.get(value, None):
+ self.__setattr__(key, request_dict.get(value))
+
+
+def chat_completion_generator(infer_grpc_url: str, req: Req, yield_json: bool) -> Dict:
+ """
+ Chat completion generator based on Triton inference service.
+
+ Args:
+ infer_grpc_url (str): Triton gRPC URL。
+ req (Request): request parameters
+ yield_json (bool): Whether to return the result in json format
+
+ Returns:
+ dict: chat completion result.
+ Normal, return {'token': xxx, 'is_end': xxx, 'send_idx': xxx, ..., 'error_msg': '', 'error_code': 0}
+ Others, return {'error_msg': xxx, 'error_code': xxx}, error_msg not None, error_code != 0
+ """
+ class _TritonOutputData:
+ def __init__(self):
+ self._completed_requests = queue.Queue()
+
+ def _triton_callback(output_data, result, error):
+ """Triton callback function"""
+ if error:
+ output_data._completed_requests.put(error)
+ else:
+ output_data._completed_requests.put(result)
+
+ def _format_resp(resp_dict):
+ if yield_json:
+ return json.dumps(resp_dict, ensure_ascii=False) + "\n"
+ else:
+ return resp_dict
+
+ timeout = req.timeout
+ req_id = req.req_id
+ req_dict = req.to_dict_for_infer()
+ http_received_time = datetime.now()
+
+ inputs = [grpcclient.InferInput("IN", [1], triton_utils.np_to_triton_dtype(np.object_))]
+ inputs[0].set_data_from_numpy(np.array([json.dumps([req_dict])], dtype=np.object_))
+ outputs = [grpcclient.InferRequestedOutput("OUT")]
+ output_data = _TritonOutputData()
+
+ with grpcclient.InferenceServerClient(url=infer_grpc_url, verbose=False) as triton_client:
+ triton_client.start_stream(callback=partial(_triton_callback, output_data))
+
+ triton_client.async_stream_infer(model_name="model",
+ inputs=inputs,
+ request_id=req_dict['req_id'],
+ outputs=outputs)
+ while True:
+ output_item = output_data._completed_requests.get(timeout=timeout)
+ if type(output_item) == triton_utils.InferenceServerException:
+ error_msg = f"status is {output_item.status()}, msg is {output_item.message()}"
+ yield _format_resp({"error_msg": error_msg, "error_code": 500})
+ break
+ else:
+ result = json.loads(output_item.as_numpy("OUT")[0])
+ result = result[0] if isinstance(result, list) else result
+ result["error_msg"] = result.get("error_msg", "")
+ result["error_code"] = result.get("error_code", 0)
+ if req.benchmark:
+ result["http_received_time"] = str(http_received_time)
+ yield _format_resp(result)
+ if (result.get("error_msg") or result.get("error_code")) or result.get("is_end") == 1:
+ break
+
+ triton_client.stop_stream()
+ triton_client.close()
+
+def chat_completion_result(infer_grpc_url: str, req: Req) -> Dict:
+ """
+ Chat completion result with not streaming mode
+
+ Args:
+ infer_grpc_url (str): Triton gRPC URL
+ req (Req): request parameters
+
+ Returns:
+ dict: chat completion result.
+ Normal, return {'tokens_all': xxx, ..., 'error_msg': '', 'error_code': 0}
+ Others, return {'error_msg': xxx, 'error_code': xxx}, error_msg not None, error_code != 0
+ """
+ result = ""
+ error_resp = None
+ for resp in chat_completion_generator(infer_grpc_url, req, yield_json=False):
+ if resp.get("error_msg") or resp.get("error_code"):
+ error_resp = resp
+ error_resp["result"] = ""
+ else:
+ result += resp.get("token")
+ usage = resp.get("usage", None)
+
+ if error_resp:
+ return error_resp
+ response = {'result': result, 'error_msg': '', 'error_code': 0}
+ if req.return_usage:
+ response["usage"] = usage
+ return response
diff --git a/llm/server/server/server/http_server/app.py b/llm/server/server/server/http_server/app.py
new file mode 100644
index 000000000000..19351108c066
--- /dev/null
+++ b/llm/server/server/server/http_server/app.py
@@ -0,0 +1,131 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+
+import uvicorn
+from typing import Dict
+from fastapi import FastAPI, Request
+from fastapi.responses import StreamingResponse
+from server.http_server.api import (Req, chat_completion_generator,
+ chat_completion_result)
+from server.http_server.adapter_openai import (
+ openai_chat_commpletion_generator, openai_chat_completion_result
+)
+from server.utils import http_server_logger
+
+http_server_logger.info(f"create fastapi app...")
+app = FastAPI()
+
+
+@app.post("/v1/chat/completions")
+def create_chat_completion(req: Req):
+ """
+ HTTP Server for chat completion
+ Return:
+ In Stream:
+ Normal, return {'token': xxx, 'is_end': xxx, 'send_idx': xxx, ..., 'error_msg': '', 'error_code': 0}
+ Others, return {'error_msg': xxx, 'error_code': xxx}, error_msg not None, error_code != 0
+ Not In Stream:
+ Normal, return {'tokens_all': xxx, ..., 'error_msg': '', 'error_code': 0}
+ Others, return {'error_msg': xxx, 'error_code': xxx}, error_msg not None, error_code != 0
+ """
+ try:
+ http_server_logger.info(f"receive request: {req.req_id}")
+ grpc_port = int(os.getenv("GRPC_PORT", 0))
+ if grpc_port == 0:
+ return {"error_msg": f"GRPC_PORT ({grpc_port}) for infer service is invalid",
+ "error_code": 400}
+ grpc_url = f"localhost:{grpc_port}"
+
+ if req.stream:
+ generator = chat_completion_generator(infer_grpc_url=grpc_url, req=req, yield_json=True)
+ resp = StreamingResponse(generator, media_type="text/event-stream")
+ else:
+ resp = chat_completion_result(infer_grpc_url=grpc_url, req=req)
+ except Exception as e:
+ resp = {'error_msg': str(e), 'error_code': 501}
+ finally:
+ http_server_logger.info(f"finish request: {req.req_id}")
+ return resp
+
+
+@app.post("/v1/chat/completions/completions")
+def openai_v1_completions(request: Dict):
+ return create_openai_completion(request, chat_interface=False)
+
+
+@app.post("/v1/chat/completions/chat/completions")
+def openai_v1_chat_completions(request: Dict):
+ return create_openai_completion(request, chat_interface=True)
+
+
+def create_openai_completion(request: Dict, chat_interface: bool):
+ try:
+ req = Req()
+ req.load_openai_request(request)
+ except Exception as e:
+ return {"error_msg": "request body is not a valid json format", "error_code": 400, "result": ''}
+
+ try:
+ http_server_logger.info(f"receive request: {req.req_id}")
+
+ grpc_port = int(os.getenv("GRPC_PORT", 0))
+ if grpc_port == 0:
+ return {"error_msg": f"GRPC_PORT ({grpc_port}) for infer service is invalid",
+ "error_code": 400}
+ grpc_url = f"localhost:{grpc_port}"
+
+ if req.stream:
+ generator = openai_chat_commpletion_generator(
+ infer_grpc_url=grpc_url,
+ req=req,
+ chat_interface=chat_interface,
+ )
+ resp = StreamingResponse(generator, media_type="text/event-stream")
+ else:
+ resp = openai_chat_completion_result(infer_grpc_url=grpc_url, req=req, chat_interface=chat_interface)
+ except Exception as e:
+ resp = {'error_msg': str(e), 'error_code': 501}
+ finally:
+ http_server_logger.info(f"finish request: {req.req_id}")
+ return resp
+
+
+def launch_http_server(port: int, workers: int) -> None:
+ """
+ launch http server
+ """
+ http_server_logger.info(f"launch http server with port: {port}, workers: {workers}")
+ try:
+ uvicorn.run(app="server.http_server.app:app",
+ host='0.0.0.0',
+ port=port,
+ workers=workers,
+ log_level="error")
+ except Exception as e:
+ http_server_logger.error(f"launch http server error, {e}")
+
+
+def main():
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--port", default=9904, type=int, help="port to the http server")
+ parser.add_argument("--workers", default=1, type=int, help="set the number of workers for the http service")
+ args = parser.parse_args()
+ launch_http_server(port=args.port, workers=args.workers)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/llm/server/server/server/triton_server.py b/llm/server/server/server/triton_server.py
new file mode 100644
index 000000000000..601a1b017907
--- /dev/null
+++ b/llm/server/server/server/triton_server.py
@@ -0,0 +1,466 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import codecs
+import json
+import multiprocessing
+import os
+import queue
+import subprocess
+import sys
+import threading
+import time
+import traceback
+from collections import Counter, deque
+from datetime import datetime
+
+import numpy as np
+from server.checker import add_default_params, check_basic_params
+from server.engine import engine
+from server.engine.config import Config
+from server.utils import error_logger, model_server_logger
+
+import server
+
+try:
+ import triton_python_backend_utils as pb_utils
+except:
+ model_server_logger.warning(
+ "TritonPythonModel is only available under triton inference server framework."
+ )
+
+if sys.stdout.encoding is None:
+ enc = os.environ["LANG"].split(".")[1]
+ sys.stdout = codecs.getwriter(enc)(sys.stdout)
+
+
+class TritonConfig(Config):
+ """
+ Triton Inference Server config
+ """
+ def __init__(self, base_config):
+ super().__init__()
+ for k, v in base_config.__dict__.items():
+ setattr(self, k, v)
+
+
+class TritonTokenProcessor(engine.TokenProcessor):
+ """
+ initialize Triton Processor
+ """
+ def __init__(self, cfg, triton_server):
+ super().__init__(cfg)
+ self.triton_server = triton_server
+ self.cached_generated_tokens = queue.Queue()
+ self.token_buffer = dict()
+ self.score_buffer = dict()
+
+ self.push_mode_sender_thread = threading.Thread(target=self._push_mode_sender_thread, args=())
+ self.push_mode_sender_thread.daemon = True
+ self.push_mode_sender_thread.start()
+
+ def _push_mode_sender_thread(self):
+ """
+ push mode sender thread
+ """
+ while True:
+ try:
+ batch_result = self.cached_generated_tokens.get()
+ for result in batch_result:
+ req_id = result["req_id"]
+ is_end = result.get("is_end", 0)
+ return_all_tokens = result.get("return_all_tokens", False)
+ if is_end == 0 and (return_all_tokens or self.cfg.disable_streaming):
+ continue
+ if return_all_tokens and "topk_tokens" in result:
+ del result["topk_tokens"]
+ result = self.triton_server.data_processor.process_response(result)
+ if "usage" in result:
+ result["usage"]["prompt_tokens"] = self.triton_server.task_info[req_id]["prompt_tokens"]
+ model_server_logger.debug(f"Send result to client under push mode: {result}")
+ with self.triton_server.thread_lock:
+ _send_result([result], self.triton_server.response_sender[req_id], is_end)
+ if is_end == 1:
+ del self.triton_server.response_sender[req_id]
+ del self.triton_server.task_info[req_id]
+ self.triton_server._update_metrics()
+ except Exception as e:
+ model_server_logger.error("Unexcepted error happend: {}, {}".format(e, str(traceback.format_exc())))
+
+ def postprocess(self, batch_result, exist_finished_task=False):
+ """
+ single postprocess for triton
+ """
+ try:
+ self.cached_generated_tokens.put(batch_result)
+ except Exception as e:
+ model_server_logger.info(
+ "Unexcepted problem happend while process output token: {}, {}"
+ .format(e, str(traceback.format_exc())))
+
+
+class TritonServer(object):
+ """
+ Triton Server
+ """
+
+ def initialize(self, args):
+ """
+ Triton initialization
+ """
+ # start health checker
+ use_custom_health_checker = int(os.getenv("USE_CUSTOM_HEALTH_CHECKER", 1))
+ # if set USE_CUSTOM_HEALTH_CHECKER=1, use custom health checker, need set --allow-http=false
+ # else use tritonserver's health checker, need set --http-port=${HTTP_PORT}
+ if use_custom_health_checker:
+ http_port = os.getenv("HTTP_PORT")
+ if http_port is None:
+ raise Exception("HTTP_PORT must be set")
+ from server.triton_server_helper import start_health_checker
+ multiprocessing.Process(target=start_health_checker, args=(int(http_port), )).start()
+ time.sleep(1)
+
+ model_config = json.loads(args["model_config"])
+ using_decoupled = pb_utils.using_decoupled_model_transaction_policy(
+ model_config)
+ if not using_decoupled:
+ raise pb_utils.TritonModelException(
+ """the model `{}` can generate any number of responses per request,
+ enable decoupled transaction policy in model configuration to
+ serve this model""".format(args["model_name"]))
+
+ # add metrics,use METRICS_PORT get server metrics
+ self.metric_family = pb_utils.MetricFamily(
+ name="inference_server_metrics",
+ description="Metrics for monitoring inference server status",
+ kind=pb_utils.MetricFamily.
+ GAUGE,
+ )
+ self.metrics = {
+ "batch_size":
+ self.metric_family.Metric(labels={"batch_size": "batch_size"}),
+ "block_num":
+ self.metric_family.Metric(labels={"block_num": "block_num"}),
+ "max_batch_size":
+ self.metric_family.Metric(
+ labels={"max_batch_size": "max_batch_size"}),
+ "max_block_num":
+ self.metric_family.Metric(
+ labels={"max_block_num": "max_block_num"}),
+ "available_resource":
+ self.metric_family.Metric(
+ labels={"available_resource": "available_resource"}),
+ }
+
+ # response_sender thread lock
+ self.thread_lock = threading.Lock()
+
+ base_config = Config()
+ self.cfg = TritonConfig(base_config)
+ self.cfg.print(file="log/deploy_init.info")
+
+ # init engine
+ self.token_processor = TritonTokenProcessor(self.cfg, self)
+ self.engine = engine.Engine(self.cfg, self.token_processor)
+ model_server_logger.info("Creat engine...")
+ self.engine.start()
+ model_server_logger.info("Create engine success")
+
+ self._initialize_push_mode()
+ model_server_logger.info("Init triton server success")
+
+
+ def execute(self, requests):
+ """
+ Triton service main function,
+ handling requests received by the Triton framework
+ """
+ if len(requests) != 1:
+ raise pb_utils.TritonModelException(
+ "Only support batch=1, but now it's {}.".format(len(requests)))
+ request = requests[0]
+ current_response_sender = request.get_response_sender()
+ request_tensor = pb_utils.get_input_tensor_by_name(request, "IN")
+ tasks = json.loads(request_tensor.as_numpy()[0])
+
+ model_server_logger.info(f"receive task: {tasks}")
+ self._process_task_push_mode(tasks, current_response_sender)
+ self._update_metrics()
+
+ def finalize(self):
+ """
+ Triton service exit function
+ """
+ model_server_logger.info("Triton service will be terminated...")
+ wait_time = 300
+ while not self.engine.all_tasks_finished():
+ if wait_time <= 0:
+ model_server_logger.warning(f"Ignore the unfinished tasks, force to stop.")
+ break
+ model_server_logger.info(f"There's unfinished tasks, wait {wait_time}...")
+ wait_time -= 5
+ time.sleep(5)
+ model_server_logger.info("Terminate the engine now.")
+ self.enable_insert_task_push_mode = False
+ time.sleep(1)
+ del self.engine
+ if hasattr(self, "http_process"):
+ self.http_process.kill()
+ model_server_logger.info("Triton service is terminated!")
+
+ def _initialize_push_mode(self):
+ from server.data.processor import DataProcessor
+ self.data_processor = DataProcessor()
+ model_server_logger.info("create data processor success")
+
+ if self.cfg.push_mode_http_port < 0:
+ model_server_logger.info("HTTP server for push mode is disabled.")
+ else:
+ model_server_logger.info("launch http server...")
+
+ current_dir_path = os.path.split(os.path.abspath(__file__))[0]
+ http_py_file = "app.py"
+ http_py_path = os.path.join(current_dir_path, "http_server", http_py_file)
+ http_cmd = f"python3 {http_py_path} --port={self.cfg.push_mode_http_port} " \
+ f"--workers={self.cfg.push_mode_http_workers} >log/launch_http.log 2>&1"
+
+ model_server_logger.info(f"Launch HTTP server for push mode, command:{http_cmd}")
+ self.http_process = subprocess.Popen(http_cmd, shell=True, preexec_fn=os.setsid)
+ time.sleep(3)
+ exit_code = self.http_process.poll()
+ if exit_code is None:
+ http_url = f"http://127.0.0.1:{self.cfg.push_mode_http_port}/v1/chat/completions"
+ model_server_logger.info(f"Launch HTTP server for push mode success, http_url:{http_url}")
+ else:
+ error_msg = "\n Launch HTTP service for push mode failed in 3 seconds. " \
+ "Please check log/launch_http.log file \n"
+ model_server_logger.error(error_msg)
+ model_server_logger.info("init push server success")
+
+ self.response_sender = dict()
+ self.task_info = dict()
+ self.cached_task_deque = deque()
+ self.enable_insert_task_push_mode = True
+ self.insert_task_to_engine_thread = threading.Thread(
+ target=self._insert_task_push_mode, args=())
+ self.insert_task_to_engine_thread.daemon = True
+ self.insert_task_to_engine_thread.start()
+
+ def _process_task_push_mode(self, tasks, current_response_sender):
+ """
+ check request and insert into cached_task_deque
+
+ Args:
+ tasks (list): list of request
+ current_response_sender: response sender for current request
+ """
+ try:
+ tik = time.time()
+ req_id = tasks[0]["req_id"]
+ cached_task_num = len(self.cached_task_deque)
+ if cached_task_num >= self.cfg.max_cached_task_num:
+ error_msg = f"cached task num ({cached_task_num}) exceeds " \
+ f"the limit ({self.cfg.max_cached_task_num})"
+ _send_error(error_msg, current_response_sender, req_id=req_id)
+ return
+
+ if not tasks or len(tasks) != 1 or not tasks[0]:
+ error_msg = f"request data should not be empty and query " \
+ f"num {len(tasks)} should be 1"
+ _send_error(error_msg, current_response_sender, req_id=req_id)
+ return
+
+ task = tasks[0]
+ task["preprocess_start_time"] = datetime.now()
+
+ error_msg = check_basic_params(task)
+ if error_msg != []:
+ _send_error(error_msg, current_response_sender, req_id=req_id)
+ return
+
+ task_id = task["req_id"]
+ with self.thread_lock:
+ if task_id in self.response_sender:
+ error_msg = f"The req_id {task_id} already exists in the current batch, " \
+ f"the current request will be ignored."
+ _send_error(error_msg, current_response_sender, req_id=req_id)
+ return
+
+ task = add_default_params(task)
+
+ if int(task.get("enable_text_truncate", 1)):
+ real_seq_len = self.cfg.max_seq_len - task.get("max_dec_len", 800)
+ task = self.data_processor.process_request(task, max_seq_len=real_seq_len)
+ else:
+ task = self.data_processor.process_request(task)
+
+ input_ids_len = len(task["input_ids"])
+ if "max_dec_len" not in task:
+ task["max_dec_len"] = min(self.cfg.max_seq_len - input_ids_len, self.cfg.dec_len_limit)
+ min_dec_len = task["min_dec_len"]
+ if input_ids_len + min_dec_len >= self.cfg.max_seq_len:
+ error_msg = f"Input text is too long, input_ids_len ({input_ids_len}) " \
+ f"+ min_dec_len ({min_dec_len}) >= max_seq_len "
+ _send_error(error_msg, current_response_sender, req_id=req_id)
+ return
+
+ if input_ids_len > self.cfg.seq_len_limit:
+ error_msg = f"Length of input token({input_ids_len}) exceeds the limit MAX_SEQ_LEN({self.cfg.seq_len_limit})."
+ _send_error(error_msg, current_response_sender, req_id=req_id)
+ return
+ if task["max_dec_len"] > self.cfg.dec_len_limit:
+ error_msg = f"The parameter max_dec_len({task['max_dec_len']}) exceeds the limit MAX_DEC_LEN({self.cfg.dec_len_limit})."
+ _send_error(error_msg, current_response_sender, req_id=req_id)
+ return
+
+ required_block_num = self.engine.resource_manager.get_required_block_number(input_ids_len)
+ if required_block_num > self.engine.resource_manager.total_block_number():
+ error_msg = f"The input task required resources is exceed the limit, task={task}."
+ _send_error(error_msg, current_response_sender, req_id=req_id)
+ return
+
+ with self.thread_lock:
+ self.response_sender[task_id] = current_response_sender
+ self.task_info[task_id] = {"prompt_tokens": input_ids_len}
+
+ task["preprocess_end_time"] = datetime.now()
+ self.cached_task_deque.appendleft(task)
+ tok = time.time()
+ model_server_logger.info(f"cache task with req_id ({task_id}), "
+ f"cost time: {tok-tik}s, cached_task_num: {len(self.cached_task_deque)}.")
+ model_server_logger.debug(f"cache task: {task}")
+ except Exception as e:
+ error_msg = "Unexcepted promblem happend while insert new task to server task queue: {}, {}".format(
+ e, str(traceback.format_exc()))
+ _send_error(error_msg, current_response_sender)
+
+ def _insert_task_push_mode(self):
+ """
+ Insert task to engine thread, monitor cached_task_deque.
+ if the engine has resource, insert task to engine
+ """
+ try:
+ while self.enable_insert_task_push_mode:
+ if not hasattr(self, "engine") or self.engine is None:
+ time.sleep(0.1)
+ continue
+ if self.engine.available_batch() == 0:
+ time.sleep(0.001)
+ continue
+ if len(self.cached_task_deque) == 0:
+ time.sleep(0.001)
+ continue
+ if not self.engine.is_queue_empty():
+ time.sleep(0.001)
+ continue
+
+ i_bs = 0
+ for _ in range(self.cfg.max_prefill_batch):
+ if len(self.cached_task_deque) == 0:
+ break
+ if self.engine.available_batch() == 0:
+ break
+ while i_bs < self.cfg.max_batch_size:
+ if self.engine.task_is_finished(i_bs):
+ break
+ i_bs += 1
+ if i_bs >= self.cfg.max_batch_size:
+ break
+ input_token_num = len(self.cached_task_deque[-1]["input_ids"])
+ if not self.engine.is_resource_sufficient(input_token_num):
+ break
+ task = self.cached_task_deque.pop()
+ try:
+ self.engine.insert_tasks([task])
+ except Exception as e:
+ err_msg = "Error happend while insert task to engine: {}, {}.".format(
+ e, str(traceback.format_exc()))
+ with self.thread_lock:
+ _send_result({"error_msg": err_msg},
+ self.response_sender[task["req_id"]], 1)
+ del self.response_sender[task["req_id"]]
+ model_server_logger.info("finish insert_task_push_mode thread")
+ except Exception as e:
+ model_server_logger.error("insert_task_push_mode thread exit "
+ f"unexpectedly, {e}. {str(traceback.format_exc())}")
+
+ def _update_metrics(self):
+ """
+ update metrics
+ """
+ block_num = self.engine.available_block_num()
+ batch_size = self.engine.available_batch()
+ self.metrics["block_num"].set(block_num)
+ self.metrics["max_batch_size"].set(self.cfg.max_batch_size)
+ self.metrics["batch_size"].set(self.cfg.max_batch_size - batch_size)
+ self.metrics["max_block_num"].set(self.cfg.max_block_num)
+ self.metrics["available_resource"].set(block_num * 1.0 /
+ self.cfg.max_block_num)
+
+ def _get_current_server_info(self):
+ """
+ get server info
+ """
+ available_batch_size = min(self.cfg.max_prefill_batch,
+ self.engine.available_batch())
+ available_block_num = self.engine.available_block_num()
+ server_info = {
+ "block_size": int(self.cfg.block_size),
+ "block_num": int(available_block_num),
+ "dec_token_num": int(self.cfg.dec_token_num),
+ "available_resource":
+ 1.0 * available_block_num / self.cfg.max_block_num,
+ "max_batch_size": int(available_batch_size),
+ }
+ return server_info
+
+
+def _send_result(result_dict, sender, end_flag=0):
+ """
+ Send inference result
+
+ Args:
+ result_dict (dict): result of inference
+ sender (grpc.aio.ServerReaderWriter): gRPC ServerReaderWriter object.
+ end_flag (int, optional): flag of end. Defaults to 0.
+ """
+ response = None
+ if result_dict:
+ result_dict = json.dumps(result_dict)
+ end_output = pb_utils.Tensor("OUT",
+ np.array([result_dict], dtype=np.object_))
+ response = pb_utils.InferenceResponse(output_tensors=[end_output])
+ if response is None and end_flag == 0:
+ return
+ sender.send(response, flags=end_flag)
+
+def _send_error(error_msg, sender, error_code=200, req_id=None):
+ """
+ Send error inference result
+
+ Args:
+ error_msg (str): error message
+ sender (grpc.aio.ServerReaderWriter): gRPC ServerReaderWriter object.
+ error_code (int, optional): error code. Defaults to 200.
+ req_id (str, optional): request id. Defaults to None
+ """
+ if not isinstance(error_msg, str):
+ error_msg = str(error_msg)
+ error_info = {"req_id": req_id, "error_msg": error_msg, "error_code": error_code, "version": "4.6", "timestamp": time.time()}
+ error_logger.info(f"{error_info}")
+ model_server_logger.error(error_msg)
+ _send_result(error_info, sender, 1)
+
+
+TritonPythonModel = TritonServer
diff --git a/llm/server/server/server/triton_server_helper.py b/llm/server/server/server/triton_server_helper.py
new file mode 100644
index 000000000000..b299cd4204f8
--- /dev/null
+++ b/llm/server/server/server/triton_server_helper.py
@@ -0,0 +1,157 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import argparse
+import os
+import queue
+import socket
+import subprocess
+import time
+from collections import defaultdict
+from multiprocessing import shared_memory
+
+import numpy as np
+import uvicorn
+from fastapi import FastAPI, HTTPException, Request
+from fastapi.responses import JSONResponse, Response
+from server.engine.config import Config
+from server.utils import get_logger
+
+app = FastAPI()
+env_config = Config()
+logger = get_logger("health_checker", "health_checker.log")
+
+
+@app.get("/v2/health/ready")
+def check_health():
+ """
+ health check interface
+ """
+ status, error_info = check()
+ if status is True:
+ logger.info("check_health: OK")
+ return Response()
+ else:
+ logger.info("check_health: Bad")
+ return JSONResponse(
+ status_code=500,
+ content=error_info)
+
+
+@app.get("/v2/health/live")
+def check_live():
+ """
+ health check interface
+ """
+ status, error_info = check()
+ if status is True:
+ logger.info("check_health: OK")
+ return Response()
+ else:
+ logger.info("check_health: Bad")
+ return JSONResponse(
+ status_code=500,
+ content=error_info)
+
+
+def check_infer_engine_process():
+ """
+ check if infer process is alive
+
+ return:
+ status: bool, True if process is alive else False
+ """
+ mp_num = int(env_config.mp_num)
+ for i in range(mp_num):
+ try:
+ infer_live_flag_shm = shared_memory.SharedMemory(name=env_config.get_unique_name("shm_flag_infer_{}_live".format(i)))
+ except Exception as e:
+ return False
+ return True
+
+
+def check():
+ """
+ State detection interface for inference services
+
+ return:
+ status: bool, True if process is alive else False
+ """
+ error_info = {}
+ grpc_port = os.getenv("GRPC_PORT")
+
+ # 1. check server is ready
+ if grpc_port is not None:
+ sock = socket.socket()
+ try:
+ sock.connect(('localhost', int(grpc_port)))
+ except Exception:
+ error_info["error_code"] = 1
+ error_info["error_msg"] = "server is not ready"
+ logger.info("server is not ready")
+ return False, error_info
+ finally:
+ sock.close()
+
+ # 2.check engine is ready
+ is_engine_live = check_infer_engine_process()
+ if is_engine_live is False:
+ error_info["error_code"] = 2
+ error_info["error_msg"] = "infer engine is down"
+ logger.info("infer engine is down")
+ return False, error_info
+
+ engine_ready_checker = np.ndarray(engine_ready_check_flag.shape, dtype=engine_ready_check_flag.dtype,
+ buffer=shm_engine_ready_check_flag.buf)
+ if engine_ready_checker[0] == 0:
+ error_info["error_code"] = 2
+ error_info["error_msg"] = "infer engine is down"
+ logger.info("infer engine is down")
+ return False, error_info
+
+ # check engine is hang
+ engine_hang_checker = np.ndarray(engine_healthy_recorded_time.shape, dtype=engine_healthy_recorded_time.dtype,
+ buffer=shm_engine_healthy_recorded_time.buf)
+ elapsed_time = time.time() - engine_hang_checker[0]
+ logger.info("engine_checker elapsed time: {}".format(elapsed_time))
+ if (engine_hang_checker[0]) and (elapsed_time > time_interval_threashold):
+ error_info["error_code"] = 3
+ error_info["error_msg"] = "infer engine hangs"
+ logger.info("infer engine hangs")
+ return False, error_info
+
+ return True, error_info
+
+
+def start_health_checker(http_port):
+ import sys
+ sys.stdout = open("log/health_http.log", 'w')
+ sys.stderr = sys.stdout
+ uvicorn.run(app=app, host='0.0.0.0', port=http_port, workers=1, log_level="info")
+
+
+# if infer engine not update for more than 10 seconds,consider it as hang or dead
+time_interval_threashold = env_config.check_health_interval
+engine_healthy_recorded_time = np.zeros([1], dtype=float)
+
+shm_engine_healthy_recorded_time = shared_memory.SharedMemory(
+ create=True,
+ size=engine_healthy_recorded_time.nbytes,
+ name=env_config.get_unique_name("engine_healthy_recorded_time"))
+
+engine_ready_check_flag = np.zeros([1], dtype=np.int32)
+shm_engine_ready_check_flag = shared_memory.SharedMemory(
+ create=True,
+ size=engine_ready_check_flag.nbytes,
+ name=env_config.get_unique_name("engine_ready_check_flag"))
diff --git a/llm/server/server/server/utils.py b/llm/server/server/server/utils.py
new file mode 100644
index 000000000000..bb80f6b0a472
--- /dev/null
+++ b/llm/server/server/server/utils.py
@@ -0,0 +1,195 @@
+# Copyright (c) 2024 PaddlePaddle Authors. All Rights Reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+import codecs
+import logging
+import os
+import pickle
+import re
+import subprocess
+import time
+from datetime import datetime
+from enum import Enum
+from logging.handlers import BaseRotatingHandler
+from pathlib import Path
+
+
+class DailyRotatingFileHandler(BaseRotatingHandler):
+ """
+ like `logging.TimedRotatingFileHandler`, but this class support multi-process
+ """
+
+ def __init__(
+ self,
+ filename,
+ backupCount=0,
+ encoding="utf-8",
+ delay=False,
+ utc=False,
+ **kwargs
+ ):
+ self.backup_count = backupCount
+ self.utc = utc
+ self.suffix = "%Y-%m-%d"
+ self.base_log_path = Path(filename)
+ self.base_filename = self.base_log_path.name
+ self.current_filename = self._compute_fn()
+ self.current_log_path = self.base_log_path.with_name(self.current_filename)
+ BaseRotatingHandler.__init__(self, filename, "a", encoding, delay)
+
+ def shouldRollover(self, record):
+ """
+ check scroll through the log
+ """
+ if self.current_filename != self._compute_fn():
+ return True
+ return False
+
+ def doRollover(self):
+ """
+ scroll log
+ """
+ if self.stream:
+ self.stream.close()
+ self.stream = None
+
+ self.current_filename = self._compute_fn()
+ self.current_log_path = self.base_log_path.with_name(self.current_filename)
+
+ if not self.delay:
+ self.stream = self._open()
+
+ self.delete_expired_files()
+
+ def _compute_fn(self):
+ """
+ Calculate the log file name corresponding current time
+ """
+ return self.base_filename + "." + time.strftime(self.suffix, time.localtime())
+
+ def _open(self):
+ """
+ open new log file
+ """
+ if self.encoding is None:
+ stream = open(str(self.current_log_path), self.mode)
+ else:
+ stream = codecs.open(str(self.current_log_path), self.mode, self.encoding)
+
+ if self.base_log_path.exists():
+ try:
+ if (
+ not self.base_log_path.is_symlink()
+ or os.readlink(self.base_log_path) != self.current_filename
+ ):
+ os.remove(self.base_log_path)
+ except OSError:
+ pass
+
+ try:
+ os.symlink(self.current_filename, str(self.base_log_path))
+ except OSError:
+ pass
+ return stream
+
+ def delete_expired_files(self):
+ """
+ delete expired log files
+ """
+ if self.backup_count <= 0:
+ return
+
+ file_names = os.listdir(str(self.base_log_path.parent))
+ result = []
+ prefix = self.base_filename + "."
+ plen = len(prefix)
+ for file_name in file_names:
+ if file_name[:plen] == prefix:
+ suffix = file_name[plen:]
+ if re.match(r"^\d{4}-\d{2}-\d{2}(\.\w+)?$", suffix):
+ result.append(file_name)
+ if len(result) < self.backup_count:
+ result = []
+ else:
+ result.sort()
+ result = result[: len(result) - self.backup_count]
+
+ for file_name in result:
+ os.remove(str(self.base_log_path.with_name(file_name)))
+
+
+def get_logger(name, file_name, without_formater=False):
+ """
+ get logger
+ """
+ log_dir = os.getenv("FD_LOG_DIR", default="log")
+ is_debug = int(os.getenv("FD_DEBUG", default=0))
+ logger = logging.getLogger(name)
+ if is_debug:
+ logger.setLevel(level=logging.DEBUG)
+ else:
+ logger.setLevel(level=logging.INFO)
+
+ LOG_FILE = "{0}/{1}".format(log_dir, file_name)
+ backup_count = int(os.getenv("FD_LOG_BACKUP_COUNT", 7))
+ handler = DailyRotatingFileHandler(LOG_FILE, backupCount=backup_count)
+
+ formatter = logging.Formatter(
+ "%(levelname)-8s %(asctime)s %(process)-5s %(filename)s[line:%(lineno)d] %(message)s"
+ )
+ if not without_formater:
+ handler.setFormatter(formatter)
+ logger.addHandler(handler)
+ handler.propagate = False
+ return logger
+
+
+def str_to_datetime(date_string):
+ """
+ string to datetime class object
+ """
+ if "." in date_string:
+ return datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S.%f")
+ else:
+ return datetime.strptime(date_string, "%Y-%m-%d %H:%M:%S")
+
+
+def datetime_diff(datetime_start, datetime_end):
+ """
+ Calculate the difference between two dates and times(s)
+
+ Args:
+ datetime_start (Union[str, datetime.datetime]): start time
+ datetime_end (Union[str, datetime.datetime]): end time
+
+ Returns:
+ float: date time difference(s)
+ """
+ if isinstance(datetime_start, str):
+ datetime_start = str_to_datetime(datetime_start)
+ if isinstance(datetime_end, str):
+ datetime_end = str_to_datetime(datetime_end)
+ if datetime_end > datetime_start:
+ cost = datetime_end - datetime_start
+ else:
+ cost = datetime_start - datetime_end
+ return cost.total_seconds()
+
+
+model_server_logger = get_logger("model_server", "infer_server.log")
+http_server_logger = get_logger("http_server", "http_server.log")
+data_processor_logger = get_logger("data_processor", "data_processor.log")
+monitor_logger = get_logger("monitor_logger", "monitor_logger.log", True)
+error_logger = get_logger("error_logger", "error_logger.log", True)