Skip to content

Commit

Permalink
Merge branch 'refactor'
Browse files Browse the repository at this point in the history
  • Loading branch information
frankroeder committed Jun 28, 2024
2 parents b6aacee + 17f46b7 commit 2719099
Show file tree
Hide file tree
Showing 5 changed files with 184 additions and 85 deletions.
2 changes: 1 addition & 1 deletion lua/parrot/config.lua
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ local config = {
api_key = "",
endpoint = "https://api.perplexity.ai/chat/completions",
topic_prompt = topic_prompt,
topic_model = "mistral-7b-instruct",
topic_model = "llama-3-8b-instruct",
},
openai = {
api_key = "",
Expand Down
144 changes: 63 additions & 81 deletions lua/parrot/init.lua
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,15 @@ local utils = require("parrot.utils")
local futils = require("parrot.file_utils")
local Pool = require("parrot.pool")
local Queries = require("parrot.queries")
local State = require("parrot.state")
local ui = require("parrot.ui")
local pft = require("plenary.filetype")
local scan = require("plenary.scandir")
local init_provider = require("parrot.provider").init_provider
local Job = require("plenary.job")

local _H = {}
local M = {
_H = _H, -- helper functions
_plugin_name = "parrot.nvim",
_queries = {}, -- table of latest queries
_state = {}, -- table of state variables
providers = {},
agents = { -- table of agents
chat = {},
Expand All @@ -24,7 +21,6 @@ local M = {
config = {}, -- config variables
hooks = {}, -- user defined command functions
logger = require("parrot.logger"),
ui = ui,
}
local pool = Pool:new()
local queries = Queries:new()
Expand Down Expand Up @@ -116,7 +112,7 @@ M.setup = function(opts)
M.config[k] = v
end

-- make sure _dirs exists
-- make sure config director matching "*_dir" exist
for k, v in pairs(M.config) do
if k:match("_dir$") and type(v) == "string" then
local dir = v:gsub("/$", "")
Expand All @@ -127,29 +123,54 @@ M.setup = function(opts)
end
end

-- remove invalid agents
for name, agent in pairs(M.agents.chat) do
if type(agent) ~= "table" or not agent.model or not agent.provider then
M.logger.warning("Removing invalid agent " .. name .. " " .. vim.inspect(agent))
M.agents.chat[name] = nil
local function is_valid_provider(name, provider)
if type(provider) ~= "table" then
M.logger.warning(string.format("Removing provider %s: not a table", name))
return false
end
end
for name, agent in pairs(M.agents.command) do
if type(agent) ~= "table" or not agent.model or not agent.provider then
M.logger.warning("Removing invalid agent " .. name .. " " .. vim.inspect(agent))
M.agents.command[name] = nil
if not provider.endpoint then
M.logger.warning(string.format("Removing provider %s: endpoint missing or empty", name))
return false
end
if provider.api_key == "" then
M.logger.warning(string.format("Removing provider %s: api_key missing or empty", name))
return false
end
return true
end

-- remove invalid providers
for name, _provider in pairs(M.providers) do
if type(_provider) ~= "table" or not _provider.endpoint then
M.logger.warning("Removing invalid provider " .. name .. " " .. vim.inspect(_provider))
M.providers[name] = nil
local filtered_providers = {}
for name, provider in pairs(M.providers) do
if is_valid_provider(name, provider) then
filtered_providers[name] = provider
else
M.logger.warning(string.format("Removing provider %s: invalid configuration", name))
end
end
M.providers = filtered_providers

local filter_valid_agents = function(agents, atype)
for name, agent in pairs(agents) do
if type(agent) ~= "table" then
M.logger.warning("Removing " .. atype .. " agent " .. name .. " because it is not a table")
agents[name] = nil
elseif not agent.provider then
M.logger.warning("Removing " .. atype .. " agent " .. name .. ", provider missing")
agents[name] = nil
elseif M.providers[agent.provider] == nil then
M.logger.warning("Removing " .. atype .. " agent " .. name .. ", invalid provider")
agents[name] = nil
elseif not agent.model then
M.logger.warning("Removing " .. atype .. " agent " .. name .. ", model missing")
agents[name] = nil
end
end
return agents
end

M.agents.chat = filter_valid_agents(M.agents.chat, "chat")
M.agents.command = filter_valid_agents(M.agents.command, "command")

-- prepare agent completions
M._chat_agents = {}
M._command_agents = {}
M._available_providers = {}
Expand Down Expand Up @@ -181,7 +202,10 @@ M.setup = function(opts)
table.sort(M._available_providers)
table.sort(M._available_provider_agents)

M.refresh_state()
-- global state
Pstate = State:new(M.config.state_dir)
Pstate:refresh(M._available_providers, M._available_provider_agents)
M.prepare_commands()

-- register user commands
for hook, _ in pairs(M.hooks) do
Expand Down Expand Up @@ -250,51 +274,6 @@ M.setup = function(opts)
end
end

M.refresh_state = function()
local state_file = M.config.state_dir .. "/state.json"
local state = {}
if vim.fn.filereadable(state_file) ~= 0 then
state = futils.file_to_table(state_file) or {}
end

if next(state) == nil then
for _, prov in pairs(M._available_providers) do
state[prov] = { chat_agent = nil, command_agent = nil }
end
end

for _, prov in pairs(M._available_providers) do
if not M._state[prov] then
M._state[prov] = { chat_agent = nil, command_agent = nil }
end

if M._state[prov].chat_agent == nil then
if state[prov] == nil or state[prov].chat_agent == nil then
M._state[prov].chat_agent = M._available_provider_agents[prov].chat[1]
else
M._state[prov].chat_agent = state[prov].chat_agent
end
end

if M._state[prov].command_agent == nil then
if state[prov] == nil or state[prov].command_agent == nil then
M._state[prov].command_agent = M._available_provider_agents[prov].command[1]
else
M._state[prov].command_agent = state[prov].command_agent
end
end
end

M._state.provider = M._state.provider or state.provider or nil
if M._state.provider == nil then
M._state.provider = M._available_providers[1]
end

futils.table_to_file(M._state, state_file)

M.prepare_commands()
end

-- creates prompt commands for each target
M.prepare_commands = function()
for name, target in pairs(ui.Target) do
Expand Down Expand Up @@ -1145,7 +1124,7 @@ M.chat_respond = function(params)
messages[1] = { role = "system", content = content }
end

local messages = prov:preprocess_messages(messages)
messages = prov:preprocess_messages(messages)
-- strip whitespace from ends of content
for _, message in ipairs(messages) do
message.content = message.content:gsub("^%s*(.-)%s*$", "%1")
Expand Down Expand Up @@ -1350,8 +1329,9 @@ M.switch_provider = function(selected_prov)
end

if M.providers[selected_prov] then
M._state.provider = selected_prov
M.refresh_state()
Pstate:set_provider(selected_prov)
Pstate:refresh(M._available_providers, M._available_provider_agents)
M.prepare_commands()
M.logger.info("Switched to provider: " .. selected_prov)
return
else
Expand Down Expand Up @@ -1389,19 +1369,20 @@ M.switch_agent = function(is_chat, selected_agent, prov)
end

if is_chat and M.agents.chat[selected_agent] then
M._state[prov.name].chat_agent = selected_agent
M.logger.info("Chat agent: " .. M._state[prov.name].chat_agent)
Pstate:set_agent(prov.name, selected_agent, "chat")
M.logger.info("Chat agent: " .. Pstate:get_agent(prov.name, "chat"))
prov:check(M.agents.chat[selected_agent])
elseif is_chat then
M.logger.warning(selected_agent .. " is not a Chat agent")
elseif M.agents.command[selected_agent] then
M._state[prov.name].command_agent = selected_agent
M.logger.info("Command agent: " .. M._state[prov.name].command_agent)
Pstate:set_agent(prov.name, selected_agent, "command")
M.logger.info("Command agent: " .. Pstate:get_agent(prov.name, "command"))
prov:check(M.agents.command[selected_agent])
else
M.logger.warning(selected_agent .. " is not a Command agent")
end
M.refresh_state()
Pstate:refresh(M._available_providers, M._available_provider_agents)
M.prepare_commands()
end

M.cmd.Agent = function(params)
Expand Down Expand Up @@ -1451,8 +1432,9 @@ end
M.get_command_agent = function()
local template = M.config.command_prompt_prefix_template
local prov = M.get_provider()
local cmd_prefix = utils.template_render_from_list(template, { ["{{agent}}"] = M._state[prov.name].command_agent })
local name = M._state[prov.name].command_agent
local cmd_prefix =
utils.template_render_from_list(template, { ["{{agent}}"] = Pstate:get_agent(prov.name, "command") })
local name = Pstate:get_agent(prov.name, "command")
local model = M.agents.command[name].model
local system_prompt = M.agents.command[name].system_prompt
return {
Expand All @@ -1468,8 +1450,8 @@ end
M.get_chat_agent = function()
local template = M.config.command_prompt_prefix_template
local prov = M.get_provider()
local cmd_prefix = utils.template_render_from_list(template, { ["{{agent}}"] = M._state[prov.name].chat_agent })
local name = M._state[prov.name].chat_agent
local cmd_prefix = utils.template_render_from_list(template, { ["{{agent}}"] = Pstate:get_agent(prov.name, "chat") })
local name = Pstate:get_agent(prov.name, "chat")
local model = M.agents.chat[name].model
local system_prompt = M.agents.chat[name].system_prompt
return {
Expand All @@ -1482,7 +1464,7 @@ M.get_chat_agent = function()
end

M.get_provider = function()
local _state_prov = M._state["provider"]
local _state_prov = Pstate:get_provider()
local endpoint = M.providers[_state_prov].endpoint
local api_key = M.providers[_state_prov].api_key
return init_provider(_state_prov, endpoint, api_key)
Expand Down
108 changes: 108 additions & 0 deletions lua/parrot/state.lua
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
local futils = require("parrot.file_utils")
local utils = require("parrot.utils")

local State = {}
State.__index = State

---@param state_dir string # directory where the state file is located
---@return table # returns a new state instance
function State:new(state_dir)
local state_file = state_dir .. "/state.json"
local file_state = vim.fn.filereadable(state_file) ~= 0 and futils.file_to_table(state_file) or {}
return setmetatable({ state_file = state_file, file_state = file_state, _state = {} }, self)
end

--- Initializes file state for each provider if it's empty
---@param available_providers table # A table of available providers
---@return nil
function State:init_file_state(available_providers)
if next(self.file_state) == nil then
for _, prov in ipairs(available_providers) do
self.file_state[prov] = { chat_agent = nil, command_agent = nil }
end
end
end

---@param provider string # provider name to initialize state
---@return nil
function State:init_provider_state(provider)
self._state[provider] = self._state[provider] or { chat_agent = nil, command_agent = nil }
end

---@param provider string # Name of the provider
---@param agent_type string # Type of agent (e.g., "chat_agent", "command_agent")
---@param available_provider_agents table # A table containing available agents for all providers
function State:load_agents(provider, agent_type, available_provider_agents)
local state_agent = self.file_state and self.file_state[provider] and self.file_state[provider][agent_type]

local is_valid_agent = false
if agent_type == "chat_agent" then
is_valid_agent = utils.contains(available_provider_agents[provider].chat, state_agent)
elseif agent_type == "command_agent" then
is_valid_agent = utils.contains(available_provider_agents[provider].command, state_agent)
end

if self._state[provider][agent_type] == nil then
if state_agent and is_valid_agent then
self._state[provider][agent_type] = state_agent
else
if agent_type == "chat_agent" then
self._state[provider][agent_type] = available_provider_agents[provider].chat[1]
elseif agent_type == "command_agent" then
self._state[provider][agent_type] = available_provider_agents[provider].command[1]
end
end
end
end

---@param available_providers table # available providers
---@param available_provider_agents table # available provider agents
function State:refresh(available_providers, available_provider_agents)
self:init_file_state(available_providers)
for _, provider in ipairs(available_providers) do
self:init_provider_state(provider)
self:load_agents(provider, "chat_agent", available_provider_agents)
self:load_agents(provider, "command_agent", available_provider_agents)
end
self._state.provider = self._state.provider or self.file_state.provider or available_providers[1]
self:save()
end

---@return nil
function State:save()
futils.table_to_file(self._state, self.state_file)
end

---@param provider string # Name of the provider to set
function State:set_provider(provider)
self._state.provider = provider
end

---@param provider string # provider name
---@param agent table # agent details
---@param atype string # type of the agent ('chat' or 'command')
function State:set_agent(provider, agent, atype)
if atype == "chat" then
self._state[provider].chat_agent = agent
elseif atype == "command" then
self._state[provider].command_agent = agent
end
end

---@return string | nil # returns the current provider name, or nil if not set
function State:get_provider()
return self._state.provider
end

---@param provider string # provider name
---@param atype string # type of agent ('chat' or 'command')
---@return table | nil # returns the agent table or nil if not found
function State:get_agent(provider, atype)
if atype == "chat" then
return self._state[provider].chat_agent
elseif atype == "command" then
return self._state[provider].command_agent
end
end

return State
4 changes: 2 additions & 2 deletions lua/parrot/ui.lua
Original file line number Diff line number Diff line change
Expand Up @@ -215,7 +215,7 @@ M.input = function(opts, on_confirm)
})
opts = (opts and not vim.tbl_isempty(opts)) and opts or vim.empty_dict()

local prompt = opts.prompt or "Enter text here..."
local prompt = opts.prompt or "Enter text here... "
local hint = "(confirm with CTRL-W_q or CTRL-C)"

-- Create a new buffer
Expand All @@ -233,7 +233,7 @@ M.input = function(opts, on_confirm)
-- Add prompt and hint as virtual text
local ns_id = vim.api.nvim_create_namespace("input_prompt")
vim.api.nvim_buf_set_extmark(buf, ns_id, 0, 0, {
virt_text = { { prompt .. " " .. hint, "Comment" } },
virt_text = { { prompt .. hint, "Comment" } },
virt_text_pos = "overlay",
})

Expand Down
Loading

0 comments on commit 2719099

Please sign in to comment.