Skip to content

Commit

Permalink
Populate model dropdown from /models endpoint
Browse files Browse the repository at this point in the history
  • Loading branch information
dustins authored and petermuller committed Sep 23, 2024
1 parent 2e34474 commit dbb575c
Show file tree
Hide file tree
Showing 4 changed files with 59 additions and 67 deletions.
8 changes: 5 additions & 3 deletions lambda/authorizer/lambda_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,9 +64,11 @@ def lambda_handler(event: Dict[str, Any], context) -> Dict[str, Any]: # type: i
allow_policy["context"] = {"username": jwt_data["sub"]}

if requested_resource.startswith("/models") and not is_admin_user:
username = jwt_data.get("sub", "user")
logger.info(f"Deny access to {username} due to non-admin accessing /models api.")
return deny_policy
# non-admin users can still list models
if event["path"].rstrip("/") != "/models":
username = jwt_data.get("sub", "user")
logger.info(f"Deny access to {username} due to non-admin accessing /models api.")
return deny_policy

logger.debug(f"Generated policy: {allow_policy}")
logger.info(f"REST API authorization handler completed with 'Allow' for resource {event['methodArn']}")
Expand Down
71 changes: 31 additions & 40 deletions lib/user-interface/react/src/components/chatbot/Chat.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
limitations under the License.
*/

import { useState, useRef, useCallback, useEffect } from 'react';
import { useState, useRef, useCallback, useEffect, useMemo } from 'react';
import { useAuth } from 'react-oidc-context';
import Form from '@cloudscape-design/components/form';
import Button from '@cloudscape-design/components/button';
Expand All @@ -41,11 +41,10 @@ import { SelectProps } from '@cloudscape-design/components/select';
import StatusIndicator from '@cloudscape-design/components/status-indicator';

import Message from './Message';
import { LisaChatMessage, LisaChatSession, Model, ModelConfig, LisaChatMessageMetadata } from '../types';
import { LisaChatMessage, LisaChatSession, ModelConfig, LisaChatMessageMetadata } from '../types';
import {
getSession,
putSession,
describeModels,
isModelInterfaceHealthy,
RESTAPI_URI,
formatDocumentsAsString,
Expand All @@ -61,6 +60,8 @@ import { BufferWindowMemory } from 'langchain/memory';
import RagControls, { RagConfig } from './RagOptions';
import { ContextUploadModal, RagUploadModal } from './FileUploadModals';
import { ChatOpenAI } from '@langchain/openai';
import { useGetAllModelsQuery } from '../../shared/reducers/model-management.reducer';
import { IModel, ModelType } from '../../shared/model/model-management.model';

export default function Chat ({ sessionId }) {
const [userPrompt, setUserPrompt] = useState('');
Expand All @@ -75,11 +76,15 @@ export default function Chat ({ sessionId }) {
${humanPrefix}: {input}
${aiPrefix}:`,
);
const [models, setModels] = useState<Model[]>([]);
const [modelsOptions, setModelsOptions] = useState<SelectProps.Options>([]);
const [modelConfig, setModelConfig] = useState<ModelConfig | undefined>(undefined);
const [selectedModel, setSelectedModel] = useState<Model | undefined>(undefined);
const [selectedModelOption, setSelectedModelOption] = useState<SelectProps.Option | undefined>(undefined);

const { data: allModels, isFetching: isFetchingModels } = useGetAllModelsQuery(undefined, {selectFromResult: (state) => ({
isFetching: state.isFetching,
data: (state.data || []).filter((model) => model.modelType === ModelType.textgen),
})});
const modelsOptions = useMemo(() => allModels.map((model) => ({ label: model.modelId, value: model.modelId })), [allModels]);
const [modelConfig, setModelConfig] = useState<ModelConfig>();
const [selectedModel, setSelectedModel] = useState<IModel>();
const [selectedModelOption, setSelectedModelOption] = useState<SelectProps.Option>();
const [session, setSession] = useState<LisaChatSession>({
history: [],
sessionId: '',
Expand All @@ -89,11 +94,9 @@ export default function Chat ({ sessionId }) {
const [streamingEnabled, setStreamingEnabled] = useState(false);
const [chatHistoryBufferSize, setChatHistoryBufferSize] = useState<number>(3);
const [ragTopK, setRagTopK] = useState<number>(3);
const [modelCanStream, setModelCanStream] = useState(false);
const [isStreaming, setIsStreaming] = useState(false);
const [isConnected, setIsConnected] = useState(false);
const [isRunning, setIsRunning] = useState(false);
const [isLoadingModels, setIsLoadingModels] = useState(false);
const [metadata, setMetadata] = useState<LisaChatMessageMetadata>({});
const [showMetadata, setShowMetadata] = useState(false);
const [internalSessionId, setInternalSessionId] = useState<string | null>(null);
Expand Down Expand Up @@ -128,26 +131,10 @@ export default function Chat ({ sessionId }) {
});

useEffect(() => {
describeTextGenModels();
isBackendHealthy().then((flag) => setIsConnected(flag));
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);

useEffect(() => {
if (selectedModelOption) {
const model = models.filter((model) => model.id === selectedModelOption.value)[0];
if (!model.streaming && model.streaming !== undefined && streamingEnabled) {
setStreamingEnabled(false);
}
setModelCanStream(model.streaming || model.streaming === undefined);
setSelectedModel(model);
}
}, [models, selectedModelOption, streamingEnabled]);

useEffect(() => {
setModelsOptions(models.map((model) => ({ label: model.id, value: model.id })));
}, [models]);

useEffect(() => {
if (!isRunning && session.history.length) {
if (session.history.at(-1).type === 'ai' && !auth.isLoading) {
Expand Down Expand Up @@ -242,7 +229,7 @@ export default function Chat ({ sessionId }) {
}
const prompt = await PromptTemplate.fromTemplate(promptTemplate).format(promptValues);
const metadata: LisaChatMessageMetadata = {
modelName: selectedModel.id,
modelName: selectedModel.modelId,
modelKwargs: modelConfig,
userId: auth.user.profile.sub,
messages: prompt,
Expand Down Expand Up @@ -280,7 +267,7 @@ export default function Chat ({ sessionId }) {

const createOpenAiClient = (streaming: boolean) => {
return new ChatOpenAI({
modelName: selectedModel?.id,
modelName: selectedModel?.modelId,
openAIApiKey: auth.user?.id_token,
configuration: {
baseURL: `${RESTAPI_URI}/${RESTAPI_VERSION}/serve`,
Expand Down Expand Up @@ -350,7 +337,7 @@ export default function Chat ({ sessionId }) {
idToken: auth.user?.id_token,
repositoryId: ragConfig.repositoryId,
repositoryType: ragConfig.repositoryType,
modelName: ragConfig.embeddingModel.id,
modelName: ragConfig.embeddingModel.modelId,
topK: ragTopK,
});

Expand Down Expand Up @@ -458,13 +445,6 @@ export default function Chat ({ sessionId }) {
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [userPrompt, metadata, streamingEnabled]);

const describeTextGenModels = useCallback(async () => {
setIsLoadingModels(true);
setModels(await describeModels(auth.user?.id_token, 'textgen'));
setIsLoadingModels(false);
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);

return (
<>
<ModelKwargsEditor
Expand Down Expand Up @@ -589,7 +569,7 @@ export default function Chat ({ sessionId }) {
<div className='flex mb-2 justify-end mt-3'>
<div>
<Button
disabled={!models.length || isRunning || !selectedModel || userPrompt === ''}
disabled={!allModels.length || isRunning || !selectedModel || userPrompt === ''}
onClick={handleSendGenerateRequest}
iconAlign='right'
iconName='angle-right-double'
Expand Down Expand Up @@ -647,20 +627,31 @@ export default function Chat ({ sessionId }) {
>
<Select
disabled={isRunning}
statusType={isLoadingModels ? 'loading' : 'finished'}
statusType={isFetchingModels ? 'loading' : 'finished'}
loadingText='Loading models (might take few seconds)...'
placeholder='Select a model'
empty={<div className='text-gray-500'>No models available.</div>}
filteringType='auto'
selectedOption={selectedModelOption}
onChange={({ detail }) => setSelectedModelOption(detail.selectedOption)}
onChange={({ detail: { selectedOption } }) => {
setSelectedModelOption(selectedOption);

const model = allModels.find((model) => model.modelId === selectedOption.value);
if (model) {
if (!model.streaming && streamingEnabled) {
setStreamingEnabled(false);
}

setSelectedModel(model);
}
}}
options={modelsOptions}
/>
<div style={{ paddingTop: 4 }}>
<Toggle
onChange={({ detail }) => setStreamingEnabled(detail.checked)}
checked={streamingEnabled}
disabled={!modelCanStream || isRunning}
disabled={!selectedModel?.streaming || isRunning}
>
Streaming
</Toggle>
Expand Down
45 changes: 22 additions & 23 deletions lib/user-interface/react/src/components/chatbot/RagOptions.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,14 @@
*/

import { Button, Grid, Select, SelectProps, SpaceBetween } from '@cloudscape-design/components';
import { useEffect, useState } from 'react';
import { Model } from '../types';
import { describeModels, listRagRepositories } from '../utils';
import { useEffect, useMemo, useState } from 'react';
import { listRagRepositories } from '../utils';
import { AuthContextProps } from 'react-oidc-context';
import { useGetAllModelsQuery } from '../../shared/reducers/model-management.reducer';
import { IModel, ModelType } from '../../shared/model/model-management.model';

export type RagConfig = {
embeddingModel: Model;
embeddingModel: IModel;
repositoryId: string;
repositoryType: string;
};
Expand All @@ -34,24 +35,22 @@ type RagControlProps = {
};

export default function RagControls ({ auth, isRunning, setUseRag, setRagConfig }: RagControlProps) {
const [embeddingModels, setEmbeddingModels] = useState<Model[]>([]);
const [embeddingOptions, setEmbeddingOptions] = useState<SelectProps.Options>([]);
const [isLoadingEmbeddingModels, setIsLoadingEmbeddingModels] = useState(false);
const [isLoadingRepositories, setIsLoadingRepositories] = useState(false);
const [repositoryOptions, setRepositoryOptions] = useState<SelectProps.Options>([]);
const [selectedEmbeddingOption, setSelectedEmbeddingOption] = useState<SelectProps.Option | undefined>(undefined);
const [selectedRepositoryOption, setSelectedRepositoryOption] = useState<SelectProps.Option | undefined>(undefined);
const [repositoryMap, setRepositoryMap] = useState(new Map());
const { data: allModels, isFetching: isFetchingModels } = useGetAllModelsQuery(undefined, {selectFromResult: (state) => ({
isFetching: state.isFetching,
data: (state.data || []).filter((model) => model.modelType === ModelType.embedding),
})});
const embeddingOptions = useMemo(() => {
return allModels?.map((model) => ({value: model.modelId})) || [];
}, [allModels]);

useEffect(() => {
setIsLoadingEmbeddingModels(true);
setIsLoadingRepositories(true);

describeModels(auth.user?.id_token, 'embedding').then((resp) => {
setEmbeddingModels(resp);
setIsLoadingEmbeddingModels(false);
});

listRagRepositories(auth.user?.id_token).then((repositories) => {
setRepositoryOptions(
repositories.map((repo) => {
Expand All @@ -68,10 +67,6 @@ export default function RagControls ({ auth, isRunning, setUseRag, setRagConfig
// eslint-disable-next-line react-hooks/exhaustive-deps
}, []);

useEffect(() => {
setEmbeddingOptions(embeddingModels.map((model) => ({ label: model.id, value: model.id })));
}, [embeddingModels]);

useEffect(() => {
setUseRag(!!selectedEmbeddingOption && !!selectedRepositoryOption);
// setUseRag is never going to change as it's just a setState function
Expand Down Expand Up @@ -120,19 +115,23 @@ export default function RagControls ({ auth, isRunning, setUseRag, setRagConfig
Clear
</Button>
<Select
disabled={isRunning}
statusType={isLoadingEmbeddingModels ? 'loading' : 'finished'}
disabled={!selectedRepositoryOption || isRunning}
statusType={isFetchingModels ? 'loading' : 'finished'}
loadingText='Loading embedding models (might take few seconds)...'
placeholder='Select an embedding model'
empty={<div className='text-gray-500'>No embedding models available.</div>}
filteringType='auto'
selectedOption={selectedEmbeddingOption}
onChange={({ detail }) => {
setSelectedEmbeddingOption(detail.selectedOption);
setRagConfig((config) => ({
...config,
embeddingModel: embeddingModels.filter((model) => model.id === detail.selectedOption.value)[0],
}));

const model = allModels.find((model) => model.modelId === detail.selectedOption.value);
if (model) {
setRagConfig((config) => ({
...config,
embeddingModel: model,
}));
}
}}
options={embeddingOptions}
/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ export const modelManagementApi = createApi({
reducerPath: 'models',
baseQuery: lisaBaseQuery(),
endpoints: (builder) => ({
getAllModels: builder.query<IModelListResponse, void>({
getAllModels: builder.query<IModelListResponse['models'], void>({
query: () => ({
url: '/models',
}),
Expand Down

0 comments on commit dbb575c

Please sign in to comment.