Skip to content

Commit

Permalink
feat(langchain): allow passing trace attributes on chain invocation (#…
Browse files Browse the repository at this point in the history
  • Loading branch information
hassiebp authored Sep 25, 2024
1 parent 1aaeb3f commit 3df6394
Show file tree
Hide file tree
Showing 2 changed files with 77 additions and 11 deletions.
40 changes: 34 additions & 6 deletions langfuse/callback/langchain.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from collections import defaultdict
import httpx
import logging
import typing
Expand Down Expand Up @@ -110,6 +111,7 @@ def __init__(

self.runs = {}
self.prompt_to_parent_run_map = {}
self.trace_updates = defaultdict(dict)

if stateful_client and isinstance(stateful_client, StatefulSpanClient):
self.runs[stateful_client.id] = stateful_client
Expand Down Expand Up @@ -214,6 +216,17 @@ def on_chain_start(
run_id=run_id, parent_run_id=parent_run_id, metadata=metadata
)

# Update trace-level information if this is a root-level chain (no parent)
# and if tags or metadata are provided
if parent_run_id is None and (tags or metadata):
self.trace_updates[run_id].update(
{
"tags": tags and [str(tag) for tag in tags],
"session_id": metadata and metadata.get("langfuse_session_id"),
"user_id": metadata and metadata.get("langfuse_user_id"),
}
)

content = {
"id": self.next_span_id,
"trace_id": self.trace.id,
Expand Down Expand Up @@ -885,32 +898,42 @@ def _update_trace_and_remove_state(
**kwargs: Any,
):
"""Update the trace with the output of the current run. Called at every finish callback event."""
chain_trace_updates = self.trace_updates.pop(run_id, {})

if (
parent_run_id
is None # If we are at the root of the langchain execution -> reached the end of the root
and self.trace is not None # We do have a trace available
and self.trace.id
== str(run_id) # The trace was generated by langchain and not by the user
):
self.trace = self.trace.update(output=output, **kwargs)
self.trace = self.trace.update(
output=output, **chain_trace_updates, **kwargs
)

elif (
parent_run_id is None
and self.trace is not None # We have a user-provided parent
and self.update_stateful_client
):
if self.root_span is not None:
self.root_span = self.root_span.update(output=output, **kwargs)
self.root_span = self.root_span.update(
output=output, **kwargs
) # No trace updates if root_span was user provided
else:
self.trace = self.trace.update(output=output, **kwargs)
self.trace = self.trace.update(
output=output, **chain_trace_updates, **kwargs
)

elif parent_run_id is None and self.langfuse is not None:
"""
For batch runs, self.trace.id == str(run_id) only for the last run
For the rest of the runs, the trace must be manually updated
The check for self.langfuse ensures that no stateful client was provided by the user
"""
self.langfuse.trace(id=str(run_id)).update(output=output, **kwargs)
self.langfuse.trace(id=str(run_id)).update(
output=output, **chain_trace_updates, **kwargs
)

if not keep_state:
del self.runs[run_id]
Expand Down Expand Up @@ -1081,10 +1104,15 @@ def _strip_langfuse_keys_from_dict(metadata: Optional[Dict[str, Any]]):
if metadata is None or not isinstance(metadata, dict):
return metadata

langfuse_keys = ["langfuse_prompt"]
langfuse_metadata_keys = [
"langfuse_prompt",
"langfuse_session_id",
"langfuse_user_id",
]

metadata_copy = metadata.copy()

for key in langfuse_keys:
for key in langfuse_metadata_keys:
metadata_copy.pop(key, None)

return metadata_copy
48 changes: 43 additions & 5 deletions tests/test_langchain.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,9 @@
)
from langchain.chains.openai_functions import create_openai_fn_chain
from langchain.chains.summarize import load_summarize_chain
from langchain_community.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain.document_loaders import TextLoader
from langchain.embeddings.openai import OpenAIEmbeddings
from langchain.llms import OpenAI
from langchain_openai import OpenAI, AzureChatOpenAI, ChatOpenAI
from langchain.memory import ConversationBufferMemory
from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain.schema import Document
Expand Down Expand Up @@ -1767,6 +1766,8 @@ def test_disabled_langfuse():
def test_link_langfuse_prompts_invoke():
langfuse = Langfuse()
trace_name = "test_link_langfuse_prompts_invoke"
session_id = "session_" + create_uuid()[:8]
user_id = "user_" + create_uuid()[:8]

# Create prompts
joke_prompt_name = "joke_prompt_" + create_uuid()[:8]
Expand Down Expand Up @@ -1819,12 +1820,23 @@ def test_link_langfuse_prompts_invoke():
config={
"callbacks": [langfuse_handler],
"run_name": trace_name,
"tags": ["langchain-tag"],
"metadata": {
"langfuse_session_id": session_id,
"langfuse_user_id": user_id,
},
},
)

langfuse_handler.flush()

observations = get_api().trace.get(langfuse_handler.get_trace_id()).observations
trace = get_api().trace.get(langfuse_handler.get_trace_id())

assert trace.tags == ["langchain-tag"]
assert trace.session_id == session_id
assert trace.user_id == user_id

observations = trace.observations

generations = sorted(
list(filter(lambda x: x.type == "GENERATION", observations)),
Expand All @@ -1847,6 +1859,8 @@ def test_link_langfuse_prompts_invoke():
def test_link_langfuse_prompts_stream():
langfuse = Langfuse(debug=True)
trace_name = "test_link_langfuse_prompts_stream"
session_id = "session_" + create_uuid()[:8]
user_id = "user_" + create_uuid()[:8]

# Create prompts
joke_prompt_name = "joke_prompt_" + create_uuid()[:8]
Expand Down Expand Up @@ -1899,6 +1913,11 @@ def test_link_langfuse_prompts_stream():
config={
"callbacks": [langfuse_handler],
"run_name": trace_name,
"tags": ["langchain-tag"],
"metadata": {
"langfuse_session_id": session_id,
"langfuse_user_id": user_id,
},
},
)

Expand All @@ -1908,7 +1927,13 @@ def test_link_langfuse_prompts_stream():

langfuse_handler.flush()

observations = get_api().trace.get(langfuse_handler.get_trace_id()).observations
trace = get_api().trace.get(langfuse_handler.get_trace_id())

assert trace.tags == ["langchain-tag"]
assert trace.session_id == session_id
assert trace.user_id == user_id

observations = trace.observations

generations = sorted(
list(filter(lambda x: x.type == "GENERATION", observations)),
Expand All @@ -1934,6 +1959,8 @@ def test_link_langfuse_prompts_stream():
def test_link_langfuse_prompts_batch():
langfuse = Langfuse()
trace_name = "test_link_langfuse_prompts_batch_" + create_uuid()[:8]
session_id = "session_" + create_uuid()[:8]
user_id = "user_" + create_uuid()[:8]

# Create prompts
joke_prompt_name = "joke_prompt_" + create_uuid()[:8]
Expand Down Expand Up @@ -1986,6 +2013,11 @@ def test_link_langfuse_prompts_batch():
config={
"callbacks": [langfuse_handler],
"run_name": trace_name,
"tags": ["langchain-tag"],
"metadata": {
"langfuse_session_id": session_id,
"langfuse_user_id": user_id,
},
},
)

Expand All @@ -1996,7 +2028,13 @@ def test_link_langfuse_prompts_batch():
assert len(traces) == 3

for trace in traces:
observations = get_api().trace.get(trace.id).observations
trace = get_api().trace.get(trace.id)

assert trace.tags == ["langchain-tag"]
assert trace.session_id == session_id
assert trace.user_id == user_id

observations = trace.observations

generations = sorted(
list(filter(lambda x: x.type == "GENERATION", observations)),
Expand Down

0 comments on commit 3df6394

Please sign in to comment.