Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor Anthropic system prompt and provider message preprocessing #29

Merged
merged 19 commits into from
Jul 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
19 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ Unlike [gp.nvim](https://github.com/Robitx/gp.nvim), [parrot.nvim](https://githu
+ [perplexity.ai API](https://blog.perplexity.ai/blog/introducing-pplx-api)
+ [OpenAI API](https://platform.openai.com/)
+ [Mistral API](https://docs.mistral.ai/api/)
+ [Gemini API](https://ai.google.dev/gemini-api/docs)
+ Local and offline serving via [ollama](https://github.com/ollama/ollama)
- Custom agent definitions to determine specific prompt and API parameter combinations, similar to [GPTs](https://openai.com/index/introducing-gpts/)
- Flexible support for providing API credentials from various sources, such as environment variables, bash commands, and your favorite password manager CLI
Expand Down Expand Up @@ -71,6 +72,7 @@ Let the parrot fix your bugs.
dependencies = { 'ibhagwan/fzf-lua', 'nvim-lua/plenary.nvim' },
config = function()
require("parrot").setup {
-- Providers must be explicitly added to make them available.
providers = {
pplx = {
api_key = os.getenv "PERPLEXITY_API_KEY",
Expand All @@ -89,6 +91,9 @@ Let the parrot fix your bugs.
mistral = {
api_key = os.getenv "MISTRAL_API_KEY",
},
gemini = {
api_key = os.getenv "GEMINI_API_KEY",
},
ollama = {} -- provide an empty list to make provider available
},
}
Expand Down
64 changes: 52 additions & 12 deletions lua/parrot/agents.lua
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,21 @@ local openai_chat_agents = {
},
}

local gemini_chat_agents = {
{
name = "Gemini-1.5-Flash-Chat",
model = { model = "gemini-1.5-flash", temperature = 1.1, topP = 1, topK = 10, maxOutputTokens = 8192 },
system_prompt = system_chat_prompt,
provider = "gemini",
},
{
name = "Gemini-1.5-Pro-Chat",
model = { model = "gemini-1.5-pro", temperature = 1.1, topP = 1, topK = 10, maxOutputTokens = 8192 },
system_prompt = system_chat_prompt,
provider = "gemini",
},
}

local pplx_chat_agents = {
{
name = "Llama3-Sonar-Small-32k-Chat",
Expand Down Expand Up @@ -116,26 +131,26 @@ local pplx_chat_agents = {
local anthropic_chat_agents = {
{
name = "Claude-3.5-Sonnet-Chat",
model = { model = "claude-3-5-sonnet-20240620", max_tokens = 4096, system = system_chat_prompt },
system_prompt = "",
model = { model = "claude-3-5-sonnet-20240620", max_tokens = 4096 },
system_prompt = system_chat_prompt,
provider = "anthropic",
},
{
name = "Claude-3-Opus-Chat",
model = { model = "claude-3-opus-20240229", max_tokens = 4096, system = system_chat_prompt },
system_prompt = "",
model = { model = "claude-3-opus-20240229", max_tokens = 4096 },
system_prompt = system_chat_prompt,
provider = "anthropic",
},
{
name = "Claude-3-Sonnet-Chat",
model = { model = "claude-3-sonnet-20240229", max_tokens = 4096, system = system_chat_prompt },
system_prompt = "",
model = { model = "claude-3-sonnet-20240229", max_tokens = 4096 },
system_prompt = system_chat_prompt,
provider = "anthropic",
},
{
name = "Claude-3-Haiku-Chat",
model = { model = "claude-3-haiku-20240307", max_tokens = 4096, system = system_chat_prompt },
system_prompt = "",
model = { model = "claude-3-haiku-20240307", max_tokens = 4096 },
system_prompt = system_chat_prompt,
provider = "anthropic",
},
}
Expand Down Expand Up @@ -250,6 +265,21 @@ local openai_command_agents = {
},
}

local gemini_command_agents = {
{
name = "Gemini-1.5-Flash",
model = { model = "gemini-1.5-flash", temperature = 0.8, topP = 1, topK = 10, maxOutputTokens = 8192 },
system_prompt = system_code_prompt,
provider = "gemini",
},
{
name = "Gemini-1.5-Pro",
model = { model = "gemini-1.5-pro", temperature = 0.8, topP = 1, topK = 10, maxOutputTokens = 8192 },
system_prompt = system_code_prompt,
provider = "gemini",
},
}

local pplx_command_agents = {
{
name = "Llama3-Sonar-Small-32k--Online",
Expand Down Expand Up @@ -284,22 +314,26 @@ local pplx_command_agents = {
local anthropic_command_agents = {
{
name = "Claude-3.5-Sonnet",
model = { model = "claude-3-5-sonnet-20240620", max_tokens = 4096, system = system_code_prompt },
model = { model = "claude-3-5-sonnet-20240620", max_tokens = 4096 },
system_prompt = system_code_prompt,
provider = "anthropic",
},
{
name = "Claude-3-Opus",
model = { model = "claude-3-opus-20240229", max_tokens = 4096, system = system_code_prompt },
model = { model = "claude-3-opus-20240229", max_tokens = 4096 },
system_prompt = system_code_prompt,
provider = "anthropic",
},
{
name = "Claude-3-Sonnet",
model = { model = "claude-3-sonnet-20240229", max_tokens = 4096, system = system_code_prompt },
model = { model = "claude-3-sonnet-20240229", max_tokens = 4096 },
system_prompt = system_code_prompt,
provider = "anthropic",
},
{
name = "Claude-3-Haiku",
model = { model = "claude-3-haiku-20240307", max_tokens = 4096, system = system_code_prompt },
model = { model = "claude-3-haiku-20240307", max_tokens = 4096 },
system_prompt = system_code_prompt,
provider = "anthropic",
},
}
Expand Down Expand Up @@ -363,6 +397,9 @@ end
for _, agent in ipairs(openai_chat_agents) do
table.insert(M.chat_agents, agent)
end
for _, agent in ipairs(gemini_chat_agents) do
table.insert(M.chat_agents, agent)
end
for _, agent in ipairs(pplx_chat_agents) do
table.insert(M.chat_agents, agent)
end
Expand All @@ -380,6 +417,9 @@ end
for _, agent in ipairs(openai_command_agents) do
table.insert(M.command_agents, agent)
end
for _, agent in ipairs(gemini_command_agents) do
table.insert(M.command_agents, agent)
end
for _, agent in ipairs(pplx_command_agents) do
table.insert(M.command_agents, agent)
end
Expand Down
8 changes: 7 additions & 1 deletion lua/parrot/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,12 @@ local config = {
topic_prompt = topic_prompt,
topic_model = "gpt-3.5-turbo",
},
gemini = {
api_key = "",
endpoint = "https://generativelanguage.googleapis.com/v1beta/models/",
topic_prompt = topic_prompt,
topic_model = { model = "gemini-1.5-flash", maxOutputTokens = 64 },
},
ollama = {
endpoint = "http://localhost:11434/api/chat",
topic_prompt = [[
Expand All @@ -32,7 +38,7 @@ local config = {
api_key = "",
endpoint = "https://api.anthropic.com/v1/messages",
topic_prompt = "You only respond with 2 to 3 words to summarize the past conversation.",
topic_model = { model = "claude-3-sonnet-20240229", max_tokens = 32, system = topic_prompt },
topic_model = { model = "claude-3-sonnet-20240229", max_tokens = 32 },
},
mistral = {
api_key = "",
Expand Down
1 change: 1 addition & 0 deletions lua/parrot/health.lua
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ function M.check()
vim.health.error("require('parrot').setup() has not been called")
end
check_provider(parrot, "openai")
check_provider(parrot, "gemini")
check_provider(parrot, "ollama")
check_provider(parrot, "pplx")
check_provider(parrot, "mistral")
Expand Down
19 changes: 12 additions & 7 deletions lua/parrot/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,7 @@ M.query = function(buf, provider, payload, handler, on_exit)
queries:cleanup(8, 60)

local curl_params = vim.deepcopy(M.config.curl_params or {})
payload = provider:preprocess_payload(payload)
local args = {
"--no-buffer",
"--silent",
Expand Down Expand Up @@ -300,6 +301,7 @@ M.query = function(buf, provider, payload, handler, on_exit)
args = curl_params,
on_exit = function(j, return_val)
for _, result in ipairs(j:result()) do
-- print("EXIT", vim.inspect(result))
if type(result) == "string" then
local success, error_msg = pcall(vim.json.decode, result)
if success then
Expand All @@ -325,6 +327,7 @@ M.query = function(buf, provider, payload, handler, on_exit)
pool:remove(j.pid)
end,
on_stdout = function(j, data)
-- print("DATA", vim.inspect(data))
local chunk = process_lines(data)
if chunk then
buffer = buffer .. chunk
Expand All @@ -340,6 +343,7 @@ M.query = function(buf, provider, payload, handler, on_exit)
end
end,
on_stderr = function(j, data)
-- print("ERROR", vim.inspect(data))
M.logger.error("Error: " .. vim.inspect(data))
if j ~= nil then
M.logger.error(j:result())
Expand Down Expand Up @@ -1020,20 +1024,18 @@ M.chat_respond = function(params)
messages[1] = { role = "system", content = content }
end

messages = prov:preprocess_messages(messages)
-- strip whitespace from ends of content
for _, message in ipairs(messages) do
message.content = message.content:gsub("^%s*(.-)%s*$", "%1")
end

-- write assistant prompt
local last_content_line = utils.last_content_line(buf)
vim.api.nvim_buf_set_lines(buf, last_content_line, last_content_line, false, { "", agent_prefix .. agent_suffix, "" })

local query_prov =
init_provider(agent.provider, M.providers[agent.provider].endpoint, M.providers[agent.provider].api_key)
query_prov:set_model(agent.model)

-- call the model and write response
M.query(
buf,
init_provider(agent.provider, M.providers[agent.provider].endpoint, M.providers[agent.provider].api_key),
query_prov,
utils.prepare_payload(messages, headers.model, agent.model),
M.create_handler(buf, win, utils.last_content_line(buf), true, "", not M.config.chat_free_cursor),
vim.schedule_wrap(function(qid)
Expand Down Expand Up @@ -1078,6 +1080,7 @@ M.chat_respond = function(params)

local topic_prov = M.get_provider()
topic_prov:check({ model = M.providers[topic_prov.name].topic_model })
topic_prov:set_model(M.providers[topic_prov.name].topic_model)

-- call the model
M.query(
Expand Down Expand Up @@ -1688,6 +1691,8 @@ M.Prompt = function(params, target, prompt, model, template, system_template, ag

-- call the model and write the response
local agent = M.get_command_agent()
prov:set_model(agent.model)

M.query(
buf,
prov,
Expand Down
42 changes: 35 additions & 7 deletions lua/parrot/provider/anthropic.lua
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
local logger = require("parrot.logger")
local utils = require("parrot.utils")

local Anthropic = {}
Anthropic.__index = Anthropic
Expand All @@ -10,6 +11,24 @@ local available_model_set = {
["claude-3-haiku-20240307"] = true,
}

-- https://docs.anthropic.com/en/api/messages
local available_api_parameters = {
-- required
["model"] = true,
["messages"] = true,
-- optional
["max_tokens"] = true,
["metadata"] = true,
["stop_sequences"] = true,
["stream"] = true,
["system"] = true,
["temperature"] = true,
["tool_choice"] = true,
["tools"] = true,
["top_k"] = true,
["top_p"] = true,
}

function Anthropic:new(endpoint, api_key)
return setmetatable({
endpoint = endpoint,
Expand All @@ -18,6 +37,22 @@ function Anthropic:new(endpoint, api_key)
}, self)
end

function Anthropic:set_model(_) end

function Anthropic:preprocess_payload(payload)
for _, message in ipairs(payload.messages) do
message.content = message.content:gsub("^%s*(.-)%s*$", "%1")
end
if payload.messages[1].role == "system" then
local system_prompt = payload.messages[1].content
-- remove the first message that serves as the system prompt as anthropic
-- expects the system prompt to be part of the curl request and not the messages
table.remove(payload.messages, 1)
payload.system = system_prompt
end
return utils.filter_payload_parameters(available_api_parameters, payload)
end

function Anthropic:curl_params()
return {
self.endpoint,
Expand All @@ -40,13 +75,6 @@ function Anthropic:verify()
end
end

function Anthropic:preprocess_messages(messages)
-- remove the first message that serves as the system prompt as anthropic
-- expects the system prompt to be part of the curl request and not the messages
table.remove(messages, 1)
return messages
end

function Anthropic:add_system_prompt(messages, _)
return messages
end
Expand Down
Loading