Skip to content

Commit

Permalink
General linting and setup update (#55)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeremymanning authored Jun 25, 2024
2 parents 824b0b3 + 4580755 commit 4759f8c
Show file tree
Hide file tree
Showing 7 changed files with 121 additions and 140 deletions.
13 changes: 3 additions & 10 deletions chatify/cache.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,8 @@
from gptcache.adapter.langchain_models import LangChainLLMs
from gptcache import Cache
from gptcache.processor.pre import get_prompt


from gptcache.manager import get_data_manager, CacheBase, VectorBase


from gptcache.adapter.langchain_models import LangChainLLMs
from gptcache.embedding import Onnx
from gptcache.embedding.string import to_embeddings as string_embedding


from gptcache.manager import CacheBase, VectorBase, get_data_manager
from gptcache.processor.pre import get_prompt
from gptcache.similarity_evaluation.distance import SearchDistanceEvaluation
from gptcache.similarity_evaluation.exact_match import ExactMatchEvaluation

Expand Down
52 changes: 23 additions & 29 deletions chatify/chains.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,24 @@
from typing import Any, Dict, List, Optional

import requests


from langchain.prompts import PromptTemplate
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains import LLMChain, LLMMathChain
from langchain.chains.base import Chain
from langchain.prompts import PromptTemplate


from typing import Any, Dict, Optional
from langchain.callbacks.manager import CallbackManagerForChainRun

from .llm_models import ModelsFactory
from .cache import LLMCacher

from .llm_models import ModelsFactory
from .utils import compress_code


class RequestChain(Chain):
llm_chain: LLMChain = None
prompt: Optional[Dict[str, Any]]
headers: Optional[Dict[str, str]] = {
'accept': 'application/json',
'Content-Type': 'application/json',
"accept": "application/json",
"Content-Type": "application/json",
}
input_key: str = 'text'
input_key: str = "text"
url: str = "url" #: :meta private:
output_key: str = "text" #: :meta private:

Expand All @@ -50,14 +44,14 @@ def _call(
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
# Prepare data
if self.url != '/':
self.url += '/'
combined_url = self.url + self.prompt['prompt_id'] + '/response'
data = {'user_text': inputs[self.input_key]}
if self.url != "/":
self.url += "/"
combined_url = self.url + self.prompt["prompt_id"] + "/response"
data = {"user_text": inputs[self.input_key]}

# Send the request
response = requests.post(url=combined_url, headers=self.headers, json=data)
output = eval(response.content.decode('utf-8'))
output = eval(response.content.decode("utf-8"))

return {self.output_key: output}

Expand All @@ -77,16 +71,16 @@ def __init__(self, config):
None
"""
self.config = config
self.chain_config = config['chain_config']
self.chain_config = config["chain_config"]

self.llm_model = None
self.llm_models_factory = ModelsFactory()

self.cache = config['cache_config']['cache']
self.cache = config["cache_config"]["cache"]
self.cacher = LLMCacher(config)

# Setup model and chain factory
self._setup_llm_model(config['model_config'])
self._setup_llm_model(config["model_config"])
self._setup_chain_factory()

return None
Expand All @@ -112,9 +106,9 @@ def _setup_chain_factory(self):
None
"""
self.chain_factory = {
'math': LLMMathChain,
'default': LLMChain,
'proxy': RequestChain,
"math": LLMMathChain,
"default": LLMChain,
"proxy": RequestChain,
}

def create_prompt(self, prompt):
Expand All @@ -129,7 +123,7 @@ def create_prompt(self, prompt):
PROMPT (PromptTemplate): Prompt template object.
"""
PROMPT = PromptTemplate(
template=prompt['content'], input_variables=prompt['input_variables']
template=prompt["content"], input_variables=prompt["input_variables"]
)
return PROMPT

Expand All @@ -145,15 +139,15 @@ def create_chain(self, model_config=None, prompt_template=None):
-------
chain (LLMChain): LLM chain object.
"""
if self.config['chain_config']['chain_type'] == 'proxy':
if self.config["chain_config"]["chain_type"] == "proxy":
chain = RequestChain(
url=self.config['model_config']['proxy_url'], prompt=prompt_template
url=self.config["model_config"]["proxy_url"], prompt=prompt_template
)
else:
try:
chain_type = self.chain_config['chain_type']
chain_type = self.chain_config["chain_type"]
except KeyError:
chain_type = 'default'
chain_type = "default"

chain = self.chain_factory[chain_type](
llm=self.llm_model, prompt=self.create_prompt(prompt_template)
Expand All @@ -179,6 +173,6 @@ def execute(self, chain, inputs, *args, **kwargs):
output = chain.llm(inputs, cache_obj=self.cacher.llm_cache)
self.cacher.llm_cache.flush()
else:
output = chain(inputs)['text']
output = chain(inputs)["text"]

return output
83 changes: 41 additions & 42 deletions chatify/llm_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,10 @@

with warnings.catch_warnings(): # catch warnings about accelerate library
warnings.simplefilter("ignore")
from langchain.llms import OpenAI, HuggingFacePipeline, LlamaCpp
from langchain.llms.base import LLM
from langchain.chat_models import ChatOpenAI
from langchain.callbacks.manager import CallbackManager
from langchain.callbacks.streaming_stdout import StreamingStdOutCallbackHandler
from langchain.chat_models import ChatOpenAI
from langchain.llms import HuggingFacePipeline, LlamaCpp, OpenAI

try:
from huggingface_hub import hf_hub_download
Expand Down Expand Up @@ -54,23 +53,23 @@ def get_model(self, model_config):
RuntimeError
If the specified model is not supported.
"""
model_ = model_config['model']
model_ = model_config["model"]

# Collect all the models
models = {
'open_ai_model': OpenAIModel,
'open_ai_chat_model': OpenAIChatModel,
'fake_model': FakeLLMModel,
'cached_model': CachedLLMModel,
'huggingface_model': HuggingFaceModel,
'llama_model': LlamaModel,
'proxy': ProxyModel,
"open_ai_model": OpenAIModel,
"open_ai_chat_model": OpenAIChatModel,
"fake_model": FakeLLMModel,
"cached_model": CachedLLMModel,
"huggingface_model": HuggingFaceModel,
"llama_model": LlamaModel,
"proxy": ProxyModel,
}

if model_ in models.keys():
with warnings.catch_warnings():
warnings.simplefilter("ignore")
if type(models[model_]) == str:
if isinstance(models[model_], str):
return models[model_]
else:
return models[model_](model_config).init_model()
Expand Down Expand Up @@ -139,17 +138,17 @@ def init_model(self):
llm_model : ChatOpenAI
Initialized OpenAI Chat Model.
"""
if self.model_config['open_ai_key'] is None:
raise ValueError(f'openai_api_key value cannot be None')
if self.model_config["open_ai_key"] is None:
raise ValueError("openai_api_key value cannot be None")

os.environ["OPENAI_API_KEY"] = self.model_config['open_ai_key']
os.environ["OPENAI_API_KEY"] = self.model_config["open_ai_key"]

llm_model = OpenAI(
temperature=0.85,
openai_api_key=self.model_config['open_ai_key'],
model_name=self.model_config['model_name'],
openai_api_key=self.model_config["open_ai_key"],
model_name=self.model_config["model_name"],
presence_penalty=0.1,
max_tokens=self.model_config['max_tokens'],
max_tokens=self.model_config["max_tokens"],
)
return llm_model

Expand Down Expand Up @@ -179,15 +178,15 @@ def init_model(self):
llm_model : ChatOpenAI
Initialized OpenAI Chat Model.
"""
if self.model_config['open_ai_key'] is None:
raise ValueError(f'openai_api_key value cannot be None')
if self.model_config["open_ai_key"] is None:
raise ValueError("openai_api_key value cannot be None")

llm_model = ChatOpenAI(
temperature=0.85,
openai_api_key=self.model_config['open_ai_key'],
model_name=self.model_config['model_name'],
openai_api_key=self.model_config["open_ai_key"],
model_name=self.model_config["model_name"],
presence_penalty=0.1,
max_tokens=self.model_config['max_tokens'],
max_tokens=self.model_config["max_tokens"],
)
return llm_model

Expand Down Expand Up @@ -216,7 +215,7 @@ def init_model(self):
Initialized Fake Chat Model.
"""
responses = [
'The explanation you requested has not been included in Chatify\'s cache. You\'ll need to enable interactive mode to generate a response. Please see the [Chatify GitHub repository](https://github.com/ContextLab/chatify) for instructions. Note that generating responses to uncached content will require an [OpenAI API Key](https://platform.openai.com/account/api-keys).'
"The explanation you requested has not been included in Chatify's cache. You'll need to enable interactive mode to generate a response. Please see the [Chatify GitHub repository](https://github.com/ContextLab/chatify) for instructions. Note that generating responses to uncached content will require an [OpenAI API Key](https://platform.openai.com/account/api-keys)."
]
llm_model = FakeListLLM(responses=responses)
return llm_model
Expand Down Expand Up @@ -247,7 +246,7 @@ def init_model(self):
"""
llm_model = FakeListLLM(
responses=[
f'The explanation you requested has not been included in Chatify\'s cache. You\'ll need to enable interactive mode to generate a response. Please see the [Chatify GitHub repository](https://github.com/ContextLab/chatify) for instructions. Note that generating responses to uncached content will require an [OpenAI API Key](https://platform.openai.com/account/api-keys).'
"The explanation you requested has not been included in Chatify's cache. You'll need to enable interactive mode to generate a response. Please see the [Chatify GitHub repository](https://github.com/ContextLab/chatify) for instructions. Note that generating responses to uncached content will require an [OpenAI API Key](https://platform.openai.com/account/api-keys)."
]
)
return llm_model
Expand Down Expand Up @@ -276,27 +275,27 @@ def init_model(self):
llm_model : HuggingFaceModel
Initialized Hugging Face Chat Model.
"""
self.proxy = self.model_config['proxy']
self.proxy_port = self.model_config['proxy_port']
self.proxy = self.model_config["proxy"]
self.proxy_port = self.model_config["proxy_port"]

with warnings.catch_warnings():
warnings.simplefilter("ignore")

try:
llm = HuggingFacePipeline.from_model_id(
model_id=self.model_config['model_name'],
task='text-generation',
model_id=self.model_config["model_name"],
task="text-generation",
device=0,
model_kwargs={'max_length': self.model_config['max_tokens']},
model_kwargs={"max_length": self.model_config["max_tokens"]},
)
except:
llm = HuggingFacePipeline.from_model_id(
model_id=self.model_config['model_name'],
task='text-generation',
model_id=self.model_config["model_name"],
task="text-generation",
model_kwargs={
'max_length': self.model_config['max_tokens'],
'temperature': 0.85,
'presence_penalty': 0.1,
"max_length": self.model_config["max_tokens"],
"temperature": 0.85,
"presence_penalty": 0.1,
},
)
return llm
Expand Down Expand Up @@ -326,8 +325,8 @@ def init_model(self):
Initialized Hugging Face Chat Model.
"""
self.model_path = hf_hub_download(
repo_id=self.model_config['model_name'],
filename=self.model_config['weights_fname'],
repo_id=self.model_config["model_name"],
filename=self.model_config["weights_fname"],
)

with warnings.catch_warnings():
Expand All @@ -337,17 +336,17 @@ def init_model(self):
try:
llm = LlamaCpp(
model_path=self.model_path,
max_tokens=self.model_config['max_tokens'],
n_gpu_layers=self.model_config['n_gpu_layers'],
n_batch=self.model_config['n_batch'],
max_tokens=self.model_config["max_tokens"],
n_gpu_layers=self.model_config["n_gpu_layers"],
n_batch=self.model_config["n_batch"],
callback_manager=callback_manager,
verbose=True,
)
except:
llm = LlamaCpp(
model_path=self.model_path,
max_tokens=self.model_config['max_tokens'],
n_batch=self.model_config['n_batch'],
max_tokens=self.model_config["max_tokens"],
n_batch=self.model_config["n_batch"],
callback_manager=callback_manager,
verbose=True,
)
Expand Down
13 changes: 5 additions & 8 deletions chatify/main.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
import yaml

import pathlib
import requests

from IPython.display import display
from IPython.core.magic import Magics, magics_class, cell_magic

import ipywidgets as widgets
import yaml
from IPython.core.magic import Magics, cell_magic, magics_class
from IPython.display import display

from .chains import CreateLLMChain
from .widgets import option_widget, button_widget, text_widget, thumbs, loading_widget

from .utils import check_dev_config, get_html
from .widgets import (button_widget, loading_widget, option_widget,
text_widget, thumbs)


@magics_class
Expand Down
8 changes: 3 additions & 5 deletions chatify/utils.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,12 @@
from typing import Any, List, Mapping, Optional

import random
import urllib
from typing import Any, List, Mapping, Optional

from langchain.llms.base import LLM
from markdown_it import MarkdownIt
from pygments import highlight
from pygments.formatters import HtmlFormatter
from pygments.lexers import get_lexer_by_name
from pygments import highlight

from langchain.llms.base import LLM


def highlight_code(code, name, attrs):
Expand Down
4 changes: 2 additions & 2 deletions chatify/widgets.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ipywidgets as widgets

import pathlib

import ipywidgets as widgets


def option_widget(config):
"""Create an options dropdown widget based on the given configuration.
Expand Down
Loading

0 comments on commit 4759f8c

Please sign in to comment.