diff --git a/internal/models/models.go b/internal/models/models.go index fa7dfb8..68d48b7 100644 --- a/internal/models/models.go +++ b/internal/models/models.go @@ -38,12 +38,6 @@ type CompletionEvent any type NoopEvent struct{} -type Chat struct { - Created time.Time `json:"created,omitempty"` - ID string `json:"id"` - Messages []Message `json:"messages"` -} - type Message struct { Role string `json:"role"` Content string `json:"content,omitempty"` @@ -51,6 +45,12 @@ type Message struct { ToolCallID string `json:"tool_call_id,omitempty"` } +type Chat struct { + Created time.Time `json:"created,omitempty"` + ID string `json:"id"` + Messages []Message `json:"messages"` +} + // FirstSystemMessage returns the first encountered Message with role 'system' func (c *Chat) FirstSystemMessage() (Message, error) { for _, msg := range c.Messages { diff --git a/internal/text/querier_setup.go b/internal/text/querier_setup.go index 04f7bbb..88f2614 100644 --- a/internal/text/querier_setup.go +++ b/internal/text/querier_setup.go @@ -74,6 +74,8 @@ func NewQuerier[C models.StreamCompleter](userConf Configurations, dfault C) (Qu toolBox.RegisterTool(tools.Find) toolBox.RegisterTool(tools.WebsiteText) toolBox.RegisterTool(tools.RipGrep) + toolBox.RegisterTool(tools.Go) + toolBox.RegisterTool(tools.WriteFile) } err = modelConf.Setup() diff --git a/internal/tools/handler.go b/internal/tools/handler.go index 0fbe5fd..f7b3217 100644 --- a/internal/tools/handler.go +++ b/internal/tools/handler.go @@ -26,7 +26,12 @@ func Invoke(call Call) string { out, err = WebsiteText.Call(call.Inputs) case "rg": out, err = RipGrep.Call(call.Inputs) + case "go": + out, err = Go.Call(call.Inputs) + case "write_file": + out, err = WriteFile.Call(call.Inputs) default: + // This error is a string as it's being returned to the LLM return "ERROR: unknown tool call: " + call.Name } if err != nil { @@ -47,6 +52,14 @@ func UserFunctionFromName(name string) UserFunction { return FileType.UserFunction() case "ls": return LS.UserFunction() + case "website_text": + return WebsiteText.UserFunction() + case "rg": + return RipGrep.UserFunction() + case "go": + return Go.UserFunction() + case "write_file": + return WriteFile.UserFunction() default: return UserFunction{} } diff --git a/internal/tools/models.go b/internal/tools/models.go index a05fbf0..164b37b 100644 --- a/internal/tools/models.go +++ b/internal/tools/models.go @@ -33,7 +33,7 @@ type Call struct { } func (c Call) Json() string { - json, err := json.Marshal(c) + json, err := json.MarshalIndent(c, "", " ") if err != nil { return fmt.Sprintf("ERROR: Failed to unmarshal: %v", err) } diff --git a/internal/tools/programming_tool_go.go b/internal/tools/programming_tool_go.go new file mode 100644 index 0000000..e738f0d --- /dev/null +++ b/internal/tools/programming_tool_go.go @@ -0,0 +1,62 @@ +package tools + +import ( + "fmt" + "os/exec" + "strings" +) + +type GoTool UserFunction + +var Go = GoTool{ + Name: "go", + Description: "Run Go commands like 'go test' and 'go run' to compile, test, and run Go programs. Run 'go help' to get details of this tool.", + Inputs: &InputSchema{ + Type: "object", + Properties: map[string]ParameterObject{ + "command": { + Type: "string", + Description: "The Go command to run (e.g., 'run', 'test', 'build').", + }, + "args": { + Type: "string", + Description: "Additional arguments for the Go command (e.g., file names, flags).", + }, + "dir": { + Type: "string", + Description: "The directory to run the command in (optional, defaults to current directory).", + }, + }, + Required: []string{"command"}, + }, +} + +func (g GoTool) Call(input Input) (string, error) { + command, ok := input["command"].(string) + if !ok { + return "", fmt.Errorf("command must be a string") + } + + args := []string{command} + + if inputArgs, ok := input["args"].(string); ok { + args = append(args, strings.Fields(inputArgs)...) + } + + cmd := exec.Command("go", args...) + + if dir, ok := input["dir"].(string); ok { + cmd.Dir = dir + } + + output, err := cmd.CombinedOutput() + if err != nil { + return "", fmt.Errorf("failed to run go command: %w, output: %v", err, string(output)) + } + + return string(output), nil +} + +func (g GoTool) UserFunction() UserFunction { + return UserFunction(Go) +} diff --git a/internal/tools/programming_tool_write_file.go b/internal/tools/programming_tool_write_file.go new file mode 100644 index 0000000..bdbb1ab --- /dev/null +++ b/internal/tools/programming_tool_write_file.go @@ -0,0 +1,82 @@ +package tools + +import ( + "fmt" + "os" + "path/filepath" +) + +type WriteFileTool UserFunction + +var WriteFile = WriteFileTool{ + Name: "write_file", + Description: "Write content to a file. Creates the file if it doesn't exist, or overwrites it if it does.", + Inputs: &InputSchema{ + Type: "object", + Properties: map[string]ParameterObject{ + "file_path": { + Type: "string", + Description: "The path to the file to write to.", + }, + "content": { + Type: "string", + Description: "The content to write to the file.", + }, + "append": { + Type: "boolean", + Description: "If true, append to the file instead of overwriting it.", + }, + }, + Required: []string{"file_path", "content"}, + }, +} + +func (w WriteFileTool) Call(input Input) (string, error) { + filePath, ok := input["file_path"].(string) + if !ok { + return "", fmt.Errorf("file_path must be a string") + } + + content, ok := input["content"].(string) + if !ok { + return "", fmt.Errorf("content must be a string") + } + + append := false + if input["append"] != nil { + append, ok = input["append"].(bool) + if !ok { + return "", fmt.Errorf("append must be a boolean") + } + } + + // Ensure the directory exists + dir := filepath.Dir(filePath) + if err := os.MkdirAll(dir, 0o755); err != nil { + return "", fmt.Errorf("failed to create directory: %w", err) + } + + var flag int + if append { + flag = os.O_APPEND | os.O_CREATE | os.O_WRONLY + } else { + flag = os.O_TRUNC | os.O_CREATE | os.O_WRONLY + } + + file, err := os.OpenFile(filePath, flag, 0o644) + if err != nil { + return "", fmt.Errorf("failed to open file: %w", err) + } + defer file.Close() + + _, err = file.WriteString(content) + if err != nil { + return "", fmt.Errorf("failed to write to file: %w", err) + } + + return fmt.Sprintf("Successfully wrote %d bytes to %s", len(content), filePath), nil +} + +func (w WriteFileTool) UserFunction() UserFunction { + return UserFunction(WriteFile) +} diff --git a/internal/tools/programming_tool_write_file_test.go b/internal/tools/programming_tool_write_file_test.go new file mode 100644 index 0000000..7ccadfe --- /dev/null +++ b/internal/tools/programming_tool_write_file_test.go @@ -0,0 +1,140 @@ +package tools + +import ( + "os" + "path/filepath" + "testing" +) + +func TestWriteFileTool_Call(t *testing.T) { + tempDir := t.TempDir() + + tests := []struct { + name string + input Input + wantErr bool + check func(t *testing.T, filePath string) + }{ + { + name: "write new file", + input: Input{ + "file_path": filepath.Join(tempDir, "test1.txt"), + "content": "Hello, World!", + }, + wantErr: false, + check: func(t *testing.T, filePath string) { + content, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("Failed to read file: %v", err) + } + if string(content) != "Hello, World!" { + t.Errorf("Expected content 'Hello, World!', got '%s'", string(content)) + } + }, + }, + { + name: "overwrite existing file", + input: Input{ + "file_path": filepath.Join(tempDir, "test2.txt"), + "content": "New content", + }, + wantErr: false, + check: func(t *testing.T, filePath string) { + content, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("Failed to read file: %v", err) + } + if string(content) != "New content" { + t.Errorf("Expected content 'New content', got '%s'", string(content)) + } + }, + }, + { + name: "append to existing file", + input: Input{ + "file_path": filepath.Join(tempDir, "test3.txt"), + "content": " Appended content", + "append": true, + }, + wantErr: false, + check: func(t *testing.T, filePath string) { + content, err := os.ReadFile(filePath) + if err != nil { + t.Fatalf("Failed to read file: %v", err) + } + if string(content) != "Initial content Appended content" { + t.Errorf("Expected content 'Initial content Appended content', got '%s'", string(content)) + } + }, + }, + { + name: "missing file_path", + input: Input{ + "content": "Some content", + }, + wantErr: true, + }, + { + name: "missing content", + input: Input{ + "file_path": filepath.Join(tempDir, "test4.txt"), + }, + wantErr: true, + }, + { + name: "invalid append type", + input: Input{ + "file_path": filepath.Join(tempDir, "test5.txt"), + "content": "Some content", + "append": "true", + }, + wantErr: true, + }, + } + + writeTool := WriteFileTool{} + + // Set up file for append test + if err := os.WriteFile(filepath.Join(tempDir, "test3.txt"), []byte("Initial content"), 0o644); err != nil { + t.Fatalf("Failed to set up append test: %v", err) + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result, err := writeTool.Call(tt.input) + if (err != nil) != tt.wantErr { + t.Errorf("WriteFileTool.Call() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !tt.wantErr { + if result == "" { + t.Errorf("WriteFileTool.Call() returned empty result") + } + if tt.check != nil { + tt.check(t, tt.input["file_path"].(string)) + } + } + }) + } +} + +func TestWriteFileTool_UserFunction(t *testing.T) { + writeTool := WriteFileTool{} + userFunc := writeTool.UserFunction() + + if userFunc.Name != "write_file" { + t.Errorf("Expected name 'write_file', got '%s'", userFunc.Name) + } + + if userFunc.Description != "Write content to a file. Creates the file if it doesn't exist, or overwrites it if it does." { + t.Errorf("Unexpected description: %s", userFunc.Description) + } + + if len(userFunc.Inputs.Required) != 2 || userFunc.Inputs.Required[0] != "file_path" || userFunc.Inputs.Required[1] != "content" { + t.Errorf("Unexpected required inputs: %v", userFunc.Inputs.Required) + } + + if len(userFunc.Inputs.Properties) != 3 { + t.Errorf("Expected 3 properties, got %d", len(userFunc.Inputs.Properties)) + } +}