Skip to content

Commit

Permalink
Merge branch 'main' into Quansight#391-reverse-chat-order
Browse files Browse the repository at this point in the history
  • Loading branch information
arjxn-py authored Jul 11, 2024
2 parents 0871e37 + 55f7fc5 commit e8ef0de
Show file tree
Hide file tree
Showing 20 changed files with 88 additions and 74 deletions.
5 changes: 0 additions & 5 deletions .github/actions/setup-env/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -53,11 +53,6 @@ runs:
mamba env update --file environment-dev.yml
git checkout -- environment-dev.yml
- name: Install redis-server if necessary
if: (steps.cache.outputs.cache-hit != 'true') && (runner.os != 'Windows')
shell: bash -el {0}
run: mamba install --yes --channel conda-forge redis-server

- name: Install playwright
shell: bash -el {0}
run: playwright install
Expand Down
4 changes: 2 additions & 2 deletions docs/examples/gallery_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,8 @@


class DemoStreamingAssistant(assistants.RagnaDemoAssistant):
def answer(self, prompt, sources):
content = next(super().answer(prompt, sources))
def answer(self, messages):
content = next(super().answer(messages))
for chunk in content.split(" "):
yield f"{chunk} "

Expand Down
20 changes: 10 additions & 10 deletions docs/tutorials/gallery_custom_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@

import uuid

from ragna.core import Document, Source, SourceStorage
from ragna.core import Document, Source, SourceStorage, Message


class TutorialSourceStorage(SourceStorage):
Expand Down Expand Up @@ -61,9 +61,9 @@ def retrieve(
# %%
# ### Assistant
#
# [ragna.core.Assistant][]s are objects that take a user prompt and relevant
# [ragna.core.Source][]s and generate a response form that. Usually, assistants are
# LLMs.
# [ragna.core.Assistant][]s are objects that take the chat history as list of
# [ragna.core.Message][]s and their relevant [ragna.core.Source][]s and generate a
# response from that. Usually, assistants are LLMs.
#
# In this tutorial, we define a minimal `TutorialAssistant` that is similar to
# [ragna.assistants.RagnaDemoAssistant][]. In `.answer()` we mirror back the user
Expand All @@ -82,8 +82,11 @@ def retrieve(


class TutorialAssistant(Assistant):
def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]:
def answer(self, messages: list[Message]) -> Iterator[str]:
print(f"Running {type(self).__name__}().answer()")
# For simplicity, we only deal with the last message here, i.e. the latest user
# prompt.
prompt, sources = (message := messages[-1]).content, message.sources
yield (
f"To answer the user prompt '{prompt}', "
f"I was given {len(sources)} source(s)."
Expand Down Expand Up @@ -254,8 +257,7 @@ def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]:
class ElaborateTutorialAssistant(Assistant):
def answer(
self,
prompt: str,
sources: list[Source],
messages: list[Message],
*,
my_required_parameter: int,
my_optional_parameter: str = "foo",
Expand Down Expand Up @@ -393,9 +395,7 @@ def answer(


class AsyncAssistant(Assistant):
async def answer(
self, prompt: str, sources: list[Source]
) -> AsyncIterator[str]:
async def answer(self, messages: list[Message]) -> AsyncIterator[str]:
print(f"Running {type(self).__name__}().answer()")
start = time.perf_counter()
await asyncio.sleep(0.3)
Expand Down
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ dependencies = [
"pydantic-settings>=2",
"PyJWT",
"python-multipart",
"redis",
"questionary",
"rich",
"sqlalchemy>=2",
Expand Down
22 changes: 14 additions & 8 deletions ragna-docker.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,22 +3,28 @@ authentication = "ragna.deploy.RagnaDemoAuthentication"
document = "ragna.core.LocalDocument"
source_storages = [
"ragna.source_storages.Chroma",
"ragna.source_storages.RagnaDemoSourceStorage",
"ragna.source_storages.LanceDB"
]
assistants = [
"ragna.assistants.Jurassic2Ultra",
"ragna.assistants.Claude",
"ragna.assistants.ClaudeInstant",
"ragna.assistants.ClaudeHaiku",
"ragna.assistants.ClaudeOpus",
"ragna.assistants.ClaudeSonnet",
"ragna.assistants.Command",
"ragna.assistants.CommandLight",
"ragna.assistants.RagnaDemoAssistant",
"ragna.assistants.GeminiPro",
"ragna.assistants.GeminiUltra",
"ragna.assistants.Mpt7bInstruct",
"ragna.assistants.Mpt30bInstruct",
"ragna.assistants.Gpt4",
"ragna.assistants.OllamaGemma2B",
"ragna.assistants.OllamaPhi2",
"ragna.assistants.OllamaLlama2",
"ragna.assistants.OllamaLlava",
"ragna.assistants.OllamaMistral",
"ragna.assistants.OllamaMixtral",
"ragna.assistants.OllamaOrcaMini",
"ragna.assistants.Gpt35Turbo16k",
"ragna.assistants.Gpt4",
"ragna.assistants.Jurassic2Ultra",
"ragna.assistants.LlamafileAssistant",
"ragna.assistants.RagnaDemoAssistant",
]

[api]
Expand Down
5 changes: 3 additions & 2 deletions ragna/assistants/_ai21labs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import AsyncIterator, cast

from ragna.core import Source
from ragna.core import Message, Source

from ._http_api import HttpApiAssistant

Expand All @@ -23,11 +23,12 @@ def _make_system_content(self, sources: list[Source]) -> str:
return instruction + "\n\n".join(source.content for source in sources)

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
# See https://docs.ai21.com/reference/j2-chat-api#chat-api-parameters
# See https://docs.ai21.com/reference/j2-complete-api-ref#api-parameters
# See https://docs.ai21.com/reference/j2-chat-api#understanding-the-response
prompt, sources = (message := messages[-1]).content, message.sources
async for data in self._call_api(
"POST",
f"https://api.ai21.com/studio/v1/j2-{self._MODEL_TYPE}/chat",
Expand Down
5 changes: 3 additions & 2 deletions ragna/assistants/_anthropic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import AsyncIterator, cast

from ragna.core import PackageRequirement, RagnaException, Requirement, Source
from ragna.core import Message, PackageRequirement, RagnaException, Requirement, Source

from ._http_api import HttpApiAssistant, HttpStreamingProtocol

Expand Down Expand Up @@ -37,10 +37,11 @@ def _instructize_system_prompt(self, sources: list[Source]) -> str:
)

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
# See https://docs.anthropic.com/claude/reference/messages_post
# See https://docs.anthropic.com/claude/reference/streaming
prompt, sources = (message := messages[-1]).content, message.sources
async for data in self._call_api(
"POST",
"https://api.anthropic.com/v1/messages",
Expand Down
5 changes: 3 additions & 2 deletions ragna/assistants/_cohere.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import AsyncIterator, cast

from ragna.core import RagnaException, Source
from ragna.core import Message, RagnaException, Source

from ._http_api import HttpApiAssistant, HttpStreamingProtocol

Expand All @@ -25,11 +25,12 @@ def _make_source_documents(self, sources: list[Source]) -> list[dict[str, str]]:
return [{"title": source.id, "snippet": source.content} for source in sources]

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
# See https://docs.cohere.com/docs/cochat-beta
# See https://docs.cohere.com/reference/chat
# See https://docs.cohere.com/docs/retrieval-augmented-generation-rag
prompt, sources = (message := messages[-1]).content, message.sources
async for event in self._call_api(
"POST",
"https://api.cohere.ai/v1/chat",
Expand Down
26 changes: 17 additions & 9 deletions ragna/assistants/_demo.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import re
import textwrap
from typing import Iterator

from ragna.core import Assistant, Source
from ragna.core import Assistant, Message, MessageRole


class RagnaDemoAssistant(Assistant):
Expand All @@ -22,11 +21,11 @@ class RagnaDemoAssistant(Assistant):
def display_name(cls) -> str:
return "Ragna/DemoAssistant"

def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]:
if re.search("markdown", prompt, re.IGNORECASE):
def answer(self, messages: list[Message]) -> Iterator[str]:
if "markdown" in messages[-1].content.lower():
yield self._markdown_answer()
else:
yield self._default_answer(prompt, sources)
yield self._default_answer(messages)

def _markdown_answer(self) -> str:
return textwrap.dedent(
Expand All @@ -39,7 +38,8 @@ def _markdown_answer(self) -> str:
"""
).strip()

def _default_answer(self, prompt: str, sources: list[Source]) -> str:
def _default_answer(self, messages: list[Message]) -> str:
prompt, sources = (message := messages[-1]).content, message.sources
sources_display = []
for source in sources:
source_display = f"- {source.document.name}"
Expand All @@ -50,13 +50,16 @@ def _default_answer(self, prompt: str, sources: list[Source]) -> str:
if len(sources) > 3:
sources_display.append("[...]")

n_messages = len([m for m in messages if m.role == MessageRole.USER])
return (
textwrap.dedent(
"""
I'm a demo assistant and can be used to try Ragnas workflow.
I'm a demo assistant and can be used to try Ragna's workflow.
I will only mirror back my inputs.
So far I have received {n_messages} messages.
Your prompt was:
Your last prompt was:
> {prompt}
Expand All @@ -66,5 +69,10 @@ def _default_answer(self, prompt: str, sources: list[Source]) -> str:
"""
)
.strip()
.format(name=str(self), prompt=prompt, sources="\n".join(sources_display))
.format(
name=str(self),
n_messages=n_messages,
prompt=prompt,
sources="\n".join(sources_display),
)
)
5 changes: 3 additions & 2 deletions ragna/assistants/_google.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import AsyncIterator

from ragna.core import Source
from ragna.core import Message, Source

from ._http_api import HttpApiAssistant, HttpStreamingProtocol

Expand All @@ -26,8 +26,9 @@ def _instructize_prompt(self, prompt: str, sources: list[Source]) -> str:
)

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
prompt, sources = (message := messages[-1]).content, message.sources
async for chunk in self._call_api(
"POST",
f"https://generativelanguage.googleapis.com/v1beta/models/{self._MODEL}:streamGenerateContent",
Expand Down
5 changes: 3 additions & 2 deletions ragna/assistants/_ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import cached_property
from typing import AsyncIterator, cast

from ragna.core import RagnaException, Source
from ragna.core import Message, RagnaException

from ._http_api import HttpStreamingProtocol
from ._openai import OpenaiLikeHttpApiAssistant
Expand Down Expand Up @@ -30,8 +30,9 @@ def _url(self) -> str:
return f"{base_url}/api/chat"

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
prompt, sources = (message := messages[-1]).content, message.sources
async for data in self._stream(prompt, sources, max_new_tokens=max_new_tokens):
# Modeled after
# https://github.com/ollama/ollama/blob/06a1508bfe456e82ba053ea554264e140c5057b5/examples/python-loganalysis/readme.md?plain=1#L57-L62
Expand Down
5 changes: 3 additions & 2 deletions ragna/assistants/_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from functools import cached_property
from typing import Any, AsyncIterator, Optional, cast

from ragna.core import Source
from ragna.core import Message, Source

from ._http_api import HttpApiAssistant, HttpStreamingProtocol

Expand Down Expand Up @@ -55,8 +55,9 @@ def _stream(
return self._call_api("POST", self._url, headers=headers, json=json_)

async def answer(
self, prompt: str, sources: list[Source], *, max_new_tokens: int = 256
self, messages: list[Message], *, max_new_tokens: int = 256
) -> AsyncIterator[str]:
prompt, sources = (message := messages[-1]).content, message.sources
async for data in self._stream(prompt, sources, max_new_tokens=max_new_tokens):
choice = data["choices"][0]
if choice["finish_reason"] is not None:
Expand Down
10 changes: 5 additions & 5 deletions ragna/core/_components.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def retrieve(self, documents: list[Document], prompt: str) -> list[Source]:
...


class MessageRole(enum.Enum):
class MessageRole(str, enum.Enum):
"""Message role
Attributes:
Expand Down Expand Up @@ -238,12 +238,12 @@ class Assistant(Component, abc.ABC):
__ragna_protocol_methods__ = ["answer"]

@abc.abstractmethod
def answer(self, prompt: str, sources: list[Source]) -> Iterator[str]:
"""Answer a prompt given some sources.
def answer(self, messages: list[Message]) -> Iterator[str]:
"""Answer a prompt given the chat history.
Args:
prompt: Prompt to be answered.
sources: Sources to use when answering answer the prompt.
messages: List of messages in the chat history. The last item is the current
user prompt and has the relevant sources attached to it.
Returns:
Answer.
Expand Down
7 changes: 4 additions & 3 deletions ragna/core/_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,12 +220,13 @@ async def answer(self, prompt: str, *, stream: bool = False) -> Message:
detail=RagnaException.EVENT,
)

self._messages.append(Message(content=prompt, role=MessageRole.USER))

sources = await self._run(self.source_storage.retrieve, self.documents, prompt)

question = Message(content=prompt, role=MessageRole.USER, sources=sources)
self._messages.append(question)

answer = Message(
content=self._run_gen(self.assistant.answer, prompt, sources),
content=self._run_gen(self.assistant.answer, self._messages.copy()),
role=MessageRole.ASSISTANT,
sources=sources,
)
Expand Down
14 changes: 7 additions & 7 deletions ragna/deploy/_api/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,15 @@ async def answer(
) -> schemas.Message:
with get_session() as session:
chat = database.get_chat(session, user=user, id=id)
chat.messages.append(
schemas.Message(content=prompt, role=ragna.core.MessageRole.USER)
)
core_chat = schema_to_core_chat(session, user=user, chat=chat)

core_answer = await core_chat.answer(prompt, stream=stream)
sources = [schemas.Source.from_core(source) for source in core_answer.sources]
chat.messages.append(
schemas.Message(
content=prompt, role=ragna.core.MessageRole.USER, sources=sources
)
)

if stream:

Expand All @@ -303,10 +306,7 @@ async def message_chunks() -> AsyncIterator[BaseModel]:
answer = schemas.Message(
content=content_chunk,
role=core_answer.role,
sources=[
schemas.Source.from_core(source)
for source in core_answer.sources
],
sources=sources,
)
yield answer

Expand Down
2 changes: 0 additions & 2 deletions requirements-docker.lock
Original file line number Diff line number Diff line change
Expand Up @@ -304,8 +304,6 @@ questionary==2.0.1
# via Ragna (pyproject.toml)
ratelimiter==1.2.0.post0
# via lancedb
redis==5.0.1
# via Ragna (pyproject.toml)
regex==2023.12.25
# via tiktoken
requests==2.31.0
Expand Down
Loading

0 comments on commit e8ef0de

Please sign in to comment.