Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(go/plugins/vertexai): add context caching to vertexai #1478

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
227 changes: 227 additions & 0 deletions go/plugins/vertexai/cache.go
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Go style guide uses tabs (as 4 spaces) instead of 2 spaces. Please update here and in the other PR.

Original file line number Diff line number Diff line change
@@ -0,0 +1,227 @@
// Copyright 2024 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package vertexai

import (
"context"
"crypto/sha256"
"encoding/hex"
"fmt"
"strings"
"time"

"cloud.google.com/go/vertexai/genai"
"github.com/firebase/genkit/go/ai"
)

// CacheConfigDetails holds configuration details for caching.
// Adjust fields as needed for your use case.
type CacheConfigDetails struct {
// TTLSeconds is how long to keep the cached content.
// If zero, defaults to 60 minutes.
TTLSeconds int
}

var (
INVALID_ARGUMENT_MESSAGES = struct {
modelVersion string
tools string
}{
modelVersion: "Invalid modelVersion specified.",
tools: "Tools are not supported with context caching.",
}
)

// getContentForCache inspects the request and modelVersion, and constructs a
// genai.CachedContent that should be cached.
// This is where you decide what goes into the cache: large documents, system instructions, etc.
func getContentForCache(
request *ai.ModelRequest,
modelVersion string,
cacheConfig *CacheConfigDetails,
) (*genai.CachedContent, error) {
var systemInstruction string
var userParts []*genai.Content

for _, m := range request.Messages {
if m.Role == ai.RoleSystem {
sysParts := []string{}
for _, p := range m.Content {
if p.IsText() {
sysParts = append(sysParts, p.Text)
}
}
if len(sysParts) > 0 {
systemInstruction = strings.Join(sysParts, "\n")
}
}
}

if len(request.Messages) > 0 {
for _, m := range request.Messages {
if m.Role == ai.RoleUser {
parts, err := convertParts(m.Content)
if err != nil {
return nil, err
}
userParts = append(userParts, &genai.Content{
Role: "user",
Parts: parts,
})
break
}
}
}

if systemInstruction == "" && len(userParts) == 0 {
return nil, fmt.Errorf("no content to cache")
}

content := &genai.CachedContent{
Model: modelVersion,
SystemInstruction: &genai.Content{
Role: "system",
Parts: []genai.Part{genai.Text(systemInstruction)},
},
Contents: userParts,
}

return content, nil
}

// generateCacheKey creates a unique key for the cached content based on its contents.
// We can hash the system instruction and model version.
func generateCacheKey(content *genai.CachedContent) string {
hash := sha256.New()
if content.SystemInstruction != nil {
for _, p := range content.SystemInstruction.Parts {
if t, ok := p.(genai.Text); ok {
hash.Write([]byte(t))
}
}
}
hash.Write([]byte(content.Model))

// Also incorporate any user content parts to ensure uniqueness
for _, c := range content.Contents {
for _, p := range c.Parts {
switch v := p.(type) {
case genai.Text:
hash.Write([]byte(v))
case genai.Blob:
hash.Write([]byte(v.MIMEType))
hash.Write(v.Data)
}
}
}

return hex.EncodeToString(hash.Sum(nil))
}

// calculateTTL returns the TTL as a time.Duration.
func calculateTTL(cacheConfig *CacheConfigDetails) time.Duration {
if cacheConfig == nil || cacheConfig.TTLSeconds <= 0 {
return 60 * time.Minute
}
return time.Duration(cacheConfig.TTLSeconds) * time.Second
}



// getKeysFrom returns the keys from the given map as a slice of strings, it is using to get the supported models
func getKeysFrom(m map[string]ai.ModelCapabilities) []string {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
}
return keys
}

// contains checks if a slice contains a given string.
func contains(slice []string, target string) bool {
for _, s := range slice {
if s == target {
return true
}
}
return false
}

func countTokensInMessages(messages []*ai.Message) int {
totalTokens := 0
for _, msg := range messages {
for _, part := range msg.Content {
if part.IsText() {
words := strings.Fields(part.Text)
totalTokens += len(words)
}
}
}
return totalTokens
}

// validateContextCacheRequest decides if we should try caching for this request.
// For demonstration, we will cache if there are more than 2 messages or if there's a system prompt.
func validateContextCacheRequest(request *ai.ModelRequest, modelVersion string) error {
models := getKeysFrom(knownCaps)
if modelVersion == "" || !contains(models, modelVersion) {
return fmt.Errorf(INVALID_ARGUMENT_MESSAGES.modelVersion)
}
if len(request.Tools) > 0 {
return fmt.Errorf(INVALID_ARGUMENT_MESSAGES.tools)
}

tokenCount := countTokensInMessages(request.Messages)
// The minimum input token count for context caching is 32,768, and the maximum is the same as the maximum for the given model.
// https://ai.google.dev/gemini-api/docs/caching?lang=go
const minTokens = 32768
if tokenCount < minTokens {
return fmt.Errorf("the cached content is of %d tokens. The minimum token count to start caching is %d.", tokenCount, minTokens)
}

// If we reach here, request is valid for context caching
return nil
}

// handleCacheIfNeeded checks if caching should be used, attempts to find or create the cache,
// and returns the cached content if applicable.
func handleCacheIfNeeded(
ctx context.Context,
client *genai.Client,
request *ai.ModelRequest,
modelVersion string,
cacheConfig *CacheConfigDetails,
) (*genai.CachedContent, error) {

if cacheConfig == nil || validateContextCacheRequest(request, modelVersion) != nil {
return nil, nil
}
cachedContent, err := getContentForCache(request, modelVersion, cacheConfig)
if err != nil {
return nil, nil
}

cachedContent.Model = modelVersion
cacheKey := generateCacheKey(cachedContent)

cachedContent.Expiration = genai.ExpireTimeOrTTL{TTL: calculateTTL(cacheConfig)}
cachedContent.Name = cacheKey
newCache, err := client.CreateCachedContent(ctx, cachedContent)
if err != nil {
return nil, fmt.Errorf("failed to create cache: %w", err)
}

return newCache, nil
}
19 changes: 16 additions & 3 deletions go/plugins/vertexai/vertexai.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,9 @@ const (

var (
knownCaps = map[string]ai.ModelCapabilities{
"gemini-1.0-pro": gemini.BasicText,
"gemini-1.5-pro": gemini.Multimodal,
"gemini-1.5-flash": gemini.Multimodal,
"gemini-1.0-pro": gemini.BasicText,
"gemini-1.5-pro": gemini.Multimodal,
"gemini-1.5-flash-002": gemini.Multimodal,
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

context caching is working only with stable Gemini model versions using number prefix like https://cloud.google.com/vertex-ai/generative-ai/docs/learn/model-versions#stable-versions-available

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We don't want to get rid of the plain gemini-1.5-flash. @hugoaguirre is working on adding model version support per model family, please work with him on how combine the two efforts.

}

knownEmbedders = []string{
Expand Down Expand Up @@ -238,10 +238,23 @@ func generate(
input *ai.ModelRequest,
cb func(context.Context, *ai.ModelResponseChunk) error,
) (*ai.ModelResponse, error) {
cacheConfig := &CacheConfigDetails{
TTLSeconds: 3600, // hardcoded to 1 hour
}

// Attempt to handle caching before creating the model.
cache, err := handleCacheIfNeeded(ctx, client, input, model, cacheConfig)
if err != nil {
return nil, err
}

gm, err := newModel(client, model, input)
if err != nil {
return nil, err
}
if cache != nil {
gm.CachedContentName = cache.Name
}
cs, err := startChat(gm, input)
if err != nil {
return nil, err
Expand Down
Loading