diff --git a/langfuse/callback/langchain.py b/langfuse/callback/langchain.py index 1d0f73c0..24e45c8e 100644 --- a/langfuse/callback/langchain.py +++ b/langfuse/callback/langchain.py @@ -1,3 +1,4 @@ +from collections import defaultdict import httpx import logging import typing @@ -110,6 +111,7 @@ def __init__( self.runs = {} self.prompt_to_parent_run_map = {} + self.trace_updates = defaultdict(dict) if stateful_client and isinstance(stateful_client, StatefulSpanClient): self.runs[stateful_client.id] = stateful_client @@ -214,6 +216,17 @@ def on_chain_start( run_id=run_id, parent_run_id=parent_run_id, metadata=metadata ) + # Update trace-level information if this is a root-level chain (no parent) + # and if tags or metadata are provided + if parent_run_id is None and (tags or metadata): + self.trace_updates[run_id].update( + { + "tags": tags and [str(tag) for tag in tags], + "session_id": metadata and metadata.get("langfuse_session_id"), + "user_id": metadata and metadata.get("langfuse_user_id"), + } + ) + content = { "id": self.next_span_id, "trace_id": self.trace.id, @@ -885,6 +898,8 @@ def _update_trace_and_remove_state( **kwargs: Any, ): """Update the trace with the output of the current run. Called at every finish callback event.""" + chain_trace_updates = self.trace_updates.pop(run_id, {}) + if ( parent_run_id is None # If we are at the root of the langchain execution -> reached the end of the root @@ -892,7 +907,9 @@ def _update_trace_and_remove_state( and self.trace.id == str(run_id) # The trace was generated by langchain and not by the user ): - self.trace = self.trace.update(output=output, **kwargs) + self.trace = self.trace.update( + output=output, **chain_trace_updates, **kwargs + ) elif ( parent_run_id is None @@ -900,9 +917,13 @@ def _update_trace_and_remove_state( and self.update_stateful_client ): if self.root_span is not None: - self.root_span = self.root_span.update(output=output, **kwargs) + self.root_span = self.root_span.update( + output=output, **kwargs + ) # No trace updates if root_span was user provided else: - self.trace = self.trace.update(output=output, **kwargs) + self.trace = self.trace.update( + output=output, **chain_trace_updates, **kwargs + ) elif parent_run_id is None and self.langfuse is not None: """ @@ -910,7 +931,9 @@ def _update_trace_and_remove_state( For the rest of the runs, the trace must be manually updated The check for self.langfuse ensures that no stateful client was provided by the user """ - self.langfuse.trace(id=str(run_id)).update(output=output, **kwargs) + self.langfuse.trace(id=str(run_id)).update( + output=output, **chain_trace_updates, **kwargs + ) if not keep_state: del self.runs[run_id] @@ -1081,10 +1104,15 @@ def _strip_langfuse_keys_from_dict(metadata: Optional[Dict[str, Any]]): if metadata is None or not isinstance(metadata, dict): return metadata - langfuse_keys = ["langfuse_prompt"] + langfuse_metadata_keys = [ + "langfuse_prompt", + "langfuse_session_id", + "langfuse_user_id", + ] + metadata_copy = metadata.copy() - for key in langfuse_keys: + for key in langfuse_metadata_keys: metadata_copy.pop(key, None) return metadata_copy diff --git a/tests/test_langchain.py b/tests/test_langchain.py index 00f6ada2..577a3df1 100644 --- a/tests/test_langchain.py +++ b/tests/test_langchain.py @@ -14,10 +14,9 @@ ) from langchain.chains.openai_functions import create_openai_fn_chain from langchain.chains.summarize import load_summarize_chain -from langchain_community.chat_models import AzureChatOpenAI, ChatOpenAI from langchain.document_loaders import TextLoader from langchain.embeddings.openai import OpenAIEmbeddings -from langchain.llms import OpenAI +from langchain_openai import OpenAI, AzureChatOpenAI, ChatOpenAI from langchain.memory import ConversationBufferMemory from langchain.prompts import ChatPromptTemplate, PromptTemplate from langchain.schema import Document @@ -1767,6 +1766,8 @@ def test_disabled_langfuse(): def test_link_langfuse_prompts_invoke(): langfuse = Langfuse() trace_name = "test_link_langfuse_prompts_invoke" + session_id = "session_" + create_uuid()[:8] + user_id = "user_" + create_uuid()[:8] # Create prompts joke_prompt_name = "joke_prompt_" + create_uuid()[:8] @@ -1819,12 +1820,23 @@ def test_link_langfuse_prompts_invoke(): config={ "callbacks": [langfuse_handler], "run_name": trace_name, + "tags": ["langchain-tag"], + "metadata": { + "langfuse_session_id": session_id, + "langfuse_user_id": user_id, + }, }, ) langfuse_handler.flush() - observations = get_api().trace.get(langfuse_handler.get_trace_id()).observations + trace = get_api().trace.get(langfuse_handler.get_trace_id()) + + assert trace.tags == ["langchain-tag"] + assert trace.session_id == session_id + assert trace.user_id == user_id + + observations = trace.observations generations = sorted( list(filter(lambda x: x.type == "GENERATION", observations)), @@ -1847,6 +1859,8 @@ def test_link_langfuse_prompts_invoke(): def test_link_langfuse_prompts_stream(): langfuse = Langfuse(debug=True) trace_name = "test_link_langfuse_prompts_stream" + session_id = "session_" + create_uuid()[:8] + user_id = "user_" + create_uuid()[:8] # Create prompts joke_prompt_name = "joke_prompt_" + create_uuid()[:8] @@ -1899,6 +1913,11 @@ def test_link_langfuse_prompts_stream(): config={ "callbacks": [langfuse_handler], "run_name": trace_name, + "tags": ["langchain-tag"], + "metadata": { + "langfuse_session_id": session_id, + "langfuse_user_id": user_id, + }, }, ) @@ -1908,7 +1927,13 @@ def test_link_langfuse_prompts_stream(): langfuse_handler.flush() - observations = get_api().trace.get(langfuse_handler.get_trace_id()).observations + trace = get_api().trace.get(langfuse_handler.get_trace_id()) + + assert trace.tags == ["langchain-tag"] + assert trace.session_id == session_id + assert trace.user_id == user_id + + observations = trace.observations generations = sorted( list(filter(lambda x: x.type == "GENERATION", observations)), @@ -1934,6 +1959,8 @@ def test_link_langfuse_prompts_stream(): def test_link_langfuse_prompts_batch(): langfuse = Langfuse() trace_name = "test_link_langfuse_prompts_batch_" + create_uuid()[:8] + session_id = "session_" + create_uuid()[:8] + user_id = "user_" + create_uuid()[:8] # Create prompts joke_prompt_name = "joke_prompt_" + create_uuid()[:8] @@ -1986,6 +2013,11 @@ def test_link_langfuse_prompts_batch(): config={ "callbacks": [langfuse_handler], "run_name": trace_name, + "tags": ["langchain-tag"], + "metadata": { + "langfuse_session_id": session_id, + "langfuse_user_id": user_id, + }, }, ) @@ -1996,7 +2028,13 @@ def test_link_langfuse_prompts_batch(): assert len(traces) == 3 for trace in traces: - observations = get_api().trace.get(trace.id).observations + trace = get_api().trace.get(trace.id) + + assert trace.tags == ["langchain-tag"] + assert trace.session_id == session_id + assert trace.user_id == user_id + + observations = trace.observations generations = sorted( list(filter(lambda x: x.type == "GENERATION", observations)),