diff --git a/xtuner/chat/__init__.py b/xtuner/chat/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/xtuner/chat/backend/__init__.py b/xtuner/chat/backend/__init__.py new file mode 100644 index 000000000..54351fa29 --- /dev/null +++ b/xtuner/chat/backend/__init__.py @@ -0,0 +1,5 @@ +from .encoder import VisionEncoderForDeploy +from .huggingface import HFBackend +from .lmdeploy import LMDeployBackend + +__all__ = ['VisionEncoderForDeploy', 'HFBackend', 'LMDeployBackend'] diff --git a/xtuner/chat/backend/base.py b/xtuner/chat/backend/base.py new file mode 100644 index 000000000..0a0fd4bbe --- /dev/null +++ b/xtuner/chat/backend/base.py @@ -0,0 +1,26 @@ +from abc import abstractmethod + +from xtuner.types import HybridChatTemplate + + +class BaseBackend(): + + @property + def chat_template(self) -> HybridChatTemplate: + pass + + @abstractmethod + def create_streamer(self, iterable=False): + pass + + @abstractmethod + def chat(self, messages, streamer=None, generation_config=None): + pass + + # @abstractmethod + # def response_with_function_call(self, response: str): + # pass + + # @abstractmethod + # def response_with_code_interpreter(self, response: str): + # pass diff --git a/xtuner/chat/backend/encoder.py b/xtuner/chat/backend/encoder.py new file mode 100644 index 000000000..af05b78df --- /dev/null +++ b/xtuner/chat/backend/encoder.py @@ -0,0 +1,308 @@ +import base64 +import os +from io import BytesIO +from typing import List, Literal, Optional, Union + +import requests +import torch +from peft import PeftModel +from PIL import Image +from torch import nn +from transformers import AutoModel, CLIPImageProcessor, CLIPVisionModel + +from xtuner.dataset.utils import expand2square + + +def load_image_from_base64(image: Union[bytes, str]) -> Image.Image: + """load image from base64 format.""" + return Image.open(BytesIO(base64.b64decode(image))) + + +def load_image(image_url: str) -> Image.Image: + """load image from url, local path or openai GPT4V.""" + + headers = { + 'User-Agent': + 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 ' + '(KHTML, like Gecko) Chrome/58.0.3029.110 Safari/537.3' + } + if image_url.startswith('http'): + response = requests.get(image_url, headers=headers) + response.raise_for_status() + + # Open the image using PIL + img = Image.open(BytesIO(response.content)) + elif image_url.startswith('data:image'): + img = load_image_from_base64(image_url.split(',')[1]) + else: + img = Image.open(image_url) + + return img + + +ModelHub = Literal['huggingface', 'modelscope'] + + +class VisionEncoderForDeploy(nn.Module): + + def __init__(self, + model_name_or_path: str, + projector_name_or_path: str, + adapter_name_or_path: str = None, + select_layer: int = -2, + hub: ModelHub = 'huggingface', + device='cuda'): + + super().__init__() + + # model_path = self._parse_model_path(xtuner_model_name_or_path, hub) + # visual_encoder_path = self._parse_visual_encoder_path( + # model_path, visual_encoder_name_or_path, hub + # ) + # projector_path = self._parse_projector_path(model_path) + + # # parse visual encoder adapter path. + # vis_enc_adapter_path = self._parse_vis_enc_adapter_path(model_path) + + self.select_layer = select_layer + self.image_processor = CLIPImageProcessor.from_pretrained( + model_name_or_path) + print(f'Load Image Processor From {model_name_or_path}') + + visual_encoder = CLIPVisionModel.from_pretrained( + model_name_or_path, torch_dtype=torch.float16) + print(f'Load Visual Encoder From {model_name_or_path}') + + # when path is None, means without visual encoder adapter + if adapter_name_or_path: + self.visual_encoder = PeftModel.from_pretrained( + visual_encoder, adapter_name_or_path) + print(f'Load Visual Encoder Adapter From {adapter_name_or_path}') + else: + self.visual_encoder = visual_encoder + + self.projector = AutoModel.from_pretrained( + projector_name_or_path, + torch_dtype=torch.float16, + trust_remote_code=True) + print(f'Load Projector from {projector_name_or_path}') + + self.dtype = torch.float16 + self.device = device + self.to(self.device) + self.to(self.dtype) + + def process_img(self, image: Image.Image) -> List[torch.Tensor]: + """Preprocess the input image, including expanding to square and + normalization. + + Args: + image (Image.Image): The input image need to be preprocessed. + + Returns: + torch.Tensor: The preprocessed image tensor. + """ + + if isinstance(image, str): + image = load_image(image) + + if not isinstance(image, Image.Image): + raise TypeError(f"Don't support {type(image).__name__}, " + 'the image type must be `PIL.Image`.') + + processor = self.image_processor + image_mean = processor.image_mean + + background_color = tuple(int(x * 255) for x in image_mean) + squared_img = expand2square(image, background_color) + + processed = processor.preprocess(squared_img, return_tensors='pt') + img_tensor = processed['pixel_values'][0] # shape: 3, h, w + + # before this line, `img_tensor` is on cpu. + img_tensor = img_tensor.to(self.device).to(self.dtype) + return img_tensor + + @torch.no_grad() + def forward(self, images: List[Union[str, + Image.Image]]) -> List[torch.Tensor]: + """Obtain the corresponding embeddings based on the images. + + Args: + images (List[Image.Image]): The input images. The data layout + for each image is (c, h, w). + + Returns: + List[torch.Tensor]: The list of extracted features from images. + The data layout for each tensor should be (tokens, dims). + """ + + num_imgs = len(images) + + img_tensors = [self.process_img(img) for img in images] + + # Determine if all image sizes are consistent. + # TODO (pppppM): Confirm when the image size will be inconsistent + shape_consistant = all(x.shape == img_tensors[0].shape + for x in img_tensors) + + from transformers.modeling_outputs import BaseModelOutputWithPooling + + if shape_consistant: + # Batch inference when all image sizes are consistent. + # img_tensors[0] shape: (3, h, w) + # tensor shape: (num_imgs, 3, h, w) + tensor = torch.stack(img_tensors, dim=0) + + enc_out = self.visual_encoder(tensor, output_hidden_states=True) + enc_out: BaseModelOutputWithPooling + + # feat shape: (num_imgs, tokens, dims) + feat = self.projector(enc_out.hidden_states[self.select_layer][:, + 1:]) + + # Split along the batch dimension + # The feature of each image corresponds to a tensor. + # len(features): num_imgs, features[0] shape:(1, tokens, dims) + features = torch.chunk(feat, num_imgs, dim=0) + + # per image feature's layout should be (tokens, dims) + features = [x.flatten(0, 1) for x in features] + + else: + features = [] + for tensor in img_tensors: + tensor: torch.Tensor + # The visual encoder requires a data layout of (bs, c, h, w). + # tensor shape: (3, h, w) batch_tensor shape: (1, 3, h, w) + batch_tensor = tensor.unsqueeze(0) + enc_out = self.visual_encoder( + batch_tensor, output_hidden_states=True) + enc_out: BaseModelOutputWithPooling + # feat shape: (1, tokens, dims) + feat = self.projector( + enc_out.hidden_states[self.select_layer][:, 1:]) + features.append(feat) + + return features + + def _parse_model_path(self, name_or_path: str, hub: ModelHub) -> str: + """Parse and get the directory path of the model. It supports load + model from local directory or download from the hub. + + Args: + name_or_path (str): The directory path or name of the model. + hub (str): The hub to download models from. + + Returns: + str: The local directory path of the model. + + Raises: + NotImplementedError: If the input hub is not supported currently. + """ + + if os.path.isdir(name_or_path): + model_path = name_or_path + else: + if hub == 'huggingface': + from huggingface_hub import snapshot_download + model_path = snapshot_download(repo_id=name_or_path) + elif hub == 'modelscope': + from modelscope import snapshot_download + model_path = snapshot_download(model_id=name_or_path) + else: + raise NotImplementedError( + 'Only supports downloading models from `Huggingface` or ' + '`Modelscope`.') + + return model_path + + def _parse_visual_encoder_path(self, model_path: str, + visual_encoder_name_or_path: str, + hub: ModelHub) -> str: + """Parse and get the directory path of the visual encoder. It supports + load visual encoder from local directory, download from the hub, or + find it in the XTuner model directory. + + Args: + model_path (str): The directory path of the model. + visual_encoder_name_or_path (Optional[str]): The directory path or + name of the visual encoder. + hub (str): The hub to download models from. + + Returns: + str: The local directory path of the visual encoder. + + Raises: + NotImplementedError: If the input hub is not supported currently. + """ + + if 'visual_encoder' in os.listdir(model_path): + assert visual_encoder_name_or_path is None + visual_encoder_path = os.path.join(model_path, 'visual_encoder') + elif os.path.isdir(visual_encoder_name_or_path): + visual_encoder_path = visual_encoder_name_or_path + else: + if hub == 'huggingface': + from huggingface_hub import snapshot_download + visual_encoder_path = snapshot_download( + repo_id=visual_encoder_name_or_path) + elif hub == 'modelscope': + from modelscope import snapshot_download + visual_encoder_path = snapshot_download( + model_id=visual_encoder_name_or_path) + else: + raise NotImplementedError( + 'Only supports downloading models from `Huggingface` or ' + '`Modelscope`.') + + return visual_encoder_path + + def _parse_projector_path(self, model_path: str) -> Optional[str]: + """Parse the path of the `projector` model according to the model path. + + Args: + model_path (str): The path to the model directory. + + Raises: + ValueError: If the 'projector' directory is not found in the + `model_path`. + + Returns: + Optional[str]: The full path of 'projector' directory if exists, + else raises ValueError. + """ + if 'projector' in os.listdir(model_path): + projector_path = os.path.join(model_path, 'projector') + else: + # Raises exception if 'projector' directory/folder not found + raise ValueError('Projector directory not found in given path') + return projector_path + + def _parse_vis_enc_adapter_path(self, model_path: str) -> Optional[str]: + """Parses the model path and returns the path to + 'visual_encoder_adapter' directory. + + Args: + model_path (str): The path to the model directory. + + Returns: + Optional[str]: The full path of 'visual_encoder_adapter' directory if exists, + else returns None. + """ + if 'visual_encoder_adapter' in os.listdir(model_path): + adapter_path = os.path.join(model_path, 'visual_encoder_adapter') + else: + # Returns None if 'visual_encoder_adapter' directory/folder not found + adapter_path = None + return adapter_path + + +if __name__ == '__main__': + img = load_image('llava.jpeg') + model = VisionEncoderForDeploy('xtuner/llava-internlm-7b', + 'openai/clip-vit-large-patch14-336') + + model.cuda() + model.eval() + outputs = model([img]) diff --git a/xtuner/chat/backend/huggingface.py b/xtuner/chat/backend/huggingface.py new file mode 100644 index 000000000..51e742327 --- /dev/null +++ b/xtuner/chat/backend/huggingface.py @@ -0,0 +1,224 @@ +from typing import Optional + +import torch +from peft import PeftModel +from transformers import (AutoModelForCausalLM, AutoTokenizer, + BitsAndBytesConfig) +from transformers import GenerationConfig as HFGenerationConfig +from transformers import PreTrainedModel, PreTrainedTokenizer + +from xtuner.chat.streamer import HFTextIteratorStreamer, HFTextStreamer +from xtuner.model.utils import LoadWoInit +from xtuner.tools.utils import get_stop_criteria +from xtuner.types import HybridChatMessages, HybridChatTemplate, SampleParams +from .base import BaseBackend + + +class _HFBackend(BaseBackend): + + def __init__( + self, + chat_template: HybridChatTemplate, + llm: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + ) -> None: + super().__init__() + + self.llm = llm + self.llm.cuda() + self.tokenizer = tokenizer + + self._chat_template = chat_template + + @property + def chat_template(self) -> HybridChatTemplate: + return self._chat_template + + @property + def eos_token_id(self): + if self.tokenizer.pad_token_id: + return self.tokenizer.eos_token_id + else: + return self.tokenizer.eos_token_id + + @property + def pad_token_id(self): + return self.tokenizer.pad_token_id + + def build_llm_and_tokenizer(self, + model_name_or_path, + adapter=None, + bits=None): + + if bits is None: + quantization_config = None + load_in_8bit = False + elif bits == 4: + quantization_config = BitsAndBytesConfig( + load_in_4bit=True, + load_in_8bit=False, + llm_int8_threshold=6.0, + llm_int8_has_fp16_weight=False, + bnb_4bit_compute_dtype=torch.float16, + bnb_4bit_use_double_quant=True, + bnb_4bit_quant_type='nf4') + load_in_8bit = False + elif bits == 8: + quantization_config = None + load_in_8bit = True + + tokenizer = AutoTokenizer.from_pretrained( + model_name_or_path, + trust_remote_code=True, + encode_special_tokens=True) + + with LoadWoInit(): + model = AutoModelForCausalLM.from_pretrained( + model_name_or_path, + device_map='auto', + load_in_8bit=load_in_8bit, + quantization_config=quantization_config, + trust_remote_code=True, + torch_dtype=torch.float16) + + if adapter is not None: + model = PeftModel.from_pretrained(model, adapter) + + model.eval() + return model, tokenizer + + def response_with_code_interpreter(self, response: str): + return False + + def response_with_function_call(self, response: str): + return False + + def create_streamer(self, chat_template=None, iterable=False): + if iterable: + return HFTextIteratorStreamer( + self.tokenizer, skip_prompt=True, chat_template=chat_template) + else: + return HFTextStreamer( + self.tokenizer, skip_prompt=True, chat_template=chat_template) + + def parse_sample_params(self, params: SampleParams) -> HFGenerationConfig: + + if params is None: + params = SampleParams() + + hf_gen_config = HFGenerationConfig( + max_new_tokens=params.max_new_tokens, + do_sample=params.temperature > 0, + temperature=params.temperature, + top_k=params.top_k, + top_p=params.top_p, + repetition_penalty=params.repetition_penalty, + seed=params.seed, + eos_token_id=self.eos_token_id, + pad_token_id=self.pad_token_id) + + stop_words = params.stop_words + stop_words.extend(self.chat_template.stop_words) + + return hf_gen_config, stop_words + + def chat(self, + messages: HybridChatMessages, + streamer=None, + sample_params: Optional[SampleParams] = None): + + prompt = messages.apply_chat_template(self.chat_template) + ids = self.tokenizer.encode(prompt, return_tensors='pt') + + hf_gen_config, stop_words = self.parse_sample_params(sample_params) + + stop_criteria = get_stop_criteria( + tokenizer=self.tokenizer, stop_words=stop_words) + + generate_output = self.llm.generate( + inputs=ids.cuda(), + streamer=streamer, + generation_config=hf_gen_config, + stopping_criteria=stop_criteria) + + output = self.tokenizer.decode( + generate_output[0][len(ids[0]):], skip_special_tokens=True) + + for word in stop_words: + output = output.rstrip(word) + + return output + + +class HFBackend(_HFBackend): + + def __init__( + self, + chat_template: HybridChatTemplate, + llm: PreTrainedModel, + tokenizer: PreTrainedTokenizer, + vision_tower: Optional[torch.nn.Module] = None, + ) -> None: + super().__init__(chat_template, llm, tokenizer) + + if vision_tower: + self.vision_tower = vision_tower + self.vision_tower.cuda() + self.vision_tower.eval() + else: + self.vision_tower = None + + def chat(self, + messages: HybridChatMessages, + streamer=None, + sample_params=None): + + img_urls = messages.collect_img_urls() + + if self.vision_tower is None or len(img_urls) == 0: + return super().chat(messages, streamer, sample_params) + + prompt = messages.apply_chat_template(self.chat_template) + + img_features = self.vision_tower(img_urls) + + # prompt, img_ranges = _insert_img_pad_tokens( + # prompt, self.chat_template.image_token, img_features, + # self.tokenizer.pad_token + # ) + + chunks = prompt.split(self.chat_template.image_token) + assert len(chunks) - 1 == len(img_urls) + chunk_embeddings = [] + for i in range(len(chunks)): + + chunk_ids = self.tokenizer.encode(chunks[i], return_tensors='pt') + chunk_ids = chunk_ids.to(self.llm.device) + chunk_emb = self.llm.get_input_embeddings()(chunk_ids) + chunk_embeddings.append(chunk_emb) + + if i < len(chunks) - 1: + chunk_embeddings.append(img_features[i].unsqueeze(0)) + + embeddings = torch.cat(chunk_embeddings, dim=1) + + hf_gen_config, stop_words = self.parse_sample_params(sample_params) + + stop_criteria = get_stop_criteria( + tokenizer=self.tokenizer, stop_words=stop_words) + + generate_output = self.llm.generate( + input_ids=None, + inputs_embeds=embeddings, + streamer=streamer, + generation_config=hf_gen_config, + bos_token_id=self.tokenizer.bos_token_id, + stopping_criteria=stop_criteria) + + output = self.tokenizer.decode( + generate_output[0], skip_special_tokens=True) + + for word in stop_words: + output = output.rstrip(word) + + return output diff --git a/xtuner/chat/backend/lmdeploy/__init__.py b/xtuner/chat/backend/lmdeploy/__init__.py new file mode 100644 index 000000000..139c066fb --- /dev/null +++ b/xtuner/chat/backend/lmdeploy/__init__.py @@ -0,0 +1,3 @@ +from .backend import LMDeployBackend + +__all__ = ['LMDeployBackend'] diff --git a/xtuner/chat/backend/lmdeploy/_encoder.py b/xtuner/chat/backend/lmdeploy/_encoder.py new file mode 100644 index 000000000..3466eb30f --- /dev/null +++ b/xtuner/chat/backend/lmdeploy/_encoder.py @@ -0,0 +1,122 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import asyncio +import queue +import time +from threading import Thread +from typing import List, Union + +import torch +from lmdeploy.utils import get_logger +from PIL.Image import Image + +logger = get_logger('lmdeploy') + + +class Record: + """Batching manager.""" + + def __init__(self): + self.number = [] + self.waiting = [] + self.done = [] + self.res_que = [] + self.total = 0 + + def enqueue(self, images: List[Image], que: Union[queue.Queue, + asyncio.Queue]): + """add ith request to manager.""" + self.number.append(len(images)) + self.waiting.extend(images) + self.res_que.append(que) + self.total += len(images) + self.log('received', len(images)) + + def dequeue(self, max_batch_size): + """try to dequeue max batch size images.""" + inputs = self.waiting[:max_batch_size] + self.waiting = self.waiting[max_batch_size:] + self.total -= len(inputs) + self.log('process', len(inputs)) + return inputs + + def nofify(self): + """set result if request i is finished.""" + if len(self.number) == 0 or self.number[0] > len(self.done): + return False + num_images = self.number.pop(0) + outputs = self.done[:num_images] + self.done = self.done[num_images:] + que = self.res_que.pop(0) + if isinstance(que, queue.Queue): + que.put(outputs) + else: + que._loop.call_soon_threadsafe(que.put_nowait, outputs) + self.log('done', num_images) + return True + + def log(self, task: str, num: int): + logger.info(f'ImageEncoder {task} {num} images, ' + f'left {self.total} images.') + + +class _AsyncEncoderWrapper: + """Image encoder.""" + + def __init__(self, model, max_batch_size: int = 16): + self.model = model + self.max_batch_size = max_batch_size + self.loop = asyncio.new_event_loop() + self.work_thread = self._start_work_thread() + torch.cuda.empty_cache() + + def _start_work_thread(self): + """internal thread.""" + + def _work_thread(): + asyncio.set_event_loop(self.loop) + self.que = asyncio.Queue() + self.loop.run_until_complete(self._forward_loop()) + + thread = Thread(target=_work_thread, daemon=True) + thread.start() + return thread + + async def _forward_loop(self): + """working loop to process images.""" + logger.info('start ImageEncoder._forward_loop') + record = Record() + while True: + while record.total == 0 or (self.que.qsize() and + record.total < self.max_batch_size): + item = await self.que.get() + record.enqueue(item[0], item[1]) + inputs = record.dequeue(self.max_batch_size) + outputs = self.forward(inputs) + record.done.extend(outputs) + while record.nofify(): + pass + + def forward(self, inputs: List[Image]): + """Model forward.""" + time_start = time.perf_counter() + outputs = self.model.forward(inputs) + time_end = time.perf_counter() + logger.info(f'ImageEncoder forward {len(inputs)} images, ' + f'cost {time_end - time_start:.3f}s') + return outputs + + def infer(self, inputs: List[Image]): + """infer.""" + outputs = queue.Queue() + item = (inputs, outputs) + self.loop.call_soon_threadsafe(self.que.put_nowait, item) + results = outputs.get() + return results + + async def async_infer(self, inputs: List[Image]): + """async infer.""" + outputs = asyncio.Queue() + item = (inputs, outputs) + self.loop.call_soon_threadsafe(self.que.put_nowait, item) + results = await outputs.get() + return results diff --git a/xtuner/chat/backend/lmdeploy/_engine.py b/xtuner/chat/backend/lmdeploy/_engine.py new file mode 100644 index 000000000..d81d30c6c --- /dev/null +++ b/xtuner/chat/backend/lmdeploy/_engine.py @@ -0,0 +1,87 @@ +# Copyright (c) OpenMMLab. All rights reserved. +import numpy as np +from lmdeploy.serve.async_engine import AsyncEngine +from lmdeploy.vl.constants import IMAGE_DUMMY_TOKEN_INDEX + +from xtuner.types import HybridChatMessages, HybridChatTemplate + + +class _MMAsyncEngine(AsyncEngine): + """Visual Language Async inference engine.""" + + def __init__(self, + chat_template: HybridChatTemplate, + *args, + encoder=None, + **kwargs) -> None: + super().__init__(*args, **kwargs) + assert self.model_name == 'base' + self.encoder = encoder + self.chat_template = chat_template + + async def _get_prompt_input(self, prompt: HybridChatMessages, + do_preprocess: bool, sequence_start: bool): + """get input_ids, embeddings and offsets.""" + + decorated = prompt.apply_chat_template(self.chat_template) + segs = decorated.split(self.chat_template.image_token) + + results = {} + input_ids = [] + if len(segs) > 1: + assert self.encoder is not None + img_urls = prompt.collect_img_urls() + features = await self.encoder.async_infer(img_urls) + features = [x.cpu().numpy() for x in features] + input_ids = [] + begins = [] + ends = [] + for i, seg in enumerate(segs): + if i > 0: + image_dim = features[i - 1].shape[0] + begins.append(len(input_ids)) + ends.append(begins[-1] + image_dim) + input_ids.extend([IMAGE_DUMMY_TOKEN_INDEX] * image_dim) + seg_ids = self.tokenizer.encode( + seg, add_bos=((i == 0) and sequence_start)) + input_ids.extend(seg_ids) + ranges = np.stack([begins, ends], axis=1).tolist() + results['input_embeddings'] = features + results['input_embedding_ranges'] = ranges + else: + input_ids = self.tokenizer.encode( + decorated, add_bos=sequence_start) + + results['input_ids'] = input_ids + results['prompt'] = decorated + return results + + # def batch_infer(self, prompts: Union[VLPromptType, List[Dict], + # List[VLPromptType], List[List[Dict]]], + # **kwargs): + # """Inference a batch of prompts.""" + # # prompts = self._convert_prompts(prompts) + # return super().batch_infer(prompts, **kwargs) + + # def stream_infer(self, prompts: Union[VLPromptType, List[Dict], + # List[VLPromptType], + # List[List[Dict]]], **kwargs): + # """Inference a batch of prompts with stream mode.""" + # # prompts = self._convert_prompts(prompts) + # return super().stream_infer(prompts, **kwargs) + + # def __call__(self, prompts, **kwargs): + # """Inference a batch of prompts.""" + # # prompts = self._convert_prompts(prompts) + # return super().__call__(prompts, **kwargs) + + # def chat(self, prompts: VLPromptType, **kwargs): + # """chat.""" + # # _prompts = self._convert_prompts(prompts) + # sess = super().chat(_prompts, **kwargs) + + # # recover prompts & history + # sess._prompt = prompts + # last_round = sess.history[-1] + # sess.history[-1] = (prompts, last_round[-1]) + # return sess diff --git a/xtuner/chat/backend/lmdeploy/backend.py b/xtuner/chat/backend/lmdeploy/backend.py new file mode 100644 index 000000000..1df25fe81 --- /dev/null +++ b/xtuner/chat/backend/lmdeploy/backend.py @@ -0,0 +1,107 @@ +import asyncio +import os +from typing import List, Optional, Union + +from lmdeploy.utils import get_logger + +from xtuner.types import HybridChatMessages, HybridChatTemplate, SampleParams +from ...streamer import LMDeployTextIteratorStreamer, LMDeployTextStreamer +from ..base import BaseBackend +from ._encoder import _AsyncEncoderWrapper +from ._engine import _MMAsyncEngine + +os.environ['TM_LOG_LEVEL'] = 'ERROR' +logger = get_logger('lmdeploy') +logger.setLevel('ERROR') + +_StreamerType = Union[LMDeployTextStreamer, LMDeployTextIteratorStreamer] + + +class LMDeployBackend(BaseBackend): + + def __init__(self, + chat_template, + llm_name_or_path, + vision_encoder=None) -> None: + super().__init__() + + if vision_encoder: + encoder = _AsyncEncoderWrapper(vision_encoder) + else: + encoder = None + + self._engine = _MMAsyncEngine( + chat_template, + encoder=encoder, + model_path=llm_name_or_path, + model_name='base') + + self._chat_template = chat_template + + @property + def chat_template(self) -> HybridChatTemplate: + return self._chat_template + + def create_streamer(self, iterable=False): + + if iterable: + return LMDeployTextIteratorStreamer() + else: + return LMDeployTextStreamer() + + def parse_sample_params(self, + params: SampleParams) -> 'LMGenerationConfig': + + if params is None: + params = SampleParams() + + stop_words = params.stop_words + stop_words.extend(self.chat_template.stop_words) + + from lmdeploy.messages import GenerationConfig as LMDGenerationConfig + lmd_gen_config = LMDGenerationConfig( + max_new_tokens=params.max_new_tokens, + temperature=params.temperature, + top_k=params.top_k, + top_p=params.top_p, + repetition_penalty=params.repetition_penalty, + random_seed=params.seed, + stop_words=stop_words) + + return lmd_gen_config + + def chat(self, + messages: HybridChatMessages, + streamer: Optional[_StreamerType] = None, + sample_params: Optional[SampleParams] = None): + + lmd_gen_config = self.parse_sample_params(sample_params) + self.session_id += 1 + import random + + generator = self._engine.generate( + messages, random.randint(1, 100000), gen_config=lmd_gen_config) + + async def get_response(): + out = '' + async for res in generator: + out += res.response + if streamer: + streamer.put(res.response) + if streamer: + streamer.end() + return out + + loop = asyncio.new_event_loop() + response = loop.run_until_complete(get_response()) + return response + + def batch_infer(self, + messages: List[HybridChatMessages], + sample_params: Optional[SampleParams] = None): + + lmd_gen_config = self.parse_sample_params(sample_params) + + results = self._engine.batch_infer(messages, gen_config=lmd_gen_config) + + return [r.text for r in results] diff --git a/xtuner/chat/conversation.py b/xtuner/chat/conversation.py new file mode 100644 index 000000000..a26616221 --- /dev/null +++ b/xtuner/chat/conversation.py @@ -0,0 +1,147 @@ +from xtuner.chat.backend import HFBackend +from xtuner.types.chat import (ChatMsg, HybridChatMessages, ImageContentItem, + TextContentItem) + + +class Conversation(): + + def __init__(self, + backend: HFBackend, + name=None, + system=None, + functions=None, + code_interpreter=None) -> None: + + self.name = name + self.backend = backend + self.system = system + self.functions = functions + self.code_interpreter = code_interpreter + self._messages = HybridChatMessages() + + if system: + msg = ChatMsg(role='system', content=system) + self._messages.messages.append(msg) + + @property + def messages(self): + return self._messages + + def add_message(self, role, content): + if role == 'system': + assert isinstance(content, str) + msg = ChatMsg(role='system', content=content) + self._messages.messages.append(msg) + elif role == 'user': + self._add_user(content) + elif role == 'assistant': + assert isinstance(content, str) + msg = ChatMsg(role='assistant', content=content) + self._messages.messages.append(msg) + + def _add_user(self, content): + + if isinstance(content, str): + msg = ChatMsg(role='user', content=content) + self._messages.messages.append(msg) + elif isinstance(content, list): + _content = [] + for item in content: + if isinstance(item, (ImageContentItem, TextContentItem)): + _content.append(item) + continue + + assert isinstance(item, dict) + assert 'type' in item + assert item['type'] in item + if item['type'] == 'image_url': + _item = ImageContentItem(image_url=item['image_url']) + _content.append(_item) + elif item['type'] == 'text': + _item = TextContentItem(text=item['text']) + _content.append(_item) + else: + raise NotImplementedError + + msg = ChatMsg(role='user', content=_content) + self._messages.messages.append(msg) + else: + raise TypeError + + def run(self, sample_params=None, streamer=None): + + self.add_message(role='user', content=content) + response = self.backend.chat(self.messages) + self.add_message(role='assistant', content=response) + return response + + def regenerate(self): + + assert self._messages.messages[-1].role == 'assistant' + self._messages.messages.pop() + return self.backend.chat(self.messages) + + def create_streamer(self, iterable=False): + return self.backend.create_streamer(iterable=iterable) + + +if __name__ == '__main__': + + from xtuner.types import HybridChatTemplate + chat_template = HybridChatTemplate( + system='<|im_start|>system\n{system}<|im_end|>\n', + user='<|im_start|>user\n{user}<|im_end|>\n<|im_start|>assistant\n', + assistant='{assistant}<|im_end|>\n', + stop_words=['<|im_end|>'], + image_token='', + function_call= + '{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 + function_result= + '<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 + functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n' + ) + + from transformers import AutoModelForCausalLM, AutoTokenizer + + from xtuner.chat.backend import HFBackend, VisionEncoderForDeploy + + llm = AutoModelForCausalLM.from_pretrained( + '/mnt/petrelfs/share_data/linzhihao/model/models--internlm--internlm2-chat-7b/snapshots/2292b86b21cb856642782cebed0a453997453b1f', + trust_remote_code=True) + tokenizer = AutoTokenizer.from_pretrained( + '/mnt/petrelfs/share_data/linzhihao/model/models--internlm--internlm2-chat-7b/snapshots/2292b86b21cb856642782cebed0a453997453b1f', + trust_remote_code=True) + vision_tower = VisionEncoderForDeploy( + model_name_or_path='openai/clip-vit-large-patch14-336', + adapter_name_or_path= + '/mnt/petrelfs/share_data/linzhihao/model/models--xtuner--llava-internlm2-7b/snapshots/f363b45ce4787bd0a21d43ed724a70ee40ce69b2/visual_encoder_adapter', + projector_name_or_path= + '/mnt/petrelfs/share_data/linzhihao/model/models--xtuner--llava-internlm2-7b/snapshots/f363b45ce4787bd0a21d43ed724a70ee40ce69b2/projector' + ) + + llm.cuda() + + backend = HFBackend( + chat_template, + llm, + tokenizer, + vision_tower, + ) + + conv = Conversation(backend) + print(conv.chat('who are you?')) + + from xtuner.chat.backend import LMDeployBackend + backend = LMDeployBackend( + chat_template, + '/mnt/petrelfs/share_data/linzhihao/model/models--internlm--internlm2-chat-7b/snapshots/2292b86b21cb856642782cebed0a453997453b1f', + vision_tower) + conv = Conversation(backend) + print(conv.chat('who are you?')) + + content = [ + TextContentItem(text='Please describe this image'), + ImageContentItem(image_url='llava.jpeg') + ] + + print(conv.chat(content)) diff --git a/xtuner/chat/streamer/__init__.py b/xtuner/chat/streamer/__init__.py new file mode 100644 index 000000000..7f83155fc --- /dev/null +++ b/xtuner/chat/streamer/__init__.py @@ -0,0 +1,7 @@ +from .huggingface import HFTextIteratorStreamer, HFTextStreamer +from .lmdeploy import LMDeployTextIteratorStreamer, LMDeployTextStreamer + +__all__ = [ + 'HFTextIteratorStreamer', 'HFTextStreamer', 'LMDeployTextIteratorStreamer', + 'LMDeployTextStreamer' +] diff --git a/xtuner/chat/streamer/huggingface.py b/xtuner/chat/streamer/huggingface.py new file mode 100644 index 000000000..91b0f29aa --- /dev/null +++ b/xtuner/chat/streamer/huggingface.py @@ -0,0 +1,37 @@ +from transformers import TextIteratorStreamer, TextStreamer +from transformers.models.auto import AutoTokenizer + + +class HFTextIteratorStreamer(TextIteratorStreamer): + + def __init__(self, + tokenizer: AutoTokenizer, + skip_prompt: bool = False, + timeout=None, + chat_template=None, + **decode_kwargs): + super().__init__(tokenizer, skip_prompt, timeout, **decode_kwargs) + self.chat_template = chat_template + + def on_finalized_text(self, text: str, stream_end: bool = False): + + for word in self.chat_template.stop_words: + text = text.rstrip(word) + super().on_finalized_text(text, stream_end) + + +class HFTextStreamer(TextStreamer): + + def __init__(self, + tokenizer: AutoTokenizer, + skip_prompt: bool = False, + chat_template=None, + **decode_kwargs): + super().__init__(tokenizer, skip_prompt, **decode_kwargs) + self.chat_template = chat_template + + def on_finalized_text(self, text: str, stream_end: bool = False): + + for word in self.chat_template.stop_words: + text = text.rstrip(word) + super().on_finalized_text(text, stream_end) diff --git a/xtuner/chat/streamer/lmdeploy.py b/xtuner/chat/streamer/lmdeploy.py new file mode 100644 index 000000000..2ec03e482 --- /dev/null +++ b/xtuner/chat/streamer/lmdeploy.py @@ -0,0 +1,49 @@ +from queue import Queue +from typing import Optional + +from transformers.generation.streamers import BaseStreamer + + +class LMDeployTextStreamer(BaseStreamer): + + def put(self, text): + self.on_finalized_text(text) + + def end(self): + """Flushes any remaining cache and prints a newline to stdout.""" + self.on_finalized_text('', stream_end=True) + + def on_finalized_text(self, text: str, stream_end: bool = False): + """Prints the new text to stdout. + + If the stream is ending, also prints a newline. + """ + print(text, flush=True, end='' if not stream_end else None) + + +class LMDeployTextIteratorStreamer(LMDeployTextStreamer): + + def __init__(self, timeout: Optional[float] = None): + super().__init__() + self.text_queue = Queue() + self.stop_signal = None + self.timeout = timeout + + def on_finalized_text(self, text: str, stream_end: bool = False): + """Put the new text in the queue. + + If the stream is ending, also put a stop signal in the queue. + """ + self.text_queue.put(text, timeout=self.timeout) + if stream_end: + self.text_queue.put(self.stop_signal, timeout=self.timeout) + + def __iter__(self): + return self + + def __next__(self): + value = self.text_queue.get(timeout=self.timeout) + if value == self.stop_signal: + raise StopIteration() + else: + return value diff --git a/xtuner/dataset/hybrid/dataset.py b/xtuner/dataset/hybrid/dataset.py index e8f127fc6..b2699e048 100644 --- a/xtuner/dataset/hybrid/dataset.py +++ b/xtuner/dataset/hybrid/dataset.py @@ -287,7 +287,6 @@ def img_sample_counter(item): def img_counter(item): return len(item['image_urls']) - with ThreadPoolExecutor(max_workers=self.num_workers) as executor: images = list( tqdm( @@ -403,8 +402,10 @@ def __getitem__(self, item: int) -> Dict[str, List]: assistant='{assistant}<|im_end|>\n', stop_words=['<|im_end|>'], image_token='', - function_call='{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 - function_result='<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 + function_call= + '{assistant}<|action_start|><|plugin|>\n{function_call}<|action_end|><|im_end|>\n', # noqa: E501, E251 + function_result= + '<|im_start|>environment name=<|plugin|>\n{function_result}<|im_end|>\n<|im_start|>assistant\n', # noqa: E501, E251 functions='<|im_start|>system name=<|plugin|>\n{functions}<|im_end|>\n' ) diff --git a/xtuner/types/__init__.py b/xtuner/types/__init__.py index cc230e8f8..79ea745af 100644 --- a/xtuner/types/__init__.py +++ b/xtuner/types/__init__.py @@ -1,6 +1,11 @@ +from .chat import (ChatMsg, HybridChatMessages, ImageContentItem, + TextContentItem) from .chat_template import HybridChatTemplate +from .sample_params import SampleParams from .train import RawTrainingData, TrainingHybridChatMessages __all__ = [ - 'HybridChatTemplate', 'RawTrainingData', 'TrainingHybridChatMessages' + 'ChatMsg', 'HybridChatMessages', 'ImageContentItem', 'TextContentItem', + 'HybridChatTemplate', 'SampleParams', 'RawTrainingData', + 'TrainingHybridChatMessages' ] diff --git a/xtuner/types/chat.py b/xtuner/types/chat.py index 74ac5e30e..c5d67d1e5 100644 --- a/xtuner/types/chat.py +++ b/xtuner/types/chat.py @@ -6,7 +6,7 @@ class TextContentItem(BaseModel): - type: Literal['text'] + type: Literal['text'] = 'text' text: str def format_content(self, chat_template: HybridChatTemplate) -> str: @@ -14,7 +14,7 @@ def format_content(self, chat_template: HybridChatTemplate) -> str: class ImageContentItem(BaseModel): - type: Literal['image_url'] + type: Literal['image_url'] = 'image_url' image_url: str def format_content(self, chat_template: HybridChatTemplate) -> str: diff --git a/xtuner/types/sample_params.py b/xtuner/types/sample_params.py new file mode 100644 index 000000000..137809648 --- /dev/null +++ b/xtuner/types/sample_params.py @@ -0,0 +1,14 @@ +from typing import Optional + +from pydantic import BaseModel + + +class SampleParams(BaseModel): + + max_new_tokens: int = 512 + temperature: float = 0.1 + top_k: int = 40 + top_p: float = 0.75 + repetition_penalty: float = 1.0 + stop_words: list = [] + seed: Optional[int] = None