Skip to content

Commit

Permalink
Added tools write_file and go (#15)
Browse files Browse the repository at this point in the history
These tools are _really_ cool. The test programming_tool_write_file_test.go
 was written by simply running the command:

go run . -t q Write tests for the write_file tool and verify its working

After that, it iteratively wrote and updated the test file
on its own. Which is huge!
  • Loading branch information
baalimago authored Jun 27, 2024
1 parent f566d2a commit c561591
Show file tree
Hide file tree
Showing 7 changed files with 306 additions and 7 deletions.
12 changes: 6 additions & 6 deletions internal/models/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,19 +38,19 @@ type CompletionEvent any

type NoopEvent struct{}

type Chat struct {
Created time.Time `json:"created,omitempty"`
ID string `json:"id"`
Messages []Message `json:"messages"`
}

type Message struct {
Role string `json:"role"`
Content string `json:"content,omitempty"`
ToolCalls []tools.Call `json:"tool_calls,omitempty"`
ToolCallID string `json:"tool_call_id,omitempty"`
}

type Chat struct {
Created time.Time `json:"created,omitempty"`
ID string `json:"id"`
Messages []Message `json:"messages"`
}

// FirstSystemMessage returns the first encountered Message with role 'system'
func (c *Chat) FirstSystemMessage() (Message, error) {
for _, msg := range c.Messages {
Expand Down
2 changes: 2 additions & 0 deletions internal/text/querier_setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@ func NewQuerier[C models.StreamCompleter](userConf Configurations, dfault C) (Qu
toolBox.RegisterTool(tools.Find)
toolBox.RegisterTool(tools.WebsiteText)
toolBox.RegisterTool(tools.RipGrep)
toolBox.RegisterTool(tools.Go)
toolBox.RegisterTool(tools.WriteFile)
}

err = modelConf.Setup()
Expand Down
13 changes: 13 additions & 0 deletions internal/tools/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,12 @@ func Invoke(call Call) string {
out, err = WebsiteText.Call(call.Inputs)
case "rg":
out, err = RipGrep.Call(call.Inputs)
case "go":
out, err = Go.Call(call.Inputs)
case "write_file":
out, err = WriteFile.Call(call.Inputs)
default:
// This error is a string as it's being returned to the LLM
return "ERROR: unknown tool call: " + call.Name
}
if err != nil {
Expand All @@ -47,6 +52,14 @@ func UserFunctionFromName(name string) UserFunction {
return FileType.UserFunction()
case "ls":
return LS.UserFunction()
case "website_text":
return WebsiteText.UserFunction()
case "rg":
return RipGrep.UserFunction()
case "go":
return Go.UserFunction()
case "write_file":
return WriteFile.UserFunction()
default:
return UserFunction{}
}
Expand Down
2 changes: 1 addition & 1 deletion internal/tools/models.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ type Call struct {
}

func (c Call) Json() string {
json, err := json.Marshal(c)
json, err := json.MarshalIndent(c, "", " ")
if err != nil {
return fmt.Sprintf("ERROR: Failed to unmarshal: %v", err)
}
Expand Down
62 changes: 62 additions & 0 deletions internal/tools/programming_tool_go.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
package tools

import (
"fmt"
"os/exec"
"strings"
)

type GoTool UserFunction

var Go = GoTool{
Name: "go",
Description: "Run Go commands like 'go test' and 'go run' to compile, test, and run Go programs. Run 'go help' to get details of this tool.",
Inputs: &InputSchema{
Type: "object",
Properties: map[string]ParameterObject{
"command": {
Type: "string",
Description: "The Go command to run (e.g., 'run', 'test', 'build').",
},
"args": {
Type: "string",
Description: "Additional arguments for the Go command (e.g., file names, flags).",
},
"dir": {
Type: "string",
Description: "The directory to run the command in (optional, defaults to current directory).",
},
},
Required: []string{"command"},
},
}

func (g GoTool) Call(input Input) (string, error) {
command, ok := input["command"].(string)
if !ok {
return "", fmt.Errorf("command must be a string")
}

args := []string{command}

if inputArgs, ok := input["args"].(string); ok {
args = append(args, strings.Fields(inputArgs)...)
}

cmd := exec.Command("go", args...)

if dir, ok := input["dir"].(string); ok {
cmd.Dir = dir
}

output, err := cmd.CombinedOutput()
if err != nil {
return "", fmt.Errorf("failed to run go command: %w, output: %v", err, string(output))
}

return string(output), nil
}

func (g GoTool) UserFunction() UserFunction {
return UserFunction(Go)
}
82 changes: 82 additions & 0 deletions internal/tools/programming_tool_write_file.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
package tools

import (
"fmt"
"os"
"path/filepath"
)

type WriteFileTool UserFunction

var WriteFile = WriteFileTool{
Name: "write_file",
Description: "Write content to a file. Creates the file if it doesn't exist, or overwrites it if it does.",
Inputs: &InputSchema{
Type: "object",
Properties: map[string]ParameterObject{
"file_path": {
Type: "string",
Description: "The path to the file to write to.",
},
"content": {
Type: "string",
Description: "The content to write to the file.",
},
"append": {
Type: "boolean",
Description: "If true, append to the file instead of overwriting it.",
},
},
Required: []string{"file_path", "content"},
},
}

func (w WriteFileTool) Call(input Input) (string, error) {
filePath, ok := input["file_path"].(string)
if !ok {
return "", fmt.Errorf("file_path must be a string")
}

content, ok := input["content"].(string)
if !ok {
return "", fmt.Errorf("content must be a string")
}

append := false
if input["append"] != nil {
append, ok = input["append"].(bool)
if !ok {
return "", fmt.Errorf("append must be a boolean")
}
}

// Ensure the directory exists
dir := filepath.Dir(filePath)
if err := os.MkdirAll(dir, 0o755); err != nil {
return "", fmt.Errorf("failed to create directory: %w", err)
}

var flag int
if append {
flag = os.O_APPEND | os.O_CREATE | os.O_WRONLY
} else {
flag = os.O_TRUNC | os.O_CREATE | os.O_WRONLY
}

file, err := os.OpenFile(filePath, flag, 0o644)
if err != nil {
return "", fmt.Errorf("failed to open file: %w", err)
}
defer file.Close()

_, err = file.WriteString(content)
if err != nil {
return "", fmt.Errorf("failed to write to file: %w", err)
}

return fmt.Sprintf("Successfully wrote %d bytes to %s", len(content), filePath), nil
}

func (w WriteFileTool) UserFunction() UserFunction {
return UserFunction(WriteFile)
}
140 changes: 140 additions & 0 deletions internal/tools/programming_tool_write_file_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
package tools

import (
"os"
"path/filepath"
"testing"
)

func TestWriteFileTool_Call(t *testing.T) {
tempDir := t.TempDir()

tests := []struct {
name string
input Input
wantErr bool
check func(t *testing.T, filePath string)
}{
{
name: "write new file",
input: Input{
"file_path": filepath.Join(tempDir, "test1.txt"),
"content": "Hello, World!",
},
wantErr: false,
check: func(t *testing.T, filePath string) {
content, err := os.ReadFile(filePath)
if err != nil {
t.Fatalf("Failed to read file: %v", err)
}
if string(content) != "Hello, World!" {
t.Errorf("Expected content 'Hello, World!', got '%s'", string(content))
}
},
},
{
name: "overwrite existing file",
input: Input{
"file_path": filepath.Join(tempDir, "test2.txt"),
"content": "New content",
},
wantErr: false,
check: func(t *testing.T, filePath string) {
content, err := os.ReadFile(filePath)
if err != nil {
t.Fatalf("Failed to read file: %v", err)
}
if string(content) != "New content" {
t.Errorf("Expected content 'New content', got '%s'", string(content))
}
},
},
{
name: "append to existing file",
input: Input{
"file_path": filepath.Join(tempDir, "test3.txt"),
"content": " Appended content",
"append": true,
},
wantErr: false,
check: func(t *testing.T, filePath string) {
content, err := os.ReadFile(filePath)
if err != nil {
t.Fatalf("Failed to read file: %v", err)
}
if string(content) != "Initial content Appended content" {
t.Errorf("Expected content 'Initial content Appended content', got '%s'", string(content))
}
},
},
{
name: "missing file_path",
input: Input{
"content": "Some content",
},
wantErr: true,
},
{
name: "missing content",
input: Input{
"file_path": filepath.Join(tempDir, "test4.txt"),
},
wantErr: true,
},
{
name: "invalid append type",
input: Input{
"file_path": filepath.Join(tempDir, "test5.txt"),
"content": "Some content",
"append": "true",
},
wantErr: true,
},
}

writeTool := WriteFileTool{}

// Set up file for append test
if err := os.WriteFile(filepath.Join(tempDir, "test3.txt"), []byte("Initial content"), 0o644); err != nil {
t.Fatalf("Failed to set up append test: %v", err)
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := writeTool.Call(tt.input)
if (err != nil) != tt.wantErr {
t.Errorf("WriteFileTool.Call() error = %v, wantErr %v", err, tt.wantErr)
return
}
if !tt.wantErr {
if result == "" {
t.Errorf("WriteFileTool.Call() returned empty result")
}
if tt.check != nil {
tt.check(t, tt.input["file_path"].(string))
}
}
})
}
}

func TestWriteFileTool_UserFunction(t *testing.T) {
writeTool := WriteFileTool{}
userFunc := writeTool.UserFunction()

if userFunc.Name != "write_file" {
t.Errorf("Expected name 'write_file', got '%s'", userFunc.Name)
}

if userFunc.Description != "Write content to a file. Creates the file if it doesn't exist, or overwrites it if it does." {
t.Errorf("Unexpected description: %s", userFunc.Description)
}

if len(userFunc.Inputs.Required) != 2 || userFunc.Inputs.Required[0] != "file_path" || userFunc.Inputs.Required[1] != "content" {
t.Errorf("Unexpected required inputs: %v", userFunc.Inputs.Required)
}

if len(userFunc.Inputs.Properties) != 3 {
t.Errorf("Expected 3 properties, got %d", len(userFunc.Inputs.Properties))
}
}

0 comments on commit c561591

Please sign in to comment.