Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Populate model dropdown from /models endpoint #110

Merged
merged 1 commit into from
Sep 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions lambda/authorizer/lambda_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ def lambda_handler(event: Dict[str, Any], context) -> Dict[str, Any]: # type: i
allow_policy["context"] = {"username": jwt_data["sub"]}

if requested_resource.startswith("/models") and not is_admin_user:
username = jwt_data.get("sub", "user")
logger.info(f"Deny access to {username} due to non-admin accessing /models api.")
return deny_policy
# non-admin users can still list models
if event["path"].rstrip("/") != "/models":
username = jwt_data.get("sub", "user")
logger.info(f"Deny access to {username} due to non-admin accessing /models api.")
return deny_policy

logger.debug(f"Generated policy: {allow_policy}")
logger.info(f"REST API authorization handler completed with 'Allow' for resource {event['methodArn']}")
Expand Down
71 changes: 31 additions & 40 deletions lib/user-interface/react/src/components/chatbot/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
limitations under the License.
*/

import { useState, useRef, useCallback, useEffect } from 'react';
import { useState, useRef, useCallback, useEffect, useMemo } from 'react';
import { useAuth } from 'react-oidc-context';
import Form from '@cloudscape-design/components/form';
import Button from '@cloudscape-design/components/button';
Expand All @@ -41,11 +41,10 @@ import { SelectProps } from '@cloudscape-design/components/select';
import StatusIndicator from '@cloudscape-design/components/status-indicator';

import Message from './Message';
import { LisaChatMessage, LisaChatSession, Model, ModelConfig, LisaChatMessageMetadata } from '../types';
import { LisaChatMessage, LisaChatSession, ModelConfig, LisaChatMessageMetadata } from '../types';
import {
getSession,
putSession,
describeModels,
isModelInterfaceHealthy,
RESTAPI_URI,
formatDocumentsAsString,
Expand All @@ -61,6 +60,8 @@ import { BufferWindowMemory } from 'langchain/memory';
import RagControls, { RagConfig } from './RagOptions';
import { ContextUploadModal, RagUploadModal } from './FileUploadModals';
import { ChatOpenAI } from '@langchain/openai';
import { useGetAllModelsQuery } from '../../shared/reducers/model-management.reducer';
import { IModel, ModelType } from '../../shared/model/model-management.model';

export default function Chat ({ sessionId }) {
const [userPrompt, setUserPrompt] = useState('');
Expand All @@ -75,11 +76,15 @@ export default function Chat ({ sessionId }) {
${humanPrefix}: {input}
${aiPrefix}:`,
);
const [models, setModels] = useState<Model[]>([]);
const [modelsOptions, setModelsOptions] = useState<SelectProps.Options>([]);
const [modelConfig, setModelConfig] = useState<ModelConfig | undefined>(undefined);
const [selectedModel, setSelectedModel] = useState<Model | undefined>(undefined);
const [selectedModelOption, setSelectedModelOption] = useState<SelectProps.Option | undefined>(undefined);

const { data: allModels, isFetching: isFetchingModels } = useGetAllModelsQuery(undefined, {selectFromResult: (state) => ({
isFetching: state.isFetching,
data: (state.data || []).filter((model) => model.modelType === ModelType.textgen),
})});
const modelsOptions = useMemo(() => allModels.map((model) => ({ label: model.modelId, value: model.modelId })), [allModels]);
const [modelConfig, setModelConfig] = useState<ModelConfig>();
const [selectedModel, setSelectedModel] = useState<IModel>();
const [selectedModelOption, setSelectedModelOption] = useState<SelectProps.Option>();
const [session, setSession] = useState<LisaChatSession>({
history: [],
sessionId: '',
Expand All @@ -89,11 +94,9 @@ export default function Chat ({ sessionId }) {
const [streamingEnabled, setStreamingEnabled] = useState(false);
const [chatHistoryBufferSize, setChatHistoryBufferSize] = useState<number>(3);
const [ragTopK, setRagTopK] = useState<number>(3);
const [modelCanStream, setModelCanStream] = useState(false);
const [isStreaming, setIsStreaming] = useState(false);
const [isConnected, setIsConnected] = useState(false);
const [isRunning, setIsRunning] = useState(false);
const [isLoadingModels, setIsLoadingModels] = useState(false);
const [metadata, setMetadata] = useState<LisaChatMessageMetadata>({});
const [showMetadata, setShowMetadata] = useState(false);
const [internalSessionId, setInternalSessionId] = useState<string | null>(null);
Expand Down Expand Up @@ -128,26 +131,10 @@ export default function Chat ({ sessionId }) {
});

useEffect(() => {
describeTextGenModels();
isBackendHealthy().then((flag) => setIsConnected(flag));
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);

useEffect(() => {
if (selectedModelOption) {
const model = models.filter((model) => model.id === selectedModelOption.value)[0];
if (!model.streaming && model.streaming !== undefined && streamingEnabled) {
setStreamingEnabled(false);
}
setModelCanStream(model.streaming || model.streaming === undefined);
setSelectedModel(model);
}
}, [models, selectedModelOption, streamingEnabled]);

useEffect(() => {
setModelsOptions(models.map((model) => ({ label: model.id, value: model.id })));
}, [models]);

useEffect(() => {
if (!isRunning && session.history.length) {
if (session.history.at(-1).type === 'ai' && !auth.isLoading) {
Expand Down Expand Up @@ -242,7 +229,7 @@ export default function Chat ({ sessionId }) {
}
const prompt = await PromptTemplate.fromTemplate(promptTemplate).format(promptValues);
const metadata: LisaChatMessageMetadata = {
modelName: selectedModel.id,
modelName: selectedModel.modelId,
modelKwargs: modelConfig,
userId: auth.user.profile.sub,
messages: prompt,
Expand Down Expand Up @@ -280,7 +267,7 @@ export default function Chat ({ sessionId }) {

const createOpenAiClient = (streaming: boolean) => {
return new ChatOpenAI({
modelName: selectedModel?.id,
modelName: selectedModel?.modelId,
openAIApiKey: auth.user?.id_token,
configuration: {
baseURL: `${RESTAPI_URI}/${RESTAPI_VERSION}/serve`,
Expand Down Expand Up @@ -350,7 +337,7 @@ export default function Chat ({ sessionId }) {
idToken: auth.user?.id_token,
repositoryId: ragConfig.repositoryId,
repositoryType: ragConfig.repositoryType,
modelName: ragConfig.embeddingModel.id,
modelName: ragConfig.embeddingModel.modelId,
topK: ragTopK,
});

Expand Down Expand Up @@ -458,13 +445,6 @@ export default function Chat ({ sessionId }) {
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [userPrompt, metadata, streamingEnabled]);

const describeTextGenModels = useCallback(async () => {
setIsLoadingModels(true);
setModels(await describeModels(auth.user?.id_token, 'textgen'));
setIsLoadingModels(false);
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);

return (
<>
<ModelKwargsEditor
Expand Down Expand Up @@ -589,7 +569,7 @@ export default function Chat ({ sessionId }) {
<div className='flex mb-2 justify-end mt-3'>
<div>
<Button
disabled={!models.length || isRunning || !selectedModel || userPrompt === ''}
disabled={!allModels.length || isRunning || !selectedModel || userPrompt === ''}
onClick={handleSendGenerateRequest}
iconAlign='right'
iconName='angle-right-double'
Expand Down Expand Up @@ -647,20 +627,31 @@ export default function Chat ({ sessionId }) {
>
<Select
disabled={isRunning}
statusType={isLoadingModels ? 'loading' : 'finished'}
statusType={isFetchingModels ? 'loading' : 'finished'}
loadingText='Loading models (might take few seconds)...'
placeholder='Select a model'
empty={<div className='text-gray-500'>No models available.</div>}
filteringType='auto'
selectedOption={selectedModelOption}
onChange={({ detail }) => setSelectedModelOption(detail.selectedOption)}
onChange={({ detail: { selectedOption } }) => {
setSelectedModelOption(selectedOption);

const model = allModels.find((model) => model.modelId === selectedOption.value);
if (model) {
if (!model.streaming && streamingEnabled) {
setStreamingEnabled(false);
}

setSelectedModel(model);
}
}}
options={modelsOptions}
/>
<div style={{ paddingTop: 4 }}>
<Toggle
onChange={({ detail }) => setStreamingEnabled(detail.checked)}
checked={streamingEnabled}
disabled={!modelCanStream || isRunning}
disabled={!selectedModel?.streaming || isRunning}
>
Streaming
</Toggle>
Expand Down
45 changes: 22 additions & 23 deletions lib/user-interface/react/src/components/chatbot/RagOptions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
*/

import { Button, Grid, Select, SelectProps, SpaceBetween } from '@cloudscape-design/components';
import { useEffect, useState } from 'react';
import { Model } from '../types';
import { describeModels, listRagRepositories } from '../utils';
import { useEffect, useMemo, useState } from 'react';
import { listRagRepositories } from '../utils';
import { AuthContextProps } from 'react-oidc-context';
import { useGetAllModelsQuery } from '../../shared/reducers/model-management.reducer';
import { IModel, ModelType } from '../../shared/model/model-management.model';

export type RagConfig = {
embeddingModel: Model;
embeddingModel: IModel;
repositoryId: string;
repositoryType: string;
};
Expand All @@ -34,24 +35,22 @@ type RagControlProps = {
};

export default function RagControls ({ auth, isRunning, setUseRag, setRagConfig }: RagControlProps) {
const [embeddingModels, setEmbeddingModels] = useState<Model[]>([]);
const [embeddingOptions, setEmbeddingOptions] = useState<SelectProps.Options>([]);
const [isLoadingEmbeddingModels, setIsLoadingEmbeddingModels] = useState(false);
const [isLoadingRepositories, setIsLoadingRepositories] = useState(false);
const [repositoryOptions, setRepositoryOptions] = useState<SelectProps.Options>([]);
const [selectedEmbeddingOption, setSelectedEmbeddingOption] = useState<SelectProps.Option | undefined>(undefined);
const [selectedRepositoryOption, setSelectedRepositoryOption] = useState<SelectProps.Option | undefined>(undefined);
const [repositoryMap, setRepositoryMap] = useState(new Map());
const { data: allModels, isFetching: isFetchingModels } = useGetAllModelsQuery(undefined, {selectFromResult: (state) => ({
isFetching: state.isFetching,
data: (state.data || []).filter((model) => model.modelType === ModelType.embedding),
})});
const embeddingOptions = useMemo(() => {
return allModels?.map((model) => ({value: model.modelId})) || [];
}, [allModels]);

useEffect(() => {
setIsLoadingEmbeddingModels(true);
setIsLoadingRepositories(true);

describeModels(auth.user?.id_token, 'embedding').then((resp) => {
setEmbeddingModels(resp);
setIsLoadingEmbeddingModels(false);
});

listRagRepositories(auth.user?.id_token).then((repositories) => {
setRepositoryOptions(
repositories.map((repo) => {
Expand All @@ -68,10 +67,6 @@ export default function RagControls ({ auth, isRunning, setUseRag, setRagConfig
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);

useEffect(() => {
setEmbeddingOptions(embeddingModels.map((model) => ({ label: model.id, value: model.id })));
}, [embeddingModels]);

useEffect(() => {
setUseRag(!!selectedEmbeddingOption && !!selectedRepositoryOption);
// setUseRag is never going to change as it's just a setState function
Expand Down Expand Up @@ -120,19 +115,23 @@ export default function RagControls ({ auth, isRunning, setUseRag, setRagConfig
Clear
</Button>
<Select
disabled={isRunning}
statusType={isLoadingEmbeddingModels ? 'loading' : 'finished'}
disabled={!selectedRepositoryOption || isRunning}
statusType={isFetchingModels ? 'loading' : 'finished'}
loadingText='Loading embedding models (might take few seconds)...'
placeholder='Select an embedding model'
empty={<div className='text-gray-500'>No embedding models available.</div>}
filteringType='auto'
selectedOption={selectedEmbeddingOption}
onChange={({ detail }) => {
setSelectedEmbeddingOption(detail.selectedOption);
setRagConfig((config) => ({
...config,
embeddingModel: embeddingModels.filter((model) => model.id === detail.selectedOption.value)[0],
}));

const model = allModels.find((model) => model.modelId === detail.selectedOption.value);
if (model) {
setRagConfig((config) => ({
...config,
embeddingModel: model,
}));
}
}}
options={embeddingOptions}
/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ export const modelManagementApi = createApi({
reducerPath: 'models',
baseQuery: lisaBaseQuery(),
endpoints: (builder) => ({
getAllModels: builder.query<IModelListResponse, void>({
getAllModels: builder.query<IModelListResponse['models'], void>({
query: () => ({
url: '/models',
}),
Expand Down
Loading