diff --git a/README.md b/README.md index e401fe30..ed293ec1 100644 --- a/README.md +++ b/README.md @@ -415,7 +415,24 @@ window.env = { // Alternatively you can set this to be your REST api elb endpoint RESTAPI_URI: 'http://localhost:8080/', RESTAPI_VERSION: 'v2', - SESSION_REST_API_URI: '' + SESSION_REST_API_URI: '', + "MODELS": [ + { + "model": "streaming-textgen-model", + "streaming": true, + "modelType": "textgen" + }, + { + "model": "non-streaming-textgen-model", + "streaming": false, + "modelType": "textgen" + }, + { + "model": "embedding-model", + "streaming": null, + "modelType": "embedding" + } + ] } ``` diff --git a/example_config.yaml b/example_config.yaml index 975aa23b..e51e54be 100644 --- a/example_config.yaml +++ b/example_config.yaml @@ -237,7 +237,18 @@ dev: # Anything within this config is copied to a configuration for starting LiteLLM in the REST API container. # It is suggested to put an "ignored" API key so that calls to locally hosted models don't fail on OpenAI calls # from LiteLLM. + # We added `lisa_params` to add additional metadata for interaction with the Chat UI. Specify if the model is a + # textgen or embedding model, and if it is textgen, specify whether it supports streaming. If embedding, then + # omit the `streaming` parameter. When defining the model list, the `lisa_params` will be an object in the model + # definition that will have the `model_type` and `streaming` fields in it. A commented example is provided below. litellmConfig: litellm_settings: telemetry: false # Don't try to send telemetry to LiteLLM servers. - model_list: [] # Add any of your existing (not LISA-hosted) models here. + model_list: # Add any of your existing (not LISA-hosted) models here. +# - model_name: mymodel +# litellm_params: +# model: openai/myprovider/mymodel +# api_key: ignored +# lisa_params: +# model_type: textgen +# streaming: true diff --git a/lib/schema.ts b/lib/schema.ts index f7ccf7c5..7078e718 100644 --- a/lib/schema.ts +++ b/lib/schema.ts @@ -677,6 +677,10 @@ const ApiGatewayConfigSchema = z /** * Configuration for models inside the LiteLLM Config * See https://litellm.vercel.app/docs/proxy/configs#all-settings for more details. + * + * The `lisa_params` are custom for the LISA installation to add model metadata to allow the models to be referenced + * correctly within the Chat UI. LiteLLM will ignore these parameters as it is not looking for them, and it will not + * fail to initialize as a result of them existing. */ const LiteLLMModel = z.object({ model_name: z.string(), @@ -686,6 +690,25 @@ const LiteLLMModel = z.object({ api_key: z.string().optional(), aws_region_name: z.string().optional(), }), + lisa_params: z + .object({ + streaming: z.boolean().nullable().default(null), + model_type: z.nativeEnum(ModelType), + }) + .refine( + (data) => { + // 'textgen' type must have boolean streaming, 'embedding' type must have null streaming + const isValidForTextgen = data.model_type === 'textgen' && typeof data.streaming === 'boolean'; + const isValidForEmbedding = data.model_type === 'embedding' && data.streaming === null; + + return isValidForTextgen || isValidForEmbedding; + }, + { + message: `For 'textgen' models, 'streaming' must be true or false. + For 'embedding' models, 'streaming' must not be set.`, + path: ['streaming'], + }, + ), model_info: z .object({ id: z.string().optional(), @@ -704,7 +727,7 @@ const LiteLLMModel = z.object({ */ const LiteLLMConfig = z.object({ environment_variables: z.map(z.string(), z.string()).optional(), - model_list: z.array(LiteLLMModel).optional(), + model_list: z.array(LiteLLMModel).optional().nullable().default([]), litellm_settings: z.object({ // ALL (https://github.com/BerriAI/litellm/blob/main/litellm/__init__.py) telemetry: z.boolean().default(false).optional(), diff --git a/lib/user-interface/index.ts b/lib/user-interface/index.ts index b7d191c5..2174022a 100644 --- a/lib/user-interface/index.ts +++ b/lib/user-interface/index.ts @@ -161,6 +161,24 @@ export class UserInterfaceStack extends Stack { }, ); + const ecsModels = config.ecsModels.map((modelConfig) => { + return { + model: modelConfig.modelId, + streaming: modelConfig.streaming, + modelType: modelConfig.modelType, + }; + }); + const litellmModels = config.litellmConfig.model_list ? config.litellmConfig.model_list : []; + const modelsList = ecsModels.concat( + litellmModels.map((model) => { + return { + model: model.model_name, + streaming: model.lisa_params.streaming, + modelType: model.lisa_params.model_type, + }; + }), + ); + // Website bucket deployment // Copy auth and LISA-Serve info to UI deployment bucket const appEnvConfig = { @@ -179,6 +197,7 @@ export class UserInterfaceStack extends Stack { fontColor: config.systemBanner?.fontColor, }, API_BASE_URL: config.apiGatewayConfig?.domainName ? '/' : `/${config.deploymentStage}/`, + MODELS: modelsList, }; const appEnvSource = Source.data('env.js', `window.env = ${JSON.stringify(appEnvConfig)}`); diff --git a/lib/user-interface/react/src/components/chatbot/Chat.tsx b/lib/user-interface/react/src/components/chatbot/Chat.tsx index dfbe89bb..3eb52d37 100644 --- a/lib/user-interface/react/src/components/chatbot/Chat.tsx +++ b/lib/user-interface/react/src/components/chatbot/Chat.tsx @@ -136,10 +136,13 @@ export default function Chat({ sessionId }) { useEffect(() => { if (selectedModelOption) { const model = models.filter((model) => model.id === selectedModelOption.value)[0]; - setModelCanStream(true); + if (!model.streaming && model.streaming !== undefined && streamingEnabled) { + setStreamingEnabled(false); + } + setModelCanStream(model.streaming || model.streaming === undefined); setSelectedModel(model); } - }, [selectedModelOption, streamingEnabled]); + }, [models, selectedModelOption, streamingEnabled]); useEffect(() => { setModelsOptions(models.map((model) => ({ label: model.id, value: model.id }))); @@ -463,8 +466,7 @@ export default function Chat({ sessionId }) { const describeTextGenModels = useCallback(async () => { setIsLoadingModels(true); - const resp = await describeModels(auth.user?.id_token); - setModels(resp.data); + setModels(await describeModels(auth.user?.id_token, 'textgen')); setIsLoadingModels(false); // eslint-disable-next-line react-hooks/exhaustive-deps }, []); diff --git a/lib/user-interface/react/src/components/chatbot/RagOptions.tsx b/lib/user-interface/react/src/components/chatbot/RagOptions.tsx index 074693fb..df68cf75 100644 --- a/lib/user-interface/react/src/components/chatbot/RagOptions.tsx +++ b/lib/user-interface/react/src/components/chatbot/RagOptions.tsx @@ -47,8 +47,8 @@ export default function RagControls({ auth, isRunning, setUseRag, setRagConfig } setIsLoadingEmbeddingModels(true); setIsLoadingRepositories(true); - describeModels(auth.user?.id_token).then((resp) => { - setEmbeddingModels(resp.data); + describeModels(auth.user?.id_token, 'embedding').then((resp) => { + setEmbeddingModels(resp); setIsLoadingEmbeddingModels(false); }); diff --git a/lib/user-interface/react/src/components/types.tsx b/lib/user-interface/react/src/components/types.tsx index ab91b866..979ac513 100644 --- a/lib/user-interface/react/src/components/types.tsx +++ b/lib/user-interface/react/src/components/types.tsx @@ -110,6 +110,15 @@ export interface Repository { * Interface for model */ export interface Model { + id: string; + modelType: ModelTypes; + streaming?: boolean; +} + +/** + * Interface for OpenAIModel that is used for OpenAI Model Interactions + */ +export interface OpenAIModel { id: string; object: string; created: number; @@ -120,7 +129,7 @@ export interface Model { * Interface for the response body received when describing a model */ export interface DescribeModelsResponseBody { - data: Model[]; + data: OpenAIModel[]; } /** diff --git a/lib/user-interface/react/src/components/utils.ts b/lib/user-interface/react/src/components/utils.ts index 078f159b..dd7f7ebe 100644 --- a/lib/user-interface/react/src/components/utils.ts +++ b/lib/user-interface/react/src/components/utils.ts @@ -16,12 +16,13 @@ import { LisaChatSession, - DescribeModelsResponseBody, LisaChatMessageFields, PutSessionRequestBody, LisaChatMessage, Repository, + ModelTypes, Model, + DescribeModelsResponseBody, } from './types'; const stripTrailingSlash = (str) => { @@ -167,12 +168,28 @@ export const deleteUserSessions = async (idToken: string) => { /** * Describes all models of a given type which are available to a user - * @param idToken the user's ID token from authenticating + * @param modelType model type we are requesting * @returns */ -export const describeModels = async (idToken: string): Promise => { +export const describeModels = async (idToken: string, modelType: ModelTypes): Promise => { const resp = await sendAuthenticatedRequest(`${RESTAPI_URI}/${RESTAPI_VERSION}/serve/models`, 'GET', idToken); - return await resp.json(); + const modelResponse = (await resp.json()) as DescribeModelsResponseBody; + + return modelResponse.data + .filter((openAiModel) => { + const configModelMatch = window.env.MODELS.filter((configModel) => configModel.model === openAiModel.id)[0]; + if (!configModelMatch || configModelMatch.modelType === modelType) { + return true; + } + }) + .map((openAiModel) => { + const configModelMatch = window.env.MODELS.filter((configModel) => configModel.model === openAiModel.id)[0]; + return { + id: openAiModel.id, + streaming: configModelMatch?.streaming, + modelType: configModelMatch?.modelType, + }; + }); }; /** diff --git a/lib/user-interface/react/src/main.tsx b/lib/user-interface/react/src/main.tsx index 6dbbcfa9..019e8e56 100644 --- a/lib/user-interface/react/src/main.tsx +++ b/lib/user-interface/react/src/main.tsx @@ -20,6 +20,7 @@ import './index.css'; import AppConfigured from './components/app-configured'; import '@cloudscape-design/global-styles/index.css'; +import { ModelTypes } from './components/types'; declare global { // eslint-disable-next-line @typescript-eslint/consistent-type-definitions @@ -36,6 +37,13 @@ declare global { backgroundColor: string; fontColor: string; }; + MODELS: [ + { + model: string; + streaming: boolean | null; + modelType: ModelTypes; + }, + ]; }; } } diff --git a/test/cdk/mocks/config.yaml b/test/cdk/mocks/config.yaml index 23e7ad6f..80846cf7 100644 --- a/test/cdk/mocks/config.yaml +++ b/test/cdk/mocks/config.yaml @@ -237,4 +237,4 @@ dev: litellmConfig: litellm_settings: telemetry: false - model_list: [ ] + model_list: