From a9af53c35f4fa208436fbc5892dfe66b0279ebb1 Mon Sep 17 00:00:00 2001 From: Pavel Jbanov Date: Thu, 1 Aug 2024 15:30:41 -0400 Subject: [PATCH] fix(go): fixed googleai/vertexai system prompt handling (#732) --- go/plugins/googleai/googleai.go | 29 ++++++++++++++++++++++++++--- go/plugins/vertexai/vertexai.go | 29 ++++++++++++++++++++++++++--- 2 files changed, 52 insertions(+), 6 deletions(-) diff --git a/go/plugins/googleai/googleai.go b/go/plugins/googleai/googleai.go index ae81ce210..0a5bc6067 100644 --- a/go/plugins/googleai/googleai.go +++ b/go/plugins/googleai/googleai.go @@ -232,7 +232,10 @@ func generate( input *ai.GenerateRequest, cb func(context.Context, *ai.GenerateResponseChunk) error, ) (*ai.GenerateResponse, error) { - gm := newModel(client, model, input) + gm, err := newModel(client, model, input) + if err != nil { + return nil, err + } cs, err := startChat(gm, input) if err != nil { return nil, err @@ -300,7 +303,7 @@ func generate( return r, nil } -func newModel(client *genai.Client, model string, input *ai.GenerateRequest) *genai.GenerativeModel { +func newModel(client *genai.Client, model string, input *ai.GenerateRequest) (*genai.GenerativeModel, error) { gm := client.GenerativeModel(model) gm.SetCandidateCount(int32(input.Candidates)) if c, ok := input.Config.(*ai.GenerationCommonConfig); ok && c != nil { @@ -320,7 +323,21 @@ func newModel(client *genai.Client, model string, input *ai.GenerateRequest) *ge gm.SetTopP(float32(c.TopP)) } } - return gm + for _, m := range input.Messages { + systemParts, err := convertParts(m.Content) + if err != nil { + return nil, err + + } + // system prompts go into GenerativeModel.SystemInstruction field. + if m.Role == ai.RoleSystem { + gm.SystemInstruction = &genai.Content{ + Parts: systemParts, + Role: string(m.Role), + } + } + } + return gm, nil } // startChat starts a chat session and configures it with the input messages. @@ -332,6 +349,12 @@ func startChat(gm *genai.GenerativeModel, input *ai.GenerateRequest) (*genai.Cha for len(messages) > 1 { m := messages[0] messages = messages[1:] + + // skip system prompt message, it's handled separately. + if m.Role == ai.RoleSystem { + continue + } + parts, err := convertParts(m.Content) if err != nil { return nil, err diff --git a/go/plugins/vertexai/vertexai.go b/go/plugins/vertexai/vertexai.go index ef60e4677..c13881509 100644 --- a/go/plugins/vertexai/vertexai.go +++ b/go/plugins/vertexai/vertexai.go @@ -239,7 +239,10 @@ func generate( input *ai.GenerateRequest, cb func(context.Context, *ai.GenerateResponseChunk) error, ) (*ai.GenerateResponse, error) { - gm := newModel(client, model, input) + gm, err := newModel(client, model, input) + if err != nil { + return nil, err + } cs, err := startChat(gm, input) if err != nil { return nil, err @@ -307,7 +310,7 @@ func generate( return r, nil } -func newModel(client *genai.Client, model string, input *ai.GenerateRequest) *genai.GenerativeModel { +func newModel(client *genai.Client, model string, input *ai.GenerateRequest) (*genai.GenerativeModel, error) { gm := client.GenerativeModel(model) gm.SetCandidateCount(int32(input.Candidates)) if c, ok := input.Config.(*ai.GenerationCommonConfig); ok && c != nil { @@ -327,7 +330,21 @@ func newModel(client *genai.Client, model string, input *ai.GenerateRequest) *ge gm.SetTopP(float32(c.TopP)) } } - return gm + for _, m := range input.Messages { + systemParts, err := convertParts(m.Content) + if err != nil { + return nil, err + + } + // system prompts go into GenerativeModel.SystemInstruction field. + if m.Role == ai.RoleSystem { + gm.SystemInstruction = &genai.Content{ + Parts: systemParts, + Role: string(m.Role), + } + } + } + return gm, nil } // startChat starts a chat session and configures it with the input messages. @@ -339,6 +356,12 @@ func startChat(gm *genai.GenerativeModel, input *ai.GenerateRequest) (*genai.Cha for len(messages) > 1 { m := messages[0] messages = messages[1:] + + // skip system prompt message, it's handled separately. + if m.Role == ai.RoleSystem { + continue + } + parts, err := convertParts(m.Content) if err != nil { return nil, err