From 4580755def6793e9146621d67dea77809cb96c98 Mon Sep 17 00:00:00 2001 From: Hemu Date: Mon, 24 Jun 2024 17:13:31 -0700 Subject: [PATCH] General linting and setup update --- chatify/cache.py | 13 ++----- chatify/chains.py | 52 +++++++++++-------------- chatify/llm_models.py | 83 ++++++++++++++++++++-------------------- chatify/main.py | 13 +++---- chatify/utils.py | 8 ++-- chatify/widgets.py | 4 +- setup.py | 88 +++++++++++++++++++++---------------------- 7 files changed, 121 insertions(+), 140 deletions(-) diff --git a/chatify/cache.py b/chatify/cache.py index d0581b0..b7a8db6 100644 --- a/chatify/cache.py +++ b/chatify/cache.py @@ -1,15 +1,8 @@ -from gptcache.adapter.langchain_models import LangChainLLMs from gptcache import Cache -from gptcache.processor.pre import get_prompt - - -from gptcache.manager import get_data_manager, CacheBase, VectorBase - - +from gptcache.adapter.langchain_models import LangChainLLMs from gptcache.embedding import Onnx -from gptcache.embedding.string import to_embeddings as string_embedding - - +from gptcache.manager import CacheBase, VectorBase, get_data_manager +from gptcache.processor.pre import get_prompt from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation from gptcache.similarity_evaluation.exact_match import ExactMatchEvaluation diff --git a/chatify/chains.py b/chatify/chains.py index d48a941..53b0063 100644 --- a/chatify/chains.py +++ b/chatify/chains.py @@ -1,19 +1,13 @@ from typing import Any, Dict, List, Optional import requests - - -from langchain.prompts import PromptTemplate +from langchain.callbacks.manager import CallbackManagerForChainRun from langchain.chains import LLMChain, LLMMathChain from langchain.chains.base import Chain +from langchain.prompts import PromptTemplate - -from typing import Any, Dict, Optional -from langchain.callbacks.manager import CallbackManagerForChainRun - -from .llm_models import ModelsFactory from .cache import LLMCacher - +from .llm_models import ModelsFactory from .utils import compress_code @@ -21,10 +15,10 @@ class RequestChain(Chain): llm_chain: LLMChain = None prompt: Optional[Dict[str, Any]] headers: Optional[Dict[str, str]] = { - 'accept': 'application/json', - 'Content-Type': 'application/json', + "accept": "application/json", + "Content-Type": "application/json", } - input_key: str = 'text' + input_key: str = "text" url: str = "url" #: :meta private: output_key: str = "text" #: :meta private: @@ -50,14 +44,14 @@ def _call( run_manager: Optional[CallbackManagerForChainRun] = None, ) -> Dict[str, Any]: # Prepare data - if self.url != '/': - self.url += '/' - combined_url = self.url + self.prompt['prompt_id'] + '/response' - data = {'user_text': inputs[self.input_key]} + if self.url != "/": + self.url += "/" + combined_url = self.url + self.prompt["prompt_id"] + "/response" + data = {"user_text": inputs[self.input_key]} # Send the request response = requests.post(url=combined_url, headers=self.headers, json=data) - output = eval(response.content.decode('utf-8')) + output = eval(response.content.decode("utf-8")) return {self.output_key: output} @@ -77,16 +71,16 @@ def __init__(self, config): None """ self.config = config - self.chain_config = config['chain_config'] + self.chain_config = config["chain_config"] self.llm_model = None self.llm_models_factory = ModelsFactory() - self.cache = config['cache_config']['cache'] + self.cache = config["cache_config"]["cache"] self.cacher = LLMCacher(config) # Setup model and chain factory - self._setup_llm_model(config['model_config']) + self._setup_llm_model(config["model_config"]) self._setup_chain_factory() return None @@ -112,9 +106,9 @@ def _setup_chain_factory(self): None """ self.chain_factory = { - 'math': LLMMathChain, - 'default': LLMChain, - 'proxy': RequestChain, + "math": LLMMathChain, + "default": LLMChain, + "proxy": RequestChain, } def create_prompt(self, prompt): @@ -129,7 +123,7 @@ def create_prompt(self, prompt): PROMPT (PromptTemplate): Prompt template object. """ PROMPT = PromptTemplate( - template=prompt['content'], input_variables=prompt['input_variables'] + template=prompt["content"], input_variables=prompt["input_variables"] ) return PROMPT @@ -145,15 +139,15 @@ def create_chain(self, model_config=None, prompt_template=None): ------- chain (LLMChain): LLM chain object. """ - if self.config['chain_config']['chain_type'] == 'proxy': + if self.config["chain_config"]["chain_type"] == "proxy": chain = RequestChain( - url=self.config['model_config']['proxy_url'], prompt=prompt_template + url=self.config["model_config"]["proxy_url"], prompt=prompt_template ) else: try: - chain_type = self.chain_config['chain_type'] + chain_type = self.chain_config["chain_type"] except KeyError: - chain_type = 'default' + chain_type = "default" chain = self.chain_factory[chain_type]( llm=self.llm_model, prompt=self.create_prompt(prompt_template) @@ -179,6 +173,6 @@ def execute(self, chain, inputs, *args, **kwargs): output = chain.llm(inputs, cache_obj=self.cacher.llm_cache) self.cacher.llm_cache.flush() else: - output = chain(inputs)['text'] + output = chain(inputs)["text"] return output diff --git a/chatify/llm_models.py b/chatify/llm_models.py index 4b0f084..9d54f99 100644 --- a/chatify/llm_models.py +++ b/chatify/llm_models.py @@ -3,11 +3,10 @@ with warnings.catch_warnings(): # catch warnings about accelerate library warnings.simplefilter("ignore") - from langchain.llms import OpenAI, HuggingFacePipeline, LlamaCpp - from langchain.llms.base import LLM - from langchain.chat_models import ChatOpenAI from langchain.callbacks.manager import CallbackManager from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler + from langchain.chat_models import ChatOpenAI + from langchain.llms import HuggingFacePipeline, LlamaCpp, OpenAI try: from huggingface_hub import hf_hub_download @@ -54,23 +53,23 @@ def get_model(self, model_config): RuntimeError If the specified model is not supported. """ - model_ = model_config['model'] + model_ = model_config["model"] # Collect all the models models = { - 'open_ai_model': OpenAIModel, - 'open_ai_chat_model': OpenAIChatModel, - 'fake_model': FakeLLMModel, - 'cached_model': CachedLLMModel, - 'huggingface_model': HuggingFaceModel, - 'llama_model': LlamaModel, - 'proxy': ProxyModel, + "open_ai_model": OpenAIModel, + "open_ai_chat_model": OpenAIChatModel, + "fake_model": FakeLLMModel, + "cached_model": CachedLLMModel, + "huggingface_model": HuggingFaceModel, + "llama_model": LlamaModel, + "proxy": ProxyModel, } if model_ in models.keys(): with warnings.catch_warnings(): warnings.simplefilter("ignore") - if type(models[model_]) == str: + if isinstance(models[model_], str): return models[model_] else: return models[model_](model_config).init_model() @@ -139,17 +138,17 @@ def init_model(self): llm_model : ChatOpenAI Initialized OpenAI Chat Model. """ - if self.model_config['open_ai_key'] is None: - raise ValueError(f'openai_api_key value cannot be None') + if self.model_config["open_ai_key"] is None: + raise ValueError("openai_api_key value cannot be None") - os.environ["OPENAI_API_KEY"] = self.model_config['open_ai_key'] + os.environ["OPENAI_API_KEY"] = self.model_config["open_ai_key"] llm_model = OpenAI( temperature=0.85, - openai_api_key=self.model_config['open_ai_key'], - model_name=self.model_config['model_name'], + openai_api_key=self.model_config["open_ai_key"], + model_name=self.model_config["model_name"], presence_penalty=0.1, - max_tokens=self.model_config['max_tokens'], + max_tokens=self.model_config["max_tokens"], ) return llm_model @@ -179,15 +178,15 @@ def init_model(self): llm_model : ChatOpenAI Initialized OpenAI Chat Model. """ - if self.model_config['open_ai_key'] is None: - raise ValueError(f'openai_api_key value cannot be None') + if self.model_config["open_ai_key"] is None: + raise ValueError("openai_api_key value cannot be None") llm_model = ChatOpenAI( temperature=0.85, - openai_api_key=self.model_config['open_ai_key'], - model_name=self.model_config['model_name'], + openai_api_key=self.model_config["open_ai_key"], + model_name=self.model_config["model_name"], presence_penalty=0.1, - max_tokens=self.model_config['max_tokens'], + max_tokens=self.model_config["max_tokens"], ) return llm_model @@ -216,7 +215,7 @@ def init_model(self): Initialized Fake Chat Model. """ responses = [ - 'The explanation you requested has not been included in Chatify\'s cache. You\'ll need to enable interactive mode to generate a response. Please see the [Chatify GitHub repository](https://github.com/ContextLab/chatify) for instructions. Note that generating responses to uncached content will require an [OpenAI API Key](https://platform.openai.com/account/api-keys).' + "The explanation you requested has not been included in Chatify's cache. You'll need to enable interactive mode to generate a response. Please see the [Chatify GitHub repository](https://github.com/ContextLab/chatify) for instructions. Note that generating responses to uncached content will require an [OpenAI API Key](https://platform.openai.com/account/api-keys)." ] llm_model = FakeListLLM(responses=responses) return llm_model @@ -247,7 +246,7 @@ def init_model(self): """ llm_model = FakeListLLM( responses=[ - f'The explanation you requested has not been included in Chatify\'s cache. You\'ll need to enable interactive mode to generate a response. Please see the [Chatify GitHub repository](https://github.com/ContextLab/chatify) for instructions. Note that generating responses to uncached content will require an [OpenAI API Key](https://platform.openai.com/account/api-keys).' + "The explanation you requested has not been included in Chatify's cache. You'll need to enable interactive mode to generate a response. Please see the [Chatify GitHub repository](https://github.com/ContextLab/chatify) for instructions. Note that generating responses to uncached content will require an [OpenAI API Key](https://platform.openai.com/account/api-keys)." ] ) return llm_model @@ -276,27 +275,27 @@ def init_model(self): llm_model : HuggingFaceModel Initialized Hugging Face Chat Model. """ - self.proxy = self.model_config['proxy'] - self.proxy_port = self.model_config['proxy_port'] + self.proxy = self.model_config["proxy"] + self.proxy_port = self.model_config["proxy_port"] with warnings.catch_warnings(): warnings.simplefilter("ignore") try: llm = HuggingFacePipeline.from_model_id( - model_id=self.model_config['model_name'], - task='text-generation', + model_id=self.model_config["model_name"], + task="text-generation", device=0, - model_kwargs={'max_length': self.model_config['max_tokens']}, + model_kwargs={"max_length": self.model_config["max_tokens"]}, ) except: llm = HuggingFacePipeline.from_model_id( - model_id=self.model_config['model_name'], - task='text-generation', + model_id=self.model_config["model_name"], + task="text-generation", model_kwargs={ - 'max_length': self.model_config['max_tokens'], - 'temperature': 0.85, - 'presence_penalty': 0.1, + "max_length": self.model_config["max_tokens"], + "temperature": 0.85, + "presence_penalty": 0.1, }, ) return llm @@ -326,8 +325,8 @@ def init_model(self): Initialized Hugging Face Chat Model. """ self.model_path = hf_hub_download( - repo_id=self.model_config['model_name'], - filename=self.model_config['weights_fname'], + repo_id=self.model_config["model_name"], + filename=self.model_config["weights_fname"], ) with warnings.catch_warnings(): @@ -337,17 +336,17 @@ def init_model(self): try: llm = LlamaCpp( model_path=self.model_path, - max_tokens=self.model_config['max_tokens'], - n_gpu_layers=self.model_config['n_gpu_layers'], - n_batch=self.model_config['n_batch'], + max_tokens=self.model_config["max_tokens"], + n_gpu_layers=self.model_config["n_gpu_layers"], + n_batch=self.model_config["n_batch"], callback_manager=callback_manager, verbose=True, ) except: llm = LlamaCpp( model_path=self.model_path, - max_tokens=self.model_config['max_tokens'], - n_batch=self.model_config['n_batch'], + max_tokens=self.model_config["max_tokens"], + n_batch=self.model_config["n_batch"], callback_manager=callback_manager, verbose=True, ) diff --git a/chatify/main.py b/chatify/main.py index 3c2f9d1..ccd883c 100644 --- a/chatify/main.py +++ b/chatify/main.py @@ -1,17 +1,14 @@ -import yaml - import pathlib -import requests - -from IPython.display import display -from IPython.core.magic import Magics, magics_class, cell_magic import ipywidgets as widgets +import yaml +from IPython.core.magic import Magics, cell_magic, magics_class +from IPython.display import display from .chains import CreateLLMChain -from .widgets import option_widget, button_widget, text_widget, thumbs, loading_widget - from .utils import check_dev_config, get_html +from .widgets import (button_widget, loading_widget, option_widget, + text_widget, thumbs) @magics_class diff --git a/chatify/utils.py b/chatify/utils.py index 0ee824c..3454002 100644 --- a/chatify/utils.py +++ b/chatify/utils.py @@ -1,14 +1,12 @@ -from typing import Any, List, Mapping, Optional - import random import urllib +from typing import Any, List, Mapping, Optional +from langchain.llms.base import LLM from markdown_it import MarkdownIt +from pygments import highlight from pygments.formatters import HtmlFormatter from pygments.lexers import get_lexer_by_name -from pygments import highlight - -from langchain.llms.base import LLM def highlight_code(code, name, attrs): diff --git a/chatify/widgets.py b/chatify/widgets.py index 6c793cc..ba8322c 100644 --- a/chatify/widgets.py +++ b/chatify/widgets.py @@ -1,7 +1,7 @@ -import ipywidgets as widgets - import pathlib +import ipywidgets as widgets + def option_widget(config): """Create an options dropdown widget based on the given configuration. diff --git a/setup.py b/setup.py index f85eb38..84adac2 100644 --- a/setup.py +++ b/setup.py @@ -2,73 +2,73 @@ """The setup script.""" -from setuptools import setup, find_packages -from glob import glob +from setuptools import find_packages, setup -with open('README.md') as readme_file: +with open("README.md") as readme_file: readme = readme_file.read() -with open('HISTORY.rst') as history_file: +with open("HISTORY.rst") as history_file: history = history_file.read() requirements = [ - 'gptcache<=0.1.35', - 'langchain<=0.0.226', - 'openai', - 'markdown', - 'ipywidgets', - 'requests', - 'markdown-it-py[linkify,plugins]', - 'pygments', + "gptcache<=0.1.35", + "langchain<=0.0.226", + "openai", + "markdown", + "ipywidgets", + "requests", + "markdown-it-py[linkify,plugins]", + "pygments", + "pydantic==1.10.11", ] extras = [ - 'transformers', - 'torch>=2.0', - 'tensorflow>=2.0', - 'flax', - 'einops', - 'accelerate', - 'xformers', - 'bitsandbytes', - 'sentencepiece', - 'llama-cpp-python', + "transformers", + "torch>=2.0", + "tensorflow>=2.0", + "flax", + "einops", + "accelerate", + "xformers", + "bitsandbytes", + "sentencepiece", + "llama-cpp-python", ] test_requirements = [ - 'pytest>=3', + "pytest>=3", ] setup( author="Contextual Dynamics Lab", - author_email='contextualdynamics@gmail.com', - python_requires='>=3.6', + author_email="contextualdynamics@gmail.com", + python_requires=">=3.6", classifiers=[ - 'Development Status :: 2 - Pre-Alpha', - 'Intended Audience :: Developers', - 'License :: OSI Approved :: MIT License', - 'Natural Language :: English', - 'Programming Language :: Python :: 3', - 'Programming Language :: Python :: 3.6', - 'Programming Language :: Python :: 3.7', - 'Programming Language :: Python :: 3.8', + "Development Status :: 2 - Pre-Alpha", + "Intended Audience :: Developers", + "License :: OSI Approved :: MIT License", + "Natural Language :: English", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3.6", + "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", ], description="A python package that adds a magic command to Jupyter notebooks to enable LLM interactions with code cells.", - description_content_type='text/markdown', + description_content_type="text/markdown", install_requires=requirements, extras_require={ - 'hf': extras, + "hf": extras, }, license="MIT license", - long_description=readme + '\n\n' + history, - long_description_content_type='text/markdown', + long_description=readme + "\n\n" + history, + long_description_content_type="text/markdown", include_package_data=True, - keywords='chatify', - name='chatify', - packages=find_packages(include=['chatify', 'chatify.*']), - test_suite='tests', + keywords="chatify", + name="chatify", + packages=find_packages(include=["chatify", "chatify.*"]), + test_suite="tests", tests_require=test_requirements, - package_data={'': ['**/*.yaml', '**/*.gif']}, - url='https://github.com/ContextLab/chatify', - version='0.2.1', + package_data={"": ["**/*.yaml", "**/*.gif"]}, + url="https://github.com/ContextLab/chatify", + version="0.2.1", zip_safe=False, )