diff --git a/apps/language_models/langchain/gen.py b/apps/language_models/langchain/gen.py index 5eb7ad670b..c23c4b3236 100644 --- a/apps/language_models/langchain/gen.py +++ b/apps/language_models/langchain/gen.py @@ -1129,7 +1129,7 @@ def evaluate( max_time=max_time, num_return_sequences=num_return_sequences, ) - for r in run_qa_db( + outr, extra = run_qa_db( query=instruction, iinput=iinput, context=context, @@ -1170,689 +1170,15 @@ def evaluate( auto_reduce_chunks=auto_reduce_chunks, max_chunks=max_chunks, device=self.device, - ): - ( - outr, - extra, - ) = r # doesn't accumulate, new answer every yield, so only save that full answer - yield dict(response=outr, sources=extra) - if save_dir: - extra_dict = gen_hyper_langchain.copy() - extra_dict.update( - prompt_type=prompt_type, - inference_server=inference_server, - langchain_mode=langchain_mode, - langchain_action=langchain_action, - document_choice=document_choice, - num_prompt_tokens=num_prompt_tokens, - instruction=instruction, - iinput=iinput, - context=context, - ) - save_generate_output( - prompt=prompt, - output=outr, - base_model=base_model, - save_dir=save_dir, - where_from="run_qa_db", - extra_dict=extra_dict, - ) - if verbose: - print( - "Post-Generate Langchain: %s decoded_output: %s" - % (str(datetime.now()), len(outr) if outr else -1), - flush=True, - ) + ) + response = dict(response=outr, sources=extra) if outr or base_model in non_hf_types: # if got no response (e.g. not showing sources and got no sources, # so nothing to give to LLM), then slip through and ask LLM # Or if llama/gptj, then just return since they had no response and can't go down below code path # clear before return, since .then() never done if from API clear_torch_cache() - return - - if inference_server.startswith( - "openai" - ) or inference_server.startswith("http"): - if inference_server.startswith("openai"): - import openai - - where_from = "openai_client" - - openai.api_key = os.getenv("OPENAI_API_KEY") - stop_sequences = list( - set(prompter.terminate_response + [prompter.PreResponse]) - ) - stop_sequences = [x for x in stop_sequences if x] - # OpenAI will complain if ask for too many new tokens, takes it as min in some sense, wrongly so. - max_new_tokens_openai = min( - max_new_tokens, model_max_length - num_prompt_tokens - ) - gen_server_kwargs = dict( - temperature=temperature if do_sample else 0, - max_tokens=max_new_tokens_openai, - top_p=top_p if do_sample else 1, - frequency_penalty=0, - n=num_return_sequences, - presence_penalty=1.07 - - repetition_penalty - + 0.6, # so good default - ) - if inference_server == "openai": - response = openai.Completion.create( - model=base_model, - prompt=prompt, - **gen_server_kwargs, - stop=stop_sequences, - stream=stream_output, - ) - if not stream_output: - text = response["choices"][0]["text"] - yield dict( - response=prompter.get_response( - prompt + text, - prompt=prompt, - sanitize_bot_response=sanitize_bot_response, - ), - sources="", - ) - else: - collected_events = [] - text = "" - for event in response: - collected_events.append( - event - ) # save the event response - event_text = event["choices"][0][ - "text" - ] # extract the text - text += event_text # append the text - yield dict( - response=prompter.get_response( - prompt + text, - prompt=prompt, - sanitize_bot_response=sanitize_bot_response, - ), - sources="", - ) - elif inference_server == "openai_chat": - response = openai.ChatCompletion.create( - model=base_model, - messages=[ - { - "role": "system", - "content": "You are a helpful assistant.", - }, - { - "role": "user", - "content": prompt, - }, - ], - stream=stream_output, - **gen_server_kwargs, - ) - if not stream_output: - text = response["choices"][0]["message"]["content"] - yield dict( - response=prompter.get_response( - prompt + text, - prompt=prompt, - sanitize_bot_response=sanitize_bot_response, - ), - sources="", - ) - else: - text = "" - for chunk in response: - delta = chunk["choices"][0]["delta"] - if "content" in delta: - text += delta["content"] - yield dict( - response=prompter.get_response( - prompt + text, - prompt=prompt, - sanitize_bot_response=sanitize_bot_response, - ), - sources="", - ) - else: - raise RuntimeError( - "No such OpenAI mode: %s" % inference_server - ) - elif inference_server.startswith("http"): - inference_server, headers = get_hf_server(inference_server) - from gradio_utils.grclient import GradioClient - from text_generation import Client as HFClient - - if isinstance(model, GradioClient): - gr_client = model - hf_client = None - elif isinstance(model, HFClient): - gr_client = None - hf_client = model - else: - ( - inference_server, - gr_client, - hf_client, - ) = self.get_client_from_inference_server( - inference_server, base_model=base_model - ) - - # quick sanity check to avoid long timeouts, just see if can reach server - requests.get( - inference_server, - timeout=int(os.getenv("REQUEST_TIMEOUT_FAST", "10")), - ) - - if gr_client is not None: - # Note: h2oGPT gradio server could handle input token size issues for prompt, - # but best to handle here so send less data to server - - chat_client = False - where_from = "gr_client" - client_langchain_mode = "Disabled" - client_langchain_action = LangChainAction.QUERY.value - gen_server_kwargs = dict( - temperature=temperature, - top_p=top_p, - top_k=top_k, - num_beams=num_beams, - max_new_tokens=max_new_tokens, - min_new_tokens=min_new_tokens, - early_stopping=early_stopping, - max_time=max_time, - repetition_penalty=repetition_penalty, - num_return_sequences=num_return_sequences, - do_sample=do_sample, - chat=chat_client, - ) - # account for gradio into gradio that handles prompting, avoid duplicating prompter prompt injection - if prompt_type in [ - None, - "", - PromptType.plain.name, - PromptType.plain.value, - str(PromptType.plain.value), - ]: - # if our prompt is plain, assume either correct or gradio server knows different prompt type, - # so pass empty prompt_Type - gr_prompt_type = "" - gr_prompt_dict = "" - gr_prompt = prompt # already prepared prompt - gr_context = "" - gr_iinput = "" - else: - # if already have prompt_type that is not plain, None, or '', then already applied some prompting - # But assume server can handle prompting, and need to avoid double-up. - # Also assume server can do better job of using stopping.py to stop early, so avoid local prompting, let server handle - # So avoid "prompt" and let gradio server reconstruct from prompt_type we passed - # Note it's ok that prompter.get_response() has prompt+text, prompt=prompt passed, - # because just means extra processing and removal of prompt, but that has no human-bot prompting doesn't matter - # since those won't appear - gr_context = context - gr_prompt = instruction - gr_iinput = iinput - gr_prompt_type = prompt_type - gr_prompt_dict = prompt_dict - client_kwargs = dict( - instruction=gr_prompt - if chat_client - else "", # only for chat=True - iinput=gr_iinput, # only for chat=True - context=gr_context, - # streaming output is supported, loops over and outputs each generation in streaming mode - # but leave stream_output=False for simple input/output mode - stream_output=stream_output, - **gen_server_kwargs, - prompt_type=gr_prompt_type, - prompt_dict=gr_prompt_dict, - instruction_nochat=gr_prompt - if not chat_client - else "", - iinput_nochat=gr_iinput, # only for chat=False - langchain_mode=client_langchain_mode, - langchain_action=client_langchain_action, - top_k_docs=top_k_docs, - chunk=chunk, - chunk_size=chunk_size, - document_choice=[DocumentChoices.All_Relevant.name], - ) - api_name = "/submit_nochat_api" # NOTE: like submit_nochat but stable API for string dict passing - if not stream_output: - res = gr_client.predict( - str(dict(client_kwargs)), api_name=api_name - ) - res_dict = ast.literal_eval(res) - text = res_dict["response"] - sources = res_dict["sources"] - yield dict( - response=prompter.get_response( - prompt + text, - prompt=prompt, - sanitize_bot_response=sanitize_bot_response, - ), - sources=sources, - ) - else: - job = gr_client.submit( - str(dict(client_kwargs)), api_name=api_name - ) - text = "" - sources = "" - res_dict = dict(response=text, sources=sources) - while not job.done(): - outputs_list = job.communicator.job.outputs - if outputs_list: - res = job.communicator.job.outputs[-1] - res_dict = ast.literal_eval(res) - text = res_dict["response"] - sources = res_dict["sources"] - if gr_prompt_type == "plain": - # then gradio server passes back full prompt + text - prompt_and_text = text - else: - prompt_and_text = prompt + text - yield dict( - response=prompter.get_response( - prompt_and_text, - prompt=prompt, - sanitize_bot_response=sanitize_bot_response, - ), - sources=sources, - ) - time.sleep(0.01) - # ensure get last output to avoid race - res_all = job.outputs() - if len(res_all) > 0: - res = res_all[-1] - res_dict = ast.literal_eval(res) - text = res_dict["response"] - sources = res_dict["sources"] - else: - # go with old text if last call didn't work - e = job.future._exception - if e is not None: - stre = str(e) - strex = "".join( - traceback.format_tb(e.__traceback__) - ) - else: - stre = "" - strex = "" - - print( - "Bad final response: %s %s %s %s %s: %s %s" - % ( - base_model, - inference_server, - res_all, - prompt, - text, - stre, - strex, - ), - flush=True, - ) - if gr_prompt_type == "plain": - # then gradio server passes back full prompt + text - prompt_and_text = text - else: - prompt_and_text = prompt + text - yield dict( - response=prompter.get_response( - prompt_and_text, - prompt=prompt, - sanitize_bot_response=sanitize_bot_response, - ), - sources=sources, - ) - elif hf_client: - # HF inference server needs control over input tokens - where_from = "hf_client" - - # prompt must include all human-bot like tokens, already added by prompt - # https://github.com/huggingface/text-generation-inference/tree/main/clients/python#types - stop_sequences = list( - set( - prompter.terminate_response - + [prompter.PreResponse] - ) - ) - stop_sequences = [x for x in stop_sequences if x] - gen_server_kwargs = dict( - do_sample=do_sample, - max_new_tokens=max_new_tokens, - # best_of=None, - repetition_penalty=repetition_penalty, - return_full_text=True, - seed=SEED, - stop_sequences=stop_sequences, - temperature=temperature, - top_k=top_k, - top_p=top_p, - # truncate=False, # behaves oddly - # typical_p=top_p, - # watermark=False, - # decoder_input_details=False, - ) - # work-around for timeout at constructor time, will be issue if multi-threading, - # so just do something reasonable or max_time if larger - # lower bound because client is re-used if multi-threading - hf_client.timeout = max(300, max_time) - if not stream_output: - text = hf_client.generate( - prompt, **gen_server_kwargs - ).generated_text - yield dict( - response=prompter.get_response( - text, - prompt=prompt, - sanitize_bot_response=sanitize_bot_response, - ), - sources="", - ) - else: - text = "" - for response in hf_client.generate_stream( - prompt, **gen_server_kwargs - ): - if not response.token.special: - # stop_sequences - text_chunk = response.token.text - text += text_chunk - yield dict( - response=prompter.get_response( - prompt + text, - prompt=prompt, - sanitize_bot_response=sanitize_bot_response, - ), - sources="", - ) - else: - raise RuntimeError( - "Failed to get client: %s" % inference_server - ) - else: - raise RuntimeError( - "No such inference_server %s" % inference_server - ) - - if save_dir and text: - # save prompt + new text - extra_dict = gen_server_kwargs.copy() - extra_dict.update( - dict( - inference_server=inference_server, - num_prompt_tokens=num_prompt_tokens, - ) - ) - save_generate_output( - prompt=prompt, - output=text, - base_model=base_model, - save_dir=save_dir, - where_from=where_from, - extra_dict=extra_dict, - ) - return - else: - assert not inference_server, ( - "inferene_server=%s not supported" % inference_server - ) - - if isinstance(tokenizer, str): - # pipeline - if tokenizer == "summarization": - key = "summary_text" - else: - raise RuntimeError("No such task type %s" % tokenizer) - # NOTE: uses max_length only - yield dict( - response=model(prompt, max_length=max_new_tokens)[0][key], - sources="", - ) - - if "mbart-" in base_model.lower(): - assert src_lang is not None - tokenizer.src_lang = self.languages_covered()[src_lang] - - stopping_criteria = get_stopping( - prompt_type, - prompt_dict, - tokenizer, - self.device, - model_max_length=tokenizer.model_max_length, - ) - - print(prompt) - # exit(0) - inputs = tokenizer(prompt, return_tensors="pt") - if debug and len(inputs["input_ids"]) > 0: - print("input_ids length", len(inputs["input_ids"][0]), flush=True) - input_ids = inputs["input_ids"].to(self.device) - # CRITICAL LIMIT else will fail - max_max_tokens = tokenizer.model_max_length - max_input_tokens = max_max_tokens - min_new_tokens - # NOTE: Don't limit up front due to max_new_tokens, let go up to max or reach max_max_tokens in stopping.py - input_ids = input_ids[:, -max_input_tokens:] - # required for falcon if multiple threads or asyncio accesses to model during generation - if use_cache is None: - use_cache = False if "falcon" in base_model else True - gen_config_kwargs = dict( - temperature=float(temperature), - top_p=float(top_p), - top_k=top_k, - num_beams=num_beams, - do_sample=do_sample, - repetition_penalty=float(repetition_penalty), - num_return_sequences=num_return_sequences, - renormalize_logits=True, - remove_invalid_values=True, - use_cache=use_cache, - ) - token_ids = [ - "eos_token_id", - "pad_token_id", - "bos_token_id", - "cls_token_id", - "sep_token_id", - ] - for token_id in token_ids: - if ( - hasattr(tokenizer, token_id) - and getattr(tokenizer, token_id) is not None - ): - gen_config_kwargs.update( - {token_id: getattr(tokenizer, token_id)} - ) - generation_config = GenerationConfig(**gen_config_kwargs) - - gen_kwargs = dict( - input_ids=input_ids, - generation_config=generation_config, - return_dict_in_generate=True, - output_scores=True, - max_new_tokens=max_new_tokens, # prompt + new - min_new_tokens=min_new_tokens, # prompt + new - early_stopping=early_stopping, # False, True, "never" - max_time=max_time, - stopping_criteria=stopping_criteria, - ) - if "gpt2" in base_model.lower(): - gen_kwargs.update( - dict( - bos_token_id=tokenizer.bos_token_id, - pad_token_id=tokenizer.eos_token_id, - ) - ) - elif "mbart-" in base_model.lower(): - assert tgt_lang is not None - tgt_lang = self.languages_covered()[tgt_lang] - gen_kwargs.update( - dict(forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang]) - ) - else: - token_ids = ["eos_token_id", "bos_token_id", "pad_token_id"] - for token_id in token_ids: - if ( - hasattr(tokenizer, token_id) - and getattr(tokenizer, token_id) is not None - ): - gen_kwargs.update({token_id: getattr(tokenizer, token_id)}) - - decoder_kwargs = dict( - skip_special_tokens=True, clean_up_tokenization_spaces=True - ) - - decoder = functools.partial(tokenizer.decode, **decoder_kwargs) - decoder_raw_kwargs = dict( - skip_special_tokens=False, clean_up_tokenization_spaces=True - ) - - decoder_raw = functools.partial(tokenizer.decode, **decoder_raw_kwargs) - - with torch.no_grad(): - have_lora_weights = lora_weights not in [no_lora_str, "", None] - context_class_cast = ( - NullContext - if self.device == "cpu" - or have_lora_weights - or self.device == "mps" - else torch.autocast - ) - with context_class_cast(self.device): - # protection for gradio not keeping track of closed users, - # else hit bitsandbytes lack of thread safety: - # https://github.com/h2oai/h2ogpt/issues/104 - # but only makes sense if concurrency_count == 1 - context_class = NullContext # if concurrency_count > 1 else filelock.FileLock - if verbose: - print("Pre-Generate: %s" % str(datetime.now()), flush=True) - decoded_output = None - with context_class("generate.lock"): - if verbose: - print("Generate: %s" % str(datetime.now()), flush=True) - # decoded tokenized prompt can deviate from prompt due to special characters - inputs_decoded = decoder(input_ids[0]) - inputs_decoded_raw = decoder_raw(input_ids[0]) - if inputs_decoded == prompt: - # normal - pass - elif inputs_decoded.lstrip() == prompt.lstrip(): - # sometimes extra space in front, make prompt same for prompt removal - prompt = inputs_decoded - elif inputs_decoded_raw == prompt: - # some models specify special tokens that are part of normal prompt, so can't skip them - inputs_decoded = prompt = inputs_decoded_raw - decoder = decoder_raw - decoder_kwargs = decoder_raw_kwargs - elif inputs_decoded_raw.replace(" ", "").replace( - "", "" - ).replace("\n", " ").replace(" ", "") == prompt.replace( - "\n", " " - ).replace( - " ", "" - ): - inputs_decoded = prompt = inputs_decoded_raw - decoder = decoder_raw - decoder_kwargs = decoder_raw_kwargs - else: - if verbose: - print( - "WARNING: Special characters in prompt", - flush=True, - ) - if stream_output: - skip_prompt = False - streamer = H2OTextIteratorStreamer( - tokenizer, - skip_prompt=skip_prompt, - block=False, - **decoder_kwargs, - ) - gen_kwargs.update(dict(streamer=streamer)) - target = wrapped_partial( - self.generate_with_exceptions, - model.generate, - prompt=prompt, - inputs_decoded=inputs_decoded, - raise_generate_gpu_exceptions=raise_generate_gpu_exceptions, - **gen_kwargs, - ) - bucket = queue.Queue() - thread = EThread( - target=target, streamer=streamer, bucket=bucket - ) - thread.start() - outputs = "" - try: - for new_text in streamer: - if bucket.qsize() > 0 or thread.exc: - thread.join() - outputs += new_text - yield dict( - response=prompter.get_response( - outputs, - prompt=inputs_decoded, - sanitize_bot_response=sanitize_bot_response, - ), - sources="", - ) - except BaseException: - # if any exception, raise that exception if was from thread, first - if thread.exc: - raise thread.exc - raise - finally: - # clear before return, since .then() never done if from API - clear_torch_cache() - # in case no exception and didn't join with thread yet, then join - if not thread.exc: - thread.join() - # in case raise StopIteration or broke queue loop in streamer, but still have exception - if thread.exc: - raise thread.exc - decoded_output = outputs - else: - try: - outputs = model.generate(**gen_kwargs) - finally: - clear_torch_cache() # has to be here for API submit_nochat_api since.then() not called - outputs = [decoder(s) for s in outputs.sequences] - yield dict( - response=prompter.get_response( - outputs, - prompt=inputs_decoded, - sanitize_bot_response=sanitize_bot_response, - ), - sources="", - ) - if outputs and len(outputs) >= 1: - decoded_output = prompt + outputs[0] - if save_dir and decoded_output: - extra_dict = gen_config_kwargs.copy() - extra_dict.update( - dict(num_prompt_tokens=num_prompt_tokens) - ) - save_generate_output( - prompt=prompt, - output=decoded_output, - base_model=base_model, - save_dir=save_dir, - where_from="evaluate_%s" % str(stream_output), - extra_dict=gen_config_kwargs, - ) - if verbose: - print( - "Post-Generate: %s decoded_output: %s" - % ( - str(datetime.now()), - len(decoded_output) if decoded_output else -1, - ), - flush=True, - ) - return outputs[0] + return response inputs_list_names = list(inspect.signature(evaluate).parameters) global inputs_kwargs_list diff --git a/apps/language_models/langchain/gpt_langchain.py b/apps/language_models/langchain/gpt_langchain.py index aea7507a3b..a21529cadc 100644 --- a/apps/language_models/langchain/gpt_langchain.py +++ b/apps/language_models/langchain/gpt_langchain.py @@ -2510,8 +2510,7 @@ def _run_qa_db( formatted_doc_chunks = "\n\n".join( [get_url(x) + "\n\n" + x.page_content for x in docs] ) - yield formatted_doc_chunks, "" - return + return formatted_doc_chunks, "" if not docs and langchain_action in [ LangChainAction.SUMMARIZE_MAP.value, LangChainAction.SUMMARIZE_ALL.value, @@ -2523,8 +2522,7 @@ def _run_qa_db( else "No documents to summarize." ) extra = "" - yield ret, extra - return + return ret, extra if not docs and langchain_mode not in [ LangChainMode.DISABLED.value, LangChainMode.CHAT_LLM.value, @@ -2536,8 +2534,7 @@ def _run_qa_db( else "No documents to query." ) extra = "" - yield ret, extra - return + return ret, extra if chain is None and model_name not in non_hf_types: # here if no docs at all and not HF type @@ -2561,7 +2558,7 @@ def _run_qa_db( if not use_context: ret = answer["output_text"] extra = "" - yield ret, extra + return ret, extra elif answer is not None: ret, extra = get_sources_answer( query, @@ -2571,7 +2568,7 @@ def _run_qa_db( answer_with_sources, verbose=verbose, ) - yield ret, extra + return ret, extra return diff --git a/apps/stable_diffusion/web/ui/h2ogpt.py b/apps/stable_diffusion/web/ui/h2ogpt.py index b2381dbfda..be61e01048 100644 --- a/apps/stable_diffusion/web/ui/h2ogpt.py +++ b/apps/stable_diffusion/web/ui/h2ogpt.py @@ -164,10 +164,7 @@ def chat(curr_system_message, history, device, precision): model_lock=True, user_path=userpath_selector.value, ) - for partial_text in output: - history[-1][1] = partial_text["response"] - yield history - + history[-1][1] = output["response"] return history