Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

REST API implementation for VertexAI Gemini models. #345

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions langfun/core/llms/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,8 @@
from langfun.core.llms.groq import GroqWhisper_Large_v3Turbo

from langfun.core.llms.vertexai import VertexAI
from langfun.core.llms.vertexai import VertexAIRest
from langfun.core.llms.vertexai import VertexAIRestGemini1_5
from langfun.core.llms.vertexai import VertexAIGemini1_5
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5
from langfun.core.llms.vertexai import VertexAIGeminiPro1_5_Latest
Expand Down
2 changes: 1 addition & 1 deletion langfun/core/llms/rest.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ class REST(lf.LanguageModel):
api_endpoint: Annotated[
str,
'The endpoint of the REST API.'
]
] = ''

request: Annotated[
Callable[[lf.Message, lf.LMSamplingOptions], dict[str, Any]],
Expand Down
261 changes: 255 additions & 6 deletions langfun/core/llms/vertexai.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,21 @@
# limitations under the License.
"""Vertex AI generative models."""

import base64
import functools
import os
from typing import Annotated, Any

import langfun.core as lf
from langfun.core import modalities as lf_modalities
from langfun.core.llms import rest
import pyglove as pg

try:
# pylint: disable=g-import-not-at-top
from google import auth as google_auth
from google.auth import credentials as credentials_lib
from google.auth.transport import requests as auth_requests
import vertexai
from google.cloud.aiplatform import models as aiplatform_models
from vertexai import generative_models
Expand All @@ -32,6 +36,8 @@

Credentials = credentials_lib.Credentials
except ImportError:
google_auth = None
auth_requests = None
credentials_lib = None # pylint: disable=invalid-name
vertexai = None
generative_models = None
Expand Down Expand Up @@ -449,6 +455,241 @@ def _sample_endpoint_model(self, prompt: lf.Message) -> lf.LMSamplingResult:
])


@lf.use_init_args(['model'])
class VertexAIRest(rest.REST):
"""Language model served on VertexAI with REST API."""

model: pg.typing.Annotated[
pg.typing.Enum(
pg.MISSING_VALUE, list(SUPPORTED_MODELS_AND_SETTINGS.keys())
),
(
'Vertex AI model name with REST API support. See '
'https://cloud.google.com/vertex-ai/generative-ai/docs/'
'model-reference/inference#supported-models'
' for details.'
),
]

project: Annotated[
str | None,
(
'Vertex AI project ID. Or set from environment variable '
'VERTEXAI_PROJECT.'
),
] = None

location: Annotated[
str | None,
(
'Vertex AI service location. Or set from environment variable '
'VERTEXAI_LOCATION.'
),
] = None

credentials: Annotated[
Credentials | None,
(
'Credentials to use. If None, the default credentials to the '
'environment will be used.'
),
] = None

supported_modalities: Annotated[
list[str],
'A list of MIME types for supported modalities'
] = []

def _on_bound(self):
super()._on_bound()
if google_auth is None:
raise ValueError(
'Please install "langfun[llm-google-vertex]" to use Vertex AI models.'
)
self._project = None
self._credentials = None

def _initialize(self):
project = self.project or os.environ.get('VERTEXAI_PROJECT', None)
if not project:
raise ValueError(
'Please specify `project` during `__init__` or set environment '
'variable `VERTEXAI_PROJECT` with your Vertex AI project ID.'
)

location = self.location or os.environ.get('VERTEXAI_LOCATION', None)
if not location:
raise ValueError(
'Please specify `location` during `__init__` or set environment '
'variable `VERTEXAI_LOCATION` with your Vertex AI service location.'
)

self._project = project
credentials = self.credentials
if credentials is None:
# Use default credentials.
credentials = google_auth.default(
scopes=['https://www.googleapis.com/auth/cloud-platform']
)
self._credentials = credentials

@property
def max_concurrency(self) -> int:
"""Returns the maximum number of concurrent requests."""
return self.rate_to_max_concurrency(
requests_per_min=SUPPORTED_MODELS_AND_SETTINGS[self.model].rpm,
tokens_per_min=0,
)

def estimate_cost(
self,
num_input_tokens: int,
num_output_tokens: int
) -> float | None:
"""Estimate the cost based on usage."""
cost_per_1k_input_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
'cost_per_1k_input_chars', None
)
cost_per_1k_output_chars = SUPPORTED_MODELS_AND_SETTINGS[self.model].get(
'cost_per_1k_output_chars', None
)
if cost_per_1k_output_chars is None or cost_per_1k_input_chars is None:
return None
return (
cost_per_1k_input_chars * num_input_tokens
+ cost_per_1k_output_chars * num_output_tokens
) * AVGERAGE_CHARS_PER_TOEKN / 1000

@functools.cached_property
def _session(self):
assert self._api_initialized
assert self._credentials is not None
assert auth_requests is not None
s = auth_requests.AuthorizedSession(self._credentials)
s.headers.update(self.headers or {})
return s

@property
def headers(self):
return {
'Content-Type': 'application/json; charset=utf-8',
}

@property
def api_endpoint(self) -> str:
return (
f'https://{self.location}-aiplatform.googleapis.com/v1/projects/'
f'{self.project}/locations/{self.location}/publishers/google/'
f'models/{self.model}:generateContent'
)

def request(
self, prompt: lf.Message, sampling_options: lf.LMSamplingOptions
) -> dict[str, Any]:
request = dict(
generationConfig=self._generation_config(prompt, sampling_options)
)
request['contents'] = [self._content_from_message(prompt)]
return request

def _generation_config(
self, prompt: lf.Message, options: lf.LMSamplingOptions
) -> dict[str, Any]:
"""Returns a dict as generation config for prompt and LMSamplingOptions."""
config = dict(
temperature=options.temperature,
maxOutputTokens=options.max_tokens,
candidateCount=options.n,
topK=options.top_k,
topP=options.top_p,
stopSequences=options.stop,
seed=options.random_seed,
responseLogprobs=options.logprobs,
logprobs=options.top_logprobs,
)

if json_schema := prompt.metadata.get('json_schema'):
if not isinstance(json_schema, dict):
raise ValueError(
f'`json_schema` must be a dict, got {json_schema!r}.'
)
json_schema = pg.to_json(json_schema)
config['responseSchema'] = json_schema
config['responseMimeType'] = 'application/json'
prompt.metadata.formatted_text = (
prompt.text
+ '\n\n [RESPONSE FORMAT (not part of prompt)]\n'
+ pg.to_json_str(json_schema, json_indent=2)
)
return config

def _content_from_message(self, prompt: lf.Message) -> dict[str, Any]:
"""Gets generation content from langfun message."""
parts = []
for lf_chunk in prompt.chunk():
if isinstance(lf_chunk, str):
parts.append({'text': lf_chunk})
elif isinstance(lf_chunk, lf_modalities.Mime):
try:
modalities = lf_chunk.make_compatible(
self.supported_modalities + ['text/plain']
)
if isinstance(modalities, lf_modalities.Mime):
modalities = [modalities]
for modality in modalities:
if modality.is_text:
parts.append({'text': modality.to_text()})
else:
parts.append({
'inlineData': {
'data': base64.b64encode(modality.to_bytes()).decode(),
'mimeType': modality.mime_type,
}
})
except lf.ModalityError as e:
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}') from e
else:
raise lf.ModalityError(f'Unsupported modality: {lf_chunk!r}')
return dict(role='user', parts=parts)

def result(self, json: dict[str, Any]) -> lf.LMSamplingResult:
messages = [
self._message_from_content_parts(candidate['content']['parts'])
for candidate in json['candidates']
]
usage = json['usageMetadata']
input_tokens = usage['promptTokenCount']
output_tokens = usage['candidatesTokenCount']
return lf.LMSamplingResult(
[lf.LMSample(message) for message in messages],
usage=lf.LMSamplingUsage(
prompt_tokens=input_tokens,
completion_tokens=output_tokens,
total_tokens=input_tokens + output_tokens,
estimated_cost=self.estimate_cost(
num_input_tokens=input_tokens,
num_output_tokens=output_tokens,
),
),
)

def _message_from_content_parts(
self, parts: list[dict[str, Any]]
) -> lf.Message:
"""Converts Vertex AI's content parts protocol to message."""
chunks = []
for part in parts:
if text_part := part.get('text'):
chunks.append(text_part)
elif inline_part := part.get('inlineData'):
chunks.append(
lf_modalities.Mime(inline_part['data'], inline_part['mimeType'])
)
else:
raise ValueError(f'Unsupported part: {part}')
return lf.AIMessage.from_chunks(chunks)


class _ModelHub:
"""Vertex AI model hub."""

Expand Down Expand Up @@ -547,13 +788,21 @@ class VertexAIGeminiPro1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
model = 'gemini-1.5-pro'


class VertexAIGeminiPro1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name
class VertexAIRestGemini1_5(VertexAIRest): # pylint: disable=invalid-name
"""Vertex AI Gemini 1.5 model with REST API."""

supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
_DOCUMENT_TYPES + _IMAGE_TYPES + _AUDIO_TYPES + _VIDEO_TYPES
)


class VertexAIGeminiPro1_5_002(VertexAIRestGemini1_5): # pylint: disable=invalid-name
"""Vertex AI Gemini 1.5 Pro model."""

model = 'gemini-1.5-pro-002'


class VertexAIGeminiPro1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name
class VertexAIGeminiPro1_5_001(VertexAIRestGemini1_5): # pylint: disable=invalid-name
"""Vertex AI Gemini 1.5 Pro model."""

model = 'gemini-1.5-pro-001'
Expand Down Expand Up @@ -583,13 +832,13 @@ class VertexAIGeminiFlash1_5(VertexAIGemini1_5): # pylint: disable=invalid-name
model = 'gemini-1.5-flash'


class VertexAIGeminiFlash1_5_002(VertexAIGemini1_5): # pylint: disable=invalid-name
class VertexAIGeminiFlash1_5_002(VertexAIRestGemini1_5): # pylint: disable=invalid-name
"""Vertex AI Gemini 1.5 Flash model."""

model = 'gemini-1.5-flash-002'


class VertexAIGeminiFlash1_5_001(VertexAIGemini1_5): # pylint: disable=invalid-name
class VertexAIGeminiFlash1_5_001(VertexAIRestGemini1_5): # pylint: disable=invalid-name
"""Vertex AI Gemini 1.5 Flash model."""

model = 'gemini-1.5-flash-001'
Expand All @@ -601,14 +850,14 @@ class VertexAIGeminiFlash1_5_0514(VertexAIGemini1_5): # pylint: disable=invalid
model = 'gemini-1.5-flash-preview-0514'


class VertexAIGeminiPro1(VertexAI): # pylint: disable=invalid-name
class VertexAIGeminiPro1(VertexAIRest): # pylint: disable=invalid-name
"""Vertex AI Gemini 1.0 Pro model."""

model = 'gemini-1.0-pro'


class VertexAIGeminiPro1Vision(VertexAI): # pylint: disable=invalid-name
"""Vertex AI Gemini 1.0 Pro model."""
"""Vertex AI Gemini 1.0 Pro Vision model."""

model = 'gemini-1.0-pro-vision'
supported_modalities: pg.typing.List(str).freeze( # pytype: disable=invalid-annotation
Expand Down
Loading
Loading