Skip to content

Commit

Permalink
Split textgen and embedding models in UI (#19)
Browse files Browse the repository at this point in the history
* Updating to pull Models and supplemental info from configs vs API

* updating README env.js to enclude new models array

* Updates to call describe model API, this allows us to pull in models and check that everything is present

* Add streaming and model type metadata to LiteLLM models

* Add parameter validation to LiteLLM LISA params

---------

Co-authored-by: Peter Muller <[email protected]>
  • Loading branch information
estohlmann and petermuller authored Jun 5, 2024
1 parent b7a4be5 commit bb0d1de
Show file tree
Hide file tree
Showing 10 changed files with 121 additions and 15 deletions.
19 changes: 18 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,24 @@ window.env = {
// Alternatively you can set this to be your REST api elb endpoint
RESTAPI_URI: 'http://localhost:8080/',
RESTAPI_VERSION: 'v2',
SESSION_REST_API_URI: '<API GW session endpoint>'
SESSION_REST_API_URI: '<API GW session endpoint>',
"MODELS": [
{
"model": "streaming-textgen-model",
"streaming": true,
"modelType": "textgen"
},
{
"model": "non-streaming-textgen-model",
"streaming": false,
"modelType": "textgen"
},
{
"model": "embedding-model",
"streaming": null,
"modelType": "embedding"
}
]
}
```

Expand Down
13 changes: 12 additions & 1 deletion example_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,18 @@ dev:
# Anything within this config is copied to a configuration for starting LiteLLM in the REST API container.
# It is suggested to put an "ignored" API key so that calls to locally hosted models don't fail on OpenAI calls
# from LiteLLM.
# We added `lisa_params` to add additional metadata for interaction with the Chat UI. Specify if the model is a
# textgen or embedding model, and if it is textgen, specify whether it supports streaming. If embedding, then
# omit the `streaming` parameter. When defining the model list, the `lisa_params` will be an object in the model
# definition that will have the `model_type` and `streaming` fields in it. A commented example is provided below.
litellmConfig:
litellm_settings:
telemetry: false # Don't try to send telemetry to LiteLLM servers.
model_list: [] # Add any of your existing (not LISA-hosted) models here.
model_list: # Add any of your existing (not LISA-hosted) models here.
# - model_name: mymodel
# litellm_params:
# model: openai/myprovider/mymodel
# api_key: ignored
# lisa_params:
# model_type: textgen
# streaming: true
25 changes: 24 additions & 1 deletion lib/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,10 @@ const ApiGatewayConfigSchema = z
/**
* Configuration for models inside the LiteLLM Config
* See https://litellm.vercel.app/docs/proxy/configs#all-settings for more details.
*
* The `lisa_params` are custom for the LISA installation to add model metadata to allow the models to be referenced
* correctly within the Chat UI. LiteLLM will ignore these parameters as it is not looking for them, and it will not
* fail to initialize as a result of them existing.
*/
const LiteLLMModel = z.object({
model_name: z.string(),
Expand All @@ -686,6 +690,25 @@ const LiteLLMModel = z.object({
api_key: z.string().optional(),
aws_region_name: z.string().optional(),
}),
lisa_params: z
.object({
streaming: z.boolean().nullable().default(null),
model_type: z.nativeEnum(ModelType),
})
.refine(
(data) => {
// 'textgen' type must have boolean streaming, 'embedding' type must have null streaming
const isValidForTextgen = data.model_type === 'textgen' && typeof data.streaming === 'boolean';
const isValidForEmbedding = data.model_type === 'embedding' && data.streaming === null;

return isValidForTextgen || isValidForEmbedding;
},
{
message: `For 'textgen' models, 'streaming' must be true or false.
For 'embedding' models, 'streaming' must not be set.`,
path: ['streaming'],
},
),
model_info: z
.object({
id: z.string().optional(),
Expand All @@ -704,7 +727,7 @@ const LiteLLMModel = z.object({
*/
const LiteLLMConfig = z.object({
environment_variables: z.map(z.string(), z.string()).optional(),
model_list: z.array(LiteLLMModel).optional(),
model_list: z.array(LiteLLMModel).optional().nullable().default([]),
litellm_settings: z.object({
// ALL (https://github.com/BerriAI/litellm/blob/main/litellm/__init__.py)
telemetry: z.boolean().default(false).optional(),
Expand Down
19 changes: 19 additions & 0 deletions lib/user-interface/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,24 @@ export class UserInterfaceStack extends Stack {
},
);

const ecsModels = config.ecsModels.map((modelConfig) => {
return {
model: modelConfig.modelId,
streaming: modelConfig.streaming,
modelType: modelConfig.modelType,
};
});
const litellmModels = config.litellmConfig.model_list ? config.litellmConfig.model_list : [];
const modelsList = ecsModels.concat(
litellmModels.map((model) => {
return {
model: model.model_name,
streaming: model.lisa_params.streaming,
modelType: model.lisa_params.model_type,
};
}),
);

// Website bucket deployment
// Copy auth and LISA-Serve info to UI deployment bucket
const appEnvConfig = {
Expand All @@ -179,6 +197,7 @@ export class UserInterfaceStack extends Stack {
fontColor: config.systemBanner?.fontColor,
},
API_BASE_URL: config.apiGatewayConfig?.domainName ? '/' : `/${config.deploymentStage}/`,
MODELS: modelsList,
};

const appEnvSource = Source.data('env.js', `window.env = ${JSON.stringify(appEnvConfig)}`);
Expand Down
10 changes: 6 additions & 4 deletions lib/user-interface/react/src/components/chatbot/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,13 @@ export default function Chat({ sessionId }) {
useEffect(() => {
if (selectedModelOption) {
const model = models.filter((model) => model.id === selectedModelOption.value)[0];
setModelCanStream(true);
if (!model.streaming && model.streaming !== undefined && streamingEnabled) {
setStreamingEnabled(false);
}
setModelCanStream(model.streaming || model.streaming === undefined);
setSelectedModel(model);
}
}, [selectedModelOption, streamingEnabled]);
}, [models, selectedModelOption, streamingEnabled]);

useEffect(() => {
setModelsOptions(models.map((model) => ({ label: model.id, value: model.id })));
Expand Down Expand Up @@ -463,8 +466,7 @@ export default function Chat({ sessionId }) {

const describeTextGenModels = useCallback(async () => {
setIsLoadingModels(true);
const resp = await describeModels(auth.user?.id_token);
setModels(resp.data);
setModels(await describeModels(auth.user?.id_token, 'textgen'));
setIsLoadingModels(false);
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ export default function RagControls({ auth, isRunning, setUseRag, setRagConfig }
setIsLoadingEmbeddingModels(true);
setIsLoadingRepositories(true);

describeModels(auth.user?.id_token).then((resp) => {
setEmbeddingModels(resp.data);
describeModels(auth.user?.id_token, 'embedding').then((resp) => {
setEmbeddingModels(resp);
setIsLoadingEmbeddingModels(false);
});

Expand Down
11 changes: 10 additions & 1 deletion lib/user-interface/react/src/components/types.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,15 @@ export interface Repository {
* Interface for model
*/
export interface Model {
id: string;
modelType: ModelTypes;
streaming?: boolean;
}

/**
* Interface for OpenAIModel that is used for OpenAI Model Interactions
*/
export interface OpenAIModel {
id: string;
object: string;
created: number;
Expand All @@ -120,7 +129,7 @@ export interface Model {
* Interface for the response body received when describing a model
*/
export interface DescribeModelsResponseBody {
data: Model[];
data: OpenAIModel[];
}

/**
Expand Down
25 changes: 21 additions & 4 deletions lib/user-interface/react/src/components/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

import {
LisaChatSession,
DescribeModelsResponseBody,
LisaChatMessageFields,
PutSessionRequestBody,
LisaChatMessage,
Repository,
ModelTypes,
Model,
DescribeModelsResponseBody,
} from './types';

const stripTrailingSlash = (str) => {
Expand Down Expand Up @@ -167,12 +168,28 @@ export const deleteUserSessions = async (idToken: string) => {

/**
* Describes all models of a given type which are available to a user
* @param idToken the user's ID token from authenticating
* @param modelType model type we are requesting
* @returns
*/
export const describeModels = async (idToken: string): Promise<DescribeModelsResponseBody> => {
export const describeModels = async (idToken: string, modelType: ModelTypes): Promise<Model[]> => {
const resp = await sendAuthenticatedRequest(`${RESTAPI_URI}/${RESTAPI_VERSION}/serve/models`, 'GET', idToken);
return await resp.json();
const modelResponse = (await resp.json()) as DescribeModelsResponseBody;

return modelResponse.data
.filter((openAiModel) => {
const configModelMatch = window.env.MODELS.filter((configModel) => configModel.model === openAiModel.id)[0];
if (!configModelMatch || configModelMatch.modelType === modelType) {
return true;
}
})
.map((openAiModel) => {
const configModelMatch = window.env.MODELS.filter((configModel) => configModel.model === openAiModel.id)[0];
return {
id: openAiModel.id,
streaming: configModelMatch?.streaming,
modelType: configModelMatch?.modelType,
};
});
};

/**
Expand Down
8 changes: 8 additions & 0 deletions lib/user-interface/react/src/main.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import './index.css';
import AppConfigured from './components/app-configured';

import '@cloudscape-design/global-styles/index.css';
import { ModelTypes } from './components/types';

declare global {
// eslint-disable-next-line @typescript-eslint/consistent-type-definitions
Expand All @@ -36,6 +37,13 @@ declare global {
backgroundColor: string;
fontColor: string;
};
MODELS: [
{
model: string;
streaming: boolean | null;
modelType: ModelTypes;
},
];
};
}
}
Expand Down
2 changes: 1 addition & 1 deletion test/cdk/mocks/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -237,4 +237,4 @@ dev:
litellmConfig:
litellm_settings:
telemetry: false
model_list: [ ]
model_list:

0 comments on commit bb0d1de

Please sign in to comment.