Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Split textgen and embedding models in UI #19

Merged
merged 5 commits into from
Jun 5, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 18 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -415,7 +415,24 @@ window.env = {
// Alternatively you can set this to be your REST api elb endpoint
RESTAPI_URI: 'http://localhost:8080/',
RESTAPI_VERSION: 'v2',
SESSION_REST_API_URI: '<API GW session endpoint>'
SESSION_REST_API_URI: '<API GW session endpoint>',
"MODELS": [
{
"model": "streaming-textgen-model",
"streaming": true,
"modelType": "textgen"
},
{
"model": "non-streaming-textgen-model",
"streaming": false,
"modelType": "textgen"
},
{
"model": "embedding-model",
"streaming": null,
"modelType": "embedding"
}
]
}
```

Expand Down
13 changes: 12 additions & 1 deletion example_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,18 @@ dev:
# Anything within this config is copied to a configuration for starting LiteLLM in the REST API container.
# It is suggested to put an "ignored" API key so that calls to locally hosted models don't fail on OpenAI calls
# from LiteLLM.
# We added `lisa_params` to add additional metadata for interaction with the Chat UI. Specify if the model is a
# textgen or embedding model, and if it is textgen, specify whether it supports streaming. If embedding, then
# omit the `streaming` parameter. When defining the model list, the `lisa_params` will be an object in the model
# definition that will have the `model_type` and `streaming` fields in it. A commented example is provided below.
litellmConfig:
litellm_settings:
telemetry: false # Don't try to send telemetry to LiteLLM servers.
model_list: [] # Add any of your existing (not LISA-hosted) models here.
model_list: # Add any of your existing (not LISA-hosted) models here.
# - model_name: mymodel
# litellm_params:
# model: openai/myprovider/mymodel
# api_key: ignored
# lisa_params:
# model_type: textgen
# streaming: true
25 changes: 24 additions & 1 deletion lib/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -677,6 +677,10 @@ const ApiGatewayConfigSchema = z
/**
* Configuration for models inside the LiteLLM Config
* See https://litellm.vercel.app/docs/proxy/configs#all-settings for more details.
*
* The `lisa_params` are custom for the LISA installation to add model metadata to allow the models to be referenced
* correctly within the Chat UI. LiteLLM will ignore these parameters as it is not looking for them, and it will not
* fail to initialize as a result of them existing.
*/
const LiteLLMModel = z.object({
model_name: z.string(),
Expand All @@ -686,6 +690,25 @@ const LiteLLMModel = z.object({
api_key: z.string().optional(),
aws_region_name: z.string().optional(),
}),
lisa_params: z
.object({
streaming: z.boolean().nullable().default(null),
model_type: z.nativeEnum(ModelType),
})
.refine(
(data) => {
// 'textgen' type must have boolean streaming, 'embedding' type must have null streaming
const isValidForTextgen = data.model_type === 'textgen' && typeof data.streaming === 'boolean';
const isValidForEmbedding = data.model_type === 'embedding' && data.streaming === null;

return isValidForTextgen || isValidForEmbedding;
},
{
message: `For 'textgen' models, 'streaming' must be true or false.
For 'embedding' models, 'streaming' must not be set.`,
path: ['streaming'],
},
),
model_info: z
.object({
id: z.string().optional(),
Expand All @@ -704,7 +727,7 @@ const LiteLLMModel = z.object({
*/
const LiteLLMConfig = z.object({
environment_variables: z.map(z.string(), z.string()).optional(),
model_list: z.array(LiteLLMModel).optional(),
model_list: z.array(LiteLLMModel).optional().nullable().default([]),
litellm_settings: z.object({
// ALL (https://github.com/BerriAI/litellm/blob/main/litellm/__init__.py)
telemetry: z.boolean().default(false).optional(),
Expand Down
19 changes: 19 additions & 0 deletions lib/user-interface/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,24 @@ export class UserInterfaceStack extends Stack {
},
);

const ecsModels = config.ecsModels.map((modelConfig) => {
return {
model: modelConfig.modelId,
streaming: modelConfig.streaming,
modelType: modelConfig.modelType,
};
});
const litellmModels = config.litellmConfig.model_list ? config.litellmConfig.model_list : [];
const modelsList = ecsModels.concat(
litellmModels.map((model) => {
return {
model: model.model_name,
streaming: model.lisa_params.streaming,
modelType: model.lisa_params.model_type,
};
}),
);

// Website bucket deployment
// Copy auth and LISA-Serve info to UI deployment bucket
const appEnvConfig = {
Expand All @@ -179,6 +197,7 @@ export class UserInterfaceStack extends Stack {
fontColor: config.systemBanner?.fontColor,
},
API_BASE_URL: config.apiGatewayConfig?.domainName ? '/' : `/${config.deploymentStage}/`,
MODELS: modelsList,
};

const appEnvSource = Source.data('env.js', `window.env = ${JSON.stringify(appEnvConfig)}`);
Expand Down
10 changes: 6 additions & 4 deletions lib/user-interface/react/src/components/chatbot/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,13 @@ export default function Chat({ sessionId }) {
useEffect(() => {
if (selectedModelOption) {
const model = models.filter((model) => model.id === selectedModelOption.value)[0];
setModelCanStream(true);
if (!model.streaming && model.streaming !== undefined && streamingEnabled) {
setStreamingEnabled(false);
}
setModelCanStream(model.streaming || model.streaming === undefined);
setSelectedModel(model);
}
}, [selectedModelOption, streamingEnabled]);
}, [models, selectedModelOption, streamingEnabled]);

useEffect(() => {
setModelsOptions(models.map((model) => ({ label: model.id, value: model.id })));
Expand Down Expand Up @@ -463,8 +466,7 @@ export default function Chat({ sessionId }) {

const describeTextGenModels = useCallback(async () => {
setIsLoadingModels(true);
const resp = await describeModels(auth.user?.id_token);
setModels(resp.data);
setModels(await describeModels(auth.user?.id_token, 'textgen'));
setIsLoadingModels(false);
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,8 @@ export default function RagControls({ auth, isRunning, setUseRag, setRagConfig }
setIsLoadingEmbeddingModels(true);
setIsLoadingRepositories(true);

describeModels(auth.user?.id_token).then((resp) => {
setEmbeddingModels(resp.data);
describeModels(auth.user?.id_token, 'embedding').then((resp) => {
setEmbeddingModels(resp);
setIsLoadingEmbeddingModels(false);
});

Expand Down
11 changes: 10 additions & 1 deletion lib/user-interface/react/src/components/types.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,15 @@ export interface Repository {
* Interface for model
*/
export interface Model {
id: string;
modelType: ModelTypes;
streaming?: boolean;
}

/**
* Interface for OpenAIModel that is used for OpenAI Model Interactions
*/
export interface OpenAIModel {
id: string;
object: string;
created: number;
Expand All @@ -120,7 +129,7 @@ export interface Model {
* Interface for the response body received when describing a model
*/
export interface DescribeModelsResponseBody {
data: Model[];
data: OpenAIModel[];
}

/**
Expand Down
25 changes: 21 additions & 4 deletions lib/user-interface/react/src/components/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,13 @@

import {
LisaChatSession,
DescribeModelsResponseBody,
LisaChatMessageFields,
PutSessionRequestBody,
LisaChatMessage,
Repository,
ModelTypes,
Model,
DescribeModelsResponseBody,
} from './types';

const stripTrailingSlash = (str) => {
Expand Down Expand Up @@ -167,12 +168,28 @@ export const deleteUserSessions = async (idToken: string) => {

/**
* Describes all models of a given type which are available to a user
* @param idToken the user's ID token from authenticating
* @param modelType model type we are requesting
* @returns
*/
export const describeModels = async (idToken: string): Promise<DescribeModelsResponseBody> => {
export const describeModels = async (idToken: string, modelType: ModelTypes): Promise<Model[]> => {
const resp = await sendAuthenticatedRequest(`${RESTAPI_URI}/${RESTAPI_VERSION}/serve/models`, 'GET', idToken);
return await resp.json();
const modelResponse = (await resp.json()) as DescribeModelsResponseBody;

return modelResponse.data
.filter((openAiModel) => {
const configModelMatch = window.env.MODELS.filter((configModel) => configModel.model === openAiModel.id)[0];
if (!configModelMatch || configModelMatch.modelType === modelType) {
return true;
}
})
.map((openAiModel) => {
const configModelMatch = window.env.MODELS.filter((configModel) => configModel.model === openAiModel.id)[0];
return {
id: openAiModel.id,
streaming: configModelMatch?.streaming,
modelType: configModelMatch?.modelType,
};
});
};

/**
Expand Down
8 changes: 8 additions & 0 deletions lib/user-interface/react/src/main.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ import './index.css';
import AppConfigured from './components/app-configured';

import '@cloudscape-design/global-styles/index.css';
import { ModelTypes } from './components/types';

declare global {
// eslint-disable-next-line @typescript-eslint/consistent-type-definitions
Expand All @@ -36,6 +37,13 @@ declare global {
backgroundColor: string;
fontColor: string;
};
MODELS: [
{
model: string;
streaming: boolean | null;
modelType: ModelTypes;
},
];
};
}
}
Expand Down
2 changes: 1 addition & 1 deletion test/cdk/mocks/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -237,4 +237,4 @@ dev:
litellmConfig:
litellm_settings:
telemetry: false
model_list: [ ]
model_list:
Loading