From 23fb773b643de56a288bbcf5cb9ec41060fb2d68 Mon Sep 17 00:00:00 2001 From: Evan Stohlmann Date: Thu, 5 Dec 2024 13:36:37 -0700 Subject: [PATCH] Update create model validation; Ensure Base image is set for LISA hosted models; --- .pre-commit-config.yaml | 2 +- lambda/models/handler/create_model_handler.py | 36 +++++++++++++++++++ .../create-model/CreateModelModal.tsx | 5 ++- .../shared/model/model-management.model.ts | 11 ++++++ .../react/src/shared/validation/index.ts | 13 +++++-- 5 files changed, 60 insertions(+), 7 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7a894344..a93384e5 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -78,7 +78,7 @@ repos: args: - --max-line-length=120 - --extend-immutable-calls=Query,fastapi.Depends,fastapi.params.Depends - - --ignore=B008,E203 # Ignore error for function calls in argument defaults + - --ignore=B008,E203, W503 # Ignore error for function calls in argument defaults exclude: ^(__init__.py$|.*\/__init__.py$) diff --git a/lambda/models/handler/create_model_handler.py b/lambda/models/handler/create_model_handler.py index 1ff95aa3..a9ffbc1d 100644 --- a/lambda/models/handler/create_model_handler.py +++ b/lambda/models/handler/create_model_handler.py @@ -35,6 +35,8 @@ def __call__(self, create_request: CreateModelRequest) -> CreateModelResponse: if table_item: raise ModelAlreadyExistsError(f"Model '{model_id}' already exists. Please select another name.") + self.validate(create_request) + self._stepfunctions.start_execution( stateMachineArn=os.environ["CREATE_SFN_ARN"], input=create_request.model_dump_json() ) @@ -46,3 +48,37 @@ def __call__(self, create_request: CreateModelRequest) -> CreateModelResponse: } ) return CreateModelResponse(model=lisa_model) + + @staticmethod + def validate(create_request: CreateModelRequest) -> None: + # The below check ensures that the model is LISA hosted + if ( + create_request.containerConfig is not None + and create_request.autoScalingConfig is not None + and create_request.loadBalancerConfig is not None + ): + if create_request.containerConfig.image.baseImage is None: + raise ValueError("Base image must be provided for LISA hosted model.") + + # Validate values relative to current ASG. All conflicting request values have been validated as part of the + # AutoScalingInstanceConfig model validations, so those are not duplicated here. + if create_request.autoScalingConfig is not None: + # Min capacity can't be greater than the deployed ASG's max capacity + if ( + create_request.autoScalingConfig.minCapacity is not None + and create_request.autoScalingConfig.maxCapacity is not None + and create_request.autoScalingConfig.minCapacity > create_request.autoScalingConfig.maxCapacity + ): + raise ValueError( + f"Min capacity cannot exceed ASG max of {create_request.autoScalingConfig.maxCapacity}." + ) + + # Max capacity can't be less than the deployed ASG's min capacity + if ( + create_request.autoScalingConfig.maxCapacity is not None + and create_request.autoScalingConfig.minCapacity is not None + and create_request.autoScalingConfig.maxCapacity < create_request.autoScalingConfig.minCapacity + ): + raise ValueError( + f"Max capacity cannot be less than ASG min of {create_request.autoScalingConfig.minCapacity}." + ) diff --git a/lib/user-interface/react/src/components/model-management/create-model/CreateModelModal.tsx b/lib/user-interface/react/src/components/model-management/create-model/CreateModelModal.tsx index 656e6ea3..75cc4f67 100644 --- a/lib/user-interface/react/src/components/model-management/create-model/CreateModelModal.tsx +++ b/lib/user-interface/react/src/components/model-management/create-model/CreateModelModal.tsx @@ -170,7 +170,7 @@ export function CreateModelModal (props: CreateModelModalProps) : ReactElement { } } - const requiredFields = [['modelId', 'modelName'], [], [], [], []]; + const requiredFields = [['modelId', 'modelName'], ['containerConfig.image.baseImage'], [], [], []]; useEffect(() => { const parsedValue = _.mergeWith({}, initialForm, props.selectedItems[0], (a: IModelRequest, b: IModelRequest) => b === null ? a : undefined); @@ -318,8 +318,7 @@ export function CreateModelModal (props: CreateModelModalProps) : ReactElement { case 'next': case 'skip': { - touchFields(requiredFields[state.activeStepIndex]); - if (isValid) { + if (touchFields(requiredFields[state.activeStepIndex]) && isValid) { setState({ ...state, activeStepIndex: event.detail.requestedStepIndex, diff --git a/lib/user-interface/react/src/shared/model/model-management.model.ts b/lib/user-interface/react/src/shared/model/model-management.model.ts index 0721681c..92e7fd2e 100644 --- a/lib/user-interface/react/src/shared/model/model-management.model.ts +++ b/lib/user-interface/react/src/shared/model/model-management.model.ts @@ -230,5 +230,16 @@ export const ModelRequestSchema = z.object({ }); } } + + const baseImageValidator = z.string().min(1, {message: 'Required for LISA hosted models.'}); + const baseImageResult = baseImageValidator.safeParse(value.containerConfig.image.baseImage); + if (baseImageResult.success === false) { + for (const error of baseImageResult.error.errors) { + context.addIssue({ + ...error, + path: ['containerConfig', 'image', 'baseImage'] + }); + } + } } }); diff --git a/lib/user-interface/react/src/shared/validation/index.ts b/lib/user-interface/react/src/shared/validation/index.ts index 3968221f..a48b3174 100644 --- a/lib/user-interface/react/src/shared/validation/index.ts +++ b/lib/user-interface/react/src/shared/validation/index.ts @@ -112,7 +112,7 @@ export type SetFieldsFunction = ( ) => void; -export type TouchFieldsFunction = (fields: string[], method?: ValidationTouchActionMethod) => void; +export type TouchFieldsFunction = (fields: string[], method?: ValidationTouchActionMethod) => boolean; /** @@ -268,7 +268,7 @@ export const useValidationReducer = > return { state, errors, - isValid: parseResult.success, + isValid: Object.keys(errors).length === 0, setState: (newState: Partial, method: ValidationStateActionMethod = ModifyMethod.Default) => { setState({ type: ValidationReducerActionTypes.STATE, @@ -289,12 +289,19 @@ export const useValidationReducer = > touchFields: ( fields: string[], method: ValidationTouchActionMethod = ModifyMethod.Default - ) => { + ): boolean => { setState({ type: ValidationReducerActionTypes.TOUCH, method, fields, } as ValidationTouchAction); + const parseResult = formSchema.safeParse({...state.form, ...{touched: fields}}); + if (!parseResult.success) { + errors = issuesToErrors(parseResult.error.issues, fields.reduce((acc, key) => { + acc[key] = true; return acc; + }, {})); + } + return Object.keys(errors).length === 0; }, }; };