Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add genie tool support and add the trace details in databricks mlflow tracing #22

Open
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

stikkireddy
Copy link
Contributor

@stikkireddy stikkireddy commented Nov 13, 2024

Items:

  • tested with agent executor tools
  • tested with langgraph tools
  • tests for bridge library
  • tests for the langchain integration
  • README update

It is implemented by creating two nested tool calls to simulate invoking a tool to be picked up by databricks tracing. This lets you only expose query results to the llm but provide a trace for the query description, query and results in the tracing UI. Look at the usage and screenshots below. It uses tool with response_format: str = "content_and_artifact" which when invoked will include a trace with the genie tool call and a detailed nested trace right underneath it.

You can test in Databricks with the preview enabled via:

%pip install -U --force-reinstall --no-deps git+https://github.com/stikkireddy/databricks-ai-bridge.git@feat-return-query-and-response
%pip install -U --force-reinstall --no-deps git+https://github.com/stikkireddy/databricks-ai-bridge.git@feat-return-query-and-response#subdirectory=integrations/langchain

AgentExecutor Usage:

import mlflow
from mlflow.models import ModelConfig

mlflow.langchain.autolog()
config = ModelConfig(development_config="config.yml")

from langchain_databricks import ChatDatabricks

# Create the llm
llm = ChatDatabricks(endpoint=config.get("llm_endpoint"))

from databricks_langchain.genie import GenieTool

tools = [
  GenieTool(
    genie_space_id=genie_space_id, 
    genie_agent_name=genie_agent_name, 
    genie_space_description=genie_space_description
  )
]

from langchain import hub

# Get the prompt to use - can be replaced with any prompt that includes variables "agent_scratchpad" and "input"!
prompt = hub.pull("hwchase17/openai-tools-agent")

from langchain.agents import AgentExecutor, create_tool_calling_agent
agent = create_tool_calling_agent(llm, tools, prompt)
agent_executor = AgentExecutor(agent=agent, tools=tools, verbose=True)

from mlflow.langchain.output_parsers import ChatCompletionsOutputParser
output_parser = ChatCompletionsOutputParser()


from langchain_core.runnables import RunnableLambda

# needed because agent executor returns {"input": "...", "output": "..."}
def agent_executor_to_just_response(inp):
    return inp["output"]
  
def pre_process_input(inp):
    # this is needed to conform to agent executor input which requires input and agent_scratchpad
    return {
        "input": inp
    }

chain = RunnableLambda(pre_process_input) | agent_executor | RunnableLambda(agent_executor_to_just_response) | output_parser

chain.invoke("what are the top 10 locations that show the most co2 emissions?")

LangGraph

import mlflow
from mlflow.models import ModelConfig

mlflow.langchain.autolog()
config = ModelConfig(development_config="config.yml")

from langchain_databricks import ChatDatabricks

# Create the llm
llm = ChatDatabricks(endpoint=config.get("llm_endpoint"))

import re

def clean_string(input_string: str) -> str:
    cleaned = re.sub(r'[^a-zA-Z0-9\s]', '_', input_string)
    cleaned = re.sub(r'\s+', '_', cleaned)
    cleaned = re.sub(r'_+', '_', cleaned)
    return cleaned.strip('_').lower()


genie_space_id = config.get("genie_space_id")
_genie_agent_name = config.get("genie_agent_name")
genie_space_description = config.get("genie_space_description")

assert genie_space_id, f"Configure the genie_space_id in config.yml it is: {genie_space_id}"
assert _genie_agent_name, f"Configure the genie_agent_name in config.yml it is: {_genie_agent_name}"
assert genie_space_id, f"Configure the genie_space_description in config.yml it is: {genie_space_description}"

genie_agent_name = clean_string(_genie_agent_name)

from databricks_langchain.genie import GenieTool

tools = [
  GenieTool(
    genie_space_id=genie_space_id, 
    genie_agent_name=genie_agent_name, 
    genie_space_description=genie_space_description
  )
]

import json
from typing import Iterator, Dict, Any

from langgraph.prebuilt import create_react_agent
from langchain_core.runnables import RunnableGenerator
from langchain_core.messages import (
    AIMessage,
    HumanMessage,
    ToolMessage,
    MessageLikeRepresentation,
)
from mlflow.langchain.output_parsers import ChatCompletionsOutputParser

agent = create_react_agent(
    llm,
    tools,
    state_modifier="You are a helpful assistant. Make sure to use tool for information.",
)

def stringify_tool_call(tool_call: Dict[str, Any]) -> str:
    """
    Convert a raw tool call into a formatted string that the playground UI expects if there is enough information in the tool_call
    """
    try:
        request = json.dumps(
            {
                "id": tool_call.get("id"),
                "name": tool_call.get("name"),
                "arguments": json.dumps(tool_call.get("args", {})),
            },
            indent=2,
        )
        return f"<tool_call>{request}</tool_call>"
    except:
        return str(tool_call)
    
def stringify_tool_result(tool_msg: ToolMessage) -> str:
    """
    Convert a ToolMessage into a formatted string that the playground UI expects if there is enough information in the ToolMessage
    """
    try:
        result = json.dumps(
            {"id": tool_msg.tool_call_id, "content": tool_msg.content}, indent=2
        )
        return f"<tool_call_result>{result}</tool_call_result>"
    except:
        return str(tool_msg)


def parse_message(msg) -> str:
    """Parse different message types into their string representations"""
    # tool call result
    if isinstance(msg, ToolMessage):
        return stringify_tool_result(msg)
    # tool call
    elif isinstance(msg, AIMessage) and msg.tool_calls:
        tool_call_results = [stringify_tool_call(call) for call in msg.tool_calls]
        return "".join(tool_call_results)
    # normal HumanMessage or AIMessage (reasoning or final answer)
    elif isinstance(msg, (AIMessage, HumanMessage)):
        return msg.content
    else:
        print(f"Unexpected message type: {type(msg)}")
        return str(msg)

def wrap_output(stream: Iterator[MessageLikeRepresentation]) -> Iterator[str]:
    """
    Process and yield formatted outputs from the message stream.
    The invoke and stream langchain functions produce different output formats.
    This function handles both cases.
    """
    for event in stream:
        # the agent was called with invoke()
        if "messages" in event:
            for msg in event["messages"]:
                yield parse_message(msg) + "\n\n"
        # the agent was called with stream()
        else:
            for node in event:
                for key, messages in event[node].items():
                    if isinstance(messages, list):
                        for msg in messages:
                            yield parse_message(msg) + "\n\n"
                    else:
                        print("Unexpected value {messages} for key {key}. Expected a list of `MessageLikeRepresentation`'s")
                        yield str(messages)


# modify wrap input to make this simpler
chain = agent | RunnableGenerator(wrap_output) | ChatCompletionsOutputParser()

chain.invoke({"messages": [{"role": "user", "content": "what are the top 10 locations that show the most co2 emissions?"}]})

image
image
image

@stikkireddy stikkireddy changed the title WIP: Add genie tool support and add the trace details in databricks mlflow tracing Add genie tool support and add the trace details in databricks mlflow tracing Nov 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

1 participant