Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature] Add memory properties to StrayCat #870

Open
wants to merge 8 commits into
base: develop
Choose a base branch
from
4 changes: 3 additions & 1 deletion core/cat/looking_glass/cheshire_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

from langchain_core.language_models.llms import BaseLLM
from langchain.base_language import BaseLanguageModel
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_community.llms import Cohere, OpenAI
from langchain_openai import ChatOpenAI
Expand All @@ -17,6 +18,7 @@
from cat.factory.embedder import get_embedder_from_name
import cat.factory.embedder as embedders
from cat.factory.llm import LLMDefaultConfig
from cat.factory.embedder import EmbedderSettings
from cat.factory.llm import get_llm_from_name
from cat.agents.main_agent import MainAgent
from cat.looking_glass.white_rabbit import WhiteRabbit
Expand Down Expand Up @@ -149,7 +151,7 @@ def load_language_model(self) -> BaseLanguageModel:

return llm

def load_language_embedder(self) -> embedders.EmbedderSettings:
def load_language_embedder(self) -> Embeddings:
"""Hook into the embedder selection.

Allows to modify how the Cat selects the embedder at bootstrap time.
Expand Down
51 changes: 40 additions & 11 deletions core/cat/looking_glass/stray_cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,18 +4,27 @@
from typing import Literal, get_args, List, Dict, Union, Any

from langchain.docstore.document import Document
from langchain_core.embeddings import Embeddings
from langchain_core.language_models.chat_models import BaseChatModel
from langchain.base_language import BaseLanguageModel
from qdrant_client import QdrantClient
from cat.memory.vector_memory_collection import VectorMemoryCollection
from cat.looking_glass.white_rabbit import WhiteRabbit
from cat.mad_hatter.mad_hatter import MadHatter
from cat.memory.long_term_memory import LongTermMemory
from cat.rabbit_hole import RabbitHole
from langchain_community.llms import BaseLLM
from langchain_core.messages import AIMessage, HumanMessage, BaseMessage

from cat.factory.embedder import EmbedderSettings
from fastapi import WebSocket


from cat.log import log
from cat.looking_glass.cheshire_cat import CheshireCat
from cat.looking_glass.callbacks import NewTokenHandler
from cat.memory.working_memory import WorkingMemory
from cat.convo.messages import CatMessage, UserMessage, MessageWhy, Role
from cat.agents.base_agent import AgentOutput
from cat.agents.base_agent import AgentOutput, BaseAgent

from cat.utils import levenshtein_distance

Expand Down Expand Up @@ -537,35 +546,55 @@ def langchainfy_chat_history(self, latest_n: int = 5) -> List[BaseMessage]:
return langchain_chat_history

@property
def user_id(self):
def user_id(self) -> str:
return self.__user_id


@property
def user_message(self) -> str:
return self.working_memory.user_message_json.text

@property
def memory_vector_client(self) -> QdrantClient:
return CheshireCat().memory.vectors.vector_db

@property
def episodic_memory(self) -> VectorMemoryCollection:
return CheshireCat().memory.vectors.episodic

@property
def declarative_memory(self) -> VectorMemoryCollection:
return CheshireCat().memory.vectors.declarative

@property
def procedural_memory(self) -> VectorMemoryCollection:
return CheshireCat().memory.vectors.procedural

@property
def _llm(self):
def _llm(self) -> BaseLanguageModel:
return CheshireCat()._llm

@property
def embedder(self):
def embedder(self) -> Embeddings:
return CheshireCat().embedder

@property
def memory(self):
def memory(self) -> LongTermMemory:
return CheshireCat().memory

@property
def rabbit_hole(self):
def rabbit_hole(self) -> RabbitHole:
return CheshireCat().rabbit_hole

@property
def mad_hatter(self):
def mad_hatter(self) -> MadHatter:
return CheshireCat().mad_hatter

@property
def main_agent(self):
def main_agent(self) -> BaseAgent:
return CheshireCat().main_agent

@property
def white_rabbit(self):
def white_rabbit(self) -> WhiteRabbit:
return CheshireCat().white_rabbit

@property
Expand Down
6 changes: 4 additions & 2 deletions core/cat/mad_hatter/decorators/tool.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,18 +99,20 @@ def tool(
) -> Callable:
"""
Make tools out of functions, can be used with or without arguments.

Requires:
- Function must be of type (str, cat) -> str
- Function must have a docstring

Examples:
.. code-block:: python
@tool
def search_api(query: str, cat) -> str:
# Searches the API for the query.
\"\"\"Searches the API for the query.\"\"\"
return "https://api.com/search?q=" + query
@tool("search", return_direct=True)
def search_api(query: str, cat) -> str:
# Searches the API for the query.
\"\"\"Searches the API for the query.\"\"\"
return "https://api.com/search?q=" + query
"""

Expand Down
4 changes: 3 additions & 1 deletion core/cat/memory/vector_memory.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import sys
import socket
from typing import Dict, Literal
from cat.utils import extract_domain_from_url, is_https

from qdrant_client import QdrantClient
Expand All @@ -13,6 +14,8 @@
# @singleton REFACTOR: worth it to have this (or LongTermMemory) as singleton?
class VectorMemory:
local_vector_db = None

collections: Dict[Literal["episodic", "declarative", "procedural"], VectorMemoryCollection] = {}

def __init__(
self,
Expand All @@ -26,7 +29,6 @@ def __init__(
# - Episodic memory will contain user and eventually cat utterances
# - Declarative memory will contain uploaded documents' content
# - Procedural memory will contain tools and knowledge on how to do things
self.collections = {}
for collection_name in ["episodic", "declarative", "procedural"]:
# Instantiate collection
collection = VectorMemoryCollection(
Expand Down
2 changes: 1 addition & 1 deletion core/cat/routes/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async def upload_file(

Note
----------
`chunk_size`, `chunk_overlap` anad `metadata` must be passed as form data.
`chunk_size`, `chunk_overlap` and `metadata` must be passed as form data.
This is necessary because the HTTP protocol does not allow file uploads to be sent as JSON.

Example
Expand Down
Loading