Skip to content

Commit

Permalink
Cleanup and Update Generic API (#17)
Browse files Browse the repository at this point in the history
* update api

* update readme

---------

Co-authored-by: Robby <[email protected]>
  • Loading branch information
h0rv and h0rv authored May 20, 2024
1 parent 92b5b03 commit d5c9535
Show file tree
Hide file tree
Showing 14 changed files with 163 additions and 130 deletions.
8 changes: 5 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,16 +38,17 @@ type Person struct {
func main() {
ctx := context.Background()

client, err := instructor.FromOpenAI[Person](
client, err := instructor.FromOpenAI(
openai.NewClient(os.Getenv("OPENAI_API_KEY")),
instructor.WithMode(instructor.ModeJSON),
instructor.WithMaxRetries(5),
instructor.WithMaxRetries(3),
)
if err != nil {
panic(err)
}

person, err := client.CreateChatCompletion(
var person Person
err = client.CreateChatCompletion(
ctx,
instructor.Request{
Model: openai.GPT4Turbo20240409,
Expand All @@ -58,6 +59,7 @@ func main() {
},
},
},
&person,
)
if err != nil {
panic(err)
Expand Down
17 changes: 12 additions & 5 deletions examples/classifcation/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,17 @@ type Prediction struct {
func classify(data string) *Prediction {
ctx := context.Background()

client, err := instructor.FromAnthropic[Prediction](
client, err := instructor.FromAnthropic(
anthropic.NewClient(os.Getenv("ANTHROPIC_API_KEY")),
instructor.WithMode(instructor.ModeToolCall),
instructor.WithMaxRetries(1),
instructor.WithMaxRetries(3),
)
if err != nil {
panic(err)
}

prediction, err := client.CreateChatCompletion(
var prediction Prediction
err = client.CreateChatCompletion(
ctx,
instructor.Request{
Model: anthropic.ModelClaude3Haiku20240307,
Expand All @@ -48,12 +49,13 @@ func classify(data string) *Prediction {
},
},
},
&prediction,
)
if err != nil {
panic(err)
}

return prediction
return &prediction
}

func main() {
Expand All @@ -62,8 +64,13 @@ func main() {
prediction := classify(ticket)

assert(prediction.contains(LabelTechIssue), "Expected ticket to be related to tech issue")
assert(prediction.contains(LabelTechIssue), "Expected ticket to be related to billing")
assert(prediction.contains(LabelBilling), "Expected ticket to be related to billing")
assert(!prediction.contains(LabelGeneralQuery), "Expected ticket NOT to be a general query")

fmt.Printf("%+v\n", prediction)
/*
&{Labels:[{Type:tech_issue} {Type:billing}]}
*/
}

/******/
Expand Down
14 changes: 10 additions & 4 deletions examples/function_calling/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,13 @@ func (s *Search) execute() {

type Searches = []Search

// type Searches struct {
// Items []Search `json:"searches" jsonschema:"title=Searches,description=A list of search results"`
// }

func segment(ctx context.Context, data string) *Searches {

client, err := instructor.FromOpenAI[Searches](
client, err := instructor.FromOpenAI(
openai.NewClient(os.Getenv("OPENAI_API_KEY")),
instructor.WithMode(instructor.ModeToolCall),
instructor.WithMaxRetries(3),
Expand All @@ -40,23 +44,25 @@ func segment(ctx context.Context, data string) *Searches {
panic(err)
}

searches, err := client.CreateChatCompletion(
var searches Searches
err = client.CreateChatCompletion(
ctx,
instructor.Request{
Model: openai.GPT4Turbo20240409,
Model: openai.GPT4o,
Messages: []instructor.Message{
{
Role: instructor.RoleUser,
Content: fmt.Sprintf("Consider the data below: '\n%s' and segment it into multiple search queries", data),
},
},
},
&searches,
)
if err != nil {
panic(err)
}

return searches
return &searches
}

func main() {
Expand Down
6 changes: 4 additions & 2 deletions examples/images/anthropic/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ func (bc *MovieCatalog) PrintCatalog() {
func main() {
ctx := context.Background()

client, err := instructor.FromAnthropic[MovieCatalog](
client, err := instructor.FromAnthropic(
anthropic.NewClient(os.Getenv("ANTHROPIC_API_KEY")),
instructor.WithMode(instructor.ModeJSONSchema),
instructor.WithMaxRetries(3),
Expand All @@ -43,7 +43,8 @@ func main() {

url := "https://utfs.io/f/bd0dbae6-27e3-4604-b640-fd2ffea891b8-fxyywt.jpeg"

movieCatalog, err := client.CreateChatCompletion(
var movieCatalog MovieCatalog
err = client.CreateChatCompletion(
ctx,
instructor.Request{
Model: "claude-3-haiku-20240307",
Expand All @@ -65,6 +66,7 @@ func main() {
},
},
},
&movieCatalog,
)
if err != nil {
panic(err)
Expand Down
6 changes: 4 additions & 2 deletions examples/images/openai/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func (bc *BookCatalog) PrintCatalog() {
func main() {
ctx := context.Background()

client, err := instructor.FromOpenAI[BookCatalog](
client, err := instructor.FromOpenAI(
openai.NewClient(os.Getenv("OPENAI_API_KEY")),
instructor.WithMode(instructor.ModeJSON),
instructor.WithMaxRetries(3),
Expand All @@ -41,7 +41,8 @@ func main() {

url := "https://utfs.io/f/fe55d6bd-e920-4a6f-8e93-a4c9dd851b90-eivhb2.png"

bookCatalog, err := client.CreateChatCompletion(
var bookCatalog BookCatalog
err = client.CreateChatCompletion(
ctx,
instructor.Request{
Model: openai.GPT4o,
Expand All @@ -63,6 +64,7 @@ func main() {
},
},
},
&bookCatalog,
)

if err != nil {
Expand Down
6 changes: 4 additions & 2 deletions examples/user/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ type Person struct {
func main() {
ctx := context.Background()

client, err := instructor.FromOpenAI[Person](
client, err := instructor.FromOpenAI(
openai.NewClient(os.Getenv("OPENAI_API_KEY")),
instructor.WithMode(instructor.ModeJSON),
instructor.WithMaxRetries(3),
Expand All @@ -26,7 +26,8 @@ func main() {
panic(err)
}

person, err := client.CreateChatCompletion(
var person Person
err = client.CreateChatCompletion(
ctx,
instructor.Request{
Model: openai.GPT4Turbo20240409,
Expand All @@ -37,6 +38,7 @@ func main() {
},
},
},
&person,
)
if err != nil {
panic(err)
Expand Down
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ go 1.21.8

require (
github.com/invopop/jsonschema v0.12.0
github.com/liushuangls/go-anthropic/v2 v2.0.3
github.com/liushuangls/go-anthropic/v2 v2.1.0
github.com/sashabaranov/go-openai v1.24.0
)

Expand Down
2 changes: 2 additions & 0 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ github.com/invopop/jsonschema v0.12.0/go.mod h1:ffZ5Km5SWWRAIN6wbDXItl95euhFz2uO
github.com/josharian/intern v1.0.0/go.mod h1:5DoeVV0s6jJacbCEi61lwdGj/aVlrQvzHFFd8Hwg//Y=
github.com/liushuangls/go-anthropic/v2 v2.0.3 h1:vNA74jYpBxqXxpj3b/+iLtvfQ6fwCY56pseIAPCItQs=
github.com/liushuangls/go-anthropic/v2 v2.0.3/go.mod h1:8BKv/fkeTaL5R9R9bGkaknYBueyw2WxY20o7bImbOek=
github.com/liushuangls/go-anthropic/v2 v2.1.0 h1:5ntOeehozlMin0+hgnhxbTru+tmBH84ADaSPelG5fPg=
github.com/liushuangls/go-anthropic/v2 v2.1.0/go.mod h1:8BKv/fkeTaL5R9R9bGkaknYBueyw2WxY20o7bImbOek=
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
Expand Down
34 changes: 13 additions & 21 deletions pkg/instructor/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,46 +9,38 @@ import (
anthropic "github.com/liushuangls/go-anthropic/v2"
)

type AnthropicClient[T any] struct {
type AnthropicClient struct {
Name string

client *anthropic.Client
schema *Schema[T]
mode Mode
}

var _ Client[any] = &AnthropicClient[any]{}
var _ Client = &AnthropicClient{}

func NewAnthropicClient[T any](client *anthropic.Client, schema *Schema[T], mode Mode) (*AnthropicClient[T], error) {
o := &AnthropicClient[T]{
func NewAnthropicClient(client *anthropic.Client) (*AnthropicClient, error) {
o := &AnthropicClient{
Name: "Anthropic",
client: client,
schema: schema,
mode: mode,
}
return o, nil
}

func (a *AnthropicClient[T]) CreateChatCompletion(ctx context.Context, request Request) (string, error) {
return a.completionModeHandler(ctx, request)
}

func (a *AnthropicClient[any]) completionModeHandler(ctx context.Context, request Request) (string, error) {
switch a.mode {
func (a *AnthropicClient) CreateChatCompletion(ctx context.Context, request Request, mode Mode, schema *Schema) (string, error) {
switch mode {
case ModeToolCall:
return a.completionToolCall(ctx, request)
return a.completionToolCall(ctx, request, schema)
case ModeJSONSchema:
return a.completionJSONSchema(ctx, request)
return a.completionJSONSchema(ctx, request, schema)
default:
return "", fmt.Errorf("mode '%s' is not supported for %s", a.mode, a.Name)
return "", fmt.Errorf("mode '%s' is not supported for %s", mode, a.Name)
}
}

func (a *AnthropicClient[any]) completionToolCall(ctx context.Context, request Request) (string, error) {
func (a *AnthropicClient) completionToolCall(ctx context.Context, request Request, schema *Schema) (string, error) {

tools := []anthropic.ToolDefinition{}

for _, function := range a.schema.Functions {
for _, function := range schema.Functions {
t := anthropic.ToolDefinition{
Name: function.Name,
Description: function.Description,
Expand Down Expand Up @@ -92,15 +84,15 @@ func (a *AnthropicClient[any]) completionToolCall(ctx context.Context, request R

}

func (a *AnthropicClient[any]) completionJSONSchema(ctx context.Context, request Request) (string, error) {
func (a *AnthropicClient) completionJSONSchema(ctx context.Context, request Request, schema *Schema) (string, error) {

system := fmt.Sprintf(`
Please responsd with json in the following json_schema:
%s
Make sure to return an instance of the JSON, not the schema itself.
`, a.schema.String)
`, schema.String)

messages, err := toAnthropicMessages(&request)
if err != nil {
Expand Down
4 changes: 3 additions & 1 deletion pkg/instructor/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,12 @@ import (
"context"
)

type Client[T any] interface {
type Client interface {
CreateChatCompletion(
ctx context.Context,
request Request,
mode Mode,
schema *Schema,
) (string, error)

// TODO: implement streaming
Expand Down
Loading

0 comments on commit d5c9535

Please sign in to comment.