diff --git a/.github/workflows/code.release.branch.yml b/.github/workflows/code.release.branch.yml index eb267012..9d59f6d3 100644 --- a/.github/workflows/code.release.branch.yml +++ b/.github/workflows/code.release.branch.yml @@ -26,7 +26,7 @@ jobs: RELEASE_TAG=${{ github.event.inputs.release_tag }} git checkout -b release/${{ github.event.inputs.release_tag }} echo "$( jq --arg version ${RELEASE_TAG:1} '.version = $version' package.json )" > package.json - sed -E -i "" -e "s/version = \"[0-9\.].+\"/version = \"${RELEASE_TAG:1}\"/g" lisa-sdk/pyproject.toml + sed -E -i -e "s/version = \"[0-9\.].+\"/version = \"${RELEASE_TAG:1}\"/g" lisa-sdk/pyproject.toml echo ${RELEASE_TAG:1} > VERSION git commit -a -m "Updating version for release ${{ github.event.inputs.release_tag }}" git push origin release/${{ github.event.inputs.release_tag }} diff --git a/CHANGELOG.md b/CHANGELOG.md index 23552a01..e8763f24 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,25 @@ +# v3.0.1 +## Bug fixes +- Updated our Lambda admin validation to work for no-auth if user has the admin secret token. This applies to model management APIs. +- State machine for create model was not reporting failed status +- Delete state machine could not delete models that weren't stored in LiteLLM DB + +## Enhancements +- Added units to the create model wizard to help with clarity +- Increased default timeouts to 10 minutes to enable large documentation processing without errors +- Updated ALB and Target group names to be lower cased by default to prevent networking issues + +## Coming Soon +- 3.1.0 will expand support for model management. Administrators will be able to modify, activate, and deactivate models through the UI or APIs. The following release we will continue to ease deployment steps for customers through a new deployment wizard and updated documentation. + +## Acknowledgements +* @petermuller +* @estohlmann +* @dustins + +**Full Changelog**: https://github.com/awslabs/LISA/compare/v3.0.0...v3.0.1 + + # v3.0.0 ## Key Features ### Model Management Administration diff --git a/README.md b/README.md index aa31c398..cc0b9f77 100644 --- a/README.md +++ b/README.md @@ -659,6 +659,10 @@ curl -s -H "Authorization: Bearer " -X GET https:// LISA provides the `/models` endpoint for creating both ECS and LiteLLM-hosted models. Depending on the request payload, infrastructure will be created or bypassed (e.g., for LiteLLM-only models). +This API accepts the same model definition parameters that were accepted in the V2 model definitions within the config.yaml file with one notable difference: the `containerConfig.baseImage.path` field is +now a path relative to the `lib/serve/ecs-model` directory, instead of from its original path relative to the repository root. This means that if the path used to be `lib/serve/ecs-model/textgen/tgi`, then +it will now be `textgen/tgi` for the CreateModel API. For vLLM models, the `path` will be `vllm`, and for TEI, it will be `embedding/tei`. + #### Request Example: ``` diff --git a/VERSION b/VERSION index 4a36342f..cb2b00e4 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -3.0.0 +3.0.1 diff --git a/ecs_model_deployer/src/lib/ecsCluster.ts b/ecs_model_deployer/src/lib/ecsCluster.ts index 98395198..ecbb8b5a 100644 --- a/ecs_model_deployer/src/lib/ecsCluster.ts +++ b/ecs_model_deployer/src/lib/ecsCluster.ts @@ -109,6 +109,11 @@ export class ECSCluster extends Construct { ], }); + new CfnOutput(this, 'autoScalingGroup', { + key: 'autoScalingGroup', + value: autoScalingGroup.autoScalingGroupName, + }); + const environment = ecsConfig.environment; const volumes: Volume[] = []; const mountPoints: MountPoint[] = []; @@ -272,10 +277,11 @@ export class ECSCluster extends Construct { const loadBalancer = new ApplicationLoadBalancer(this, createCdkId([ecsConfig.identifier, 'ALB']), { deletionProtection: config.removalPolicy !== RemovalPolicy.DESTROY, internetFacing: false, - loadBalancerName: createCdkId([config.deploymentName, ecsConfig.identifier], 32, 2), + loadBalancerName: createCdkId([config.deploymentName, ecsConfig.identifier], 32, 2).toLowerCase(), dropInvalidHeaderFields: true, securityGroup, vpc, + idleTimeout: Duration.seconds(600) }); // Add listener @@ -294,7 +300,7 @@ export class ECSCluster extends Construct { // Add targets const loadBalancerHealthCheckConfig = ecsConfig.loadBalancerConfig.healthCheckConfig; const targetGroup = listener.addTargets(createCdkId([ecsConfig.identifier, 'TgtGrp']), { - targetGroupName: createCdkId([config.deploymentName, ecsConfig.identifier], 32, 2), + targetGroupName: createCdkId([config.deploymentName, ecsConfig.identifier], 32, 2).toLowerCase(), healthCheck: { path: loadBalancerHealthCheckConfig.path, interval: Duration.seconds(loadBalancerHealthCheckConfig.interval), diff --git a/lambda/authorizer/lambda_functions.py b/lambda/authorizer/lambda_functions.py index 541b5763..146d79b4 100644 --- a/lambda/authorizer/lambda_functions.py +++ b/lambda/authorizer/lambda_functions.py @@ -16,15 +16,20 @@ import logging import os import ssl +from functools import cache from typing import Any, Dict +import boto3 import create_env_variables # noqa: F401 import jwt import requests -from utilities.common_functions import authorization_wrapper, get_id_token +from botocore.exceptions import ClientError +from utilities.common_functions import authorization_wrapper, get_id_token, retry_config logger = logging.getLogger(__name__) +secrets_manager = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"], config=retry_config) + @authorization_wrapper def lambda_handler(event: Dict[str, Any], context) -> Dict[str, Any]: # type: ignore [no-untyped-def] @@ -48,6 +53,11 @@ def lambda_handler(event: Dict[str, Any], context) -> Dict[str, Any]: # type: i deny_policy = generate_policy(effect="Deny", resource=event["methodArn"]) + if id_token in get_management_tokens(): + allow_policy = generate_policy(effect="Allow", resource=event["methodArn"], username="lisa-management-token") + logger.debug(f"Generated policy: {allow_policy}") + return allow_policy + if jwt_data := id_token_is_valid(id_token=id_token, client_id=client_id, authority=authority): is_admin_user = is_admin(jwt_data, admin_group, jwt_groups_property) allow_policy = generate_policy(effect="Allow", resource=event["methodArn"], username=jwt_data["sub"]) @@ -134,3 +144,22 @@ def is_admin(jwt_data: dict[str, Any], admin_group: str, jwt_groups_property: st else: return False return admin_group in current_node + + +@cache +def get_management_tokens() -> list[str]: + """Return secret management tokens if they exist.""" + secret_tokens: list[str] = [] + secret_id = os.environ.get("MANAGEMENT_KEY_NAME") + + try: + secret_tokens.append( + secrets_manager.get_secret_value(SecretId=secret_id, VersionStage="AWSCURRENT")["SecretString"] + ) + secret_tokens.append( + secrets_manager.get_secret_value(SecretId=secret_id, VersionStage="AWSPREVIOUS")["SecretString"] + ) + except ClientError as e: + logger.warn(f"Unable to fetch {secret_id}. {e.response['Error']['Code']}: {e.response['Error']['Message']}") + + return secret_tokens diff --git a/lambda/models/domain_objects.py b/lambda/models/domain_objects.py index 8d567dc9..1261a1d4 100644 --- a/lambda/models/domain_objects.py +++ b/lambda/models/domain_objects.py @@ -18,7 +18,7 @@ from typing import Annotated, List, Optional, Union from pydantic import BaseModel -from pydantic.functional_validators import AfterValidator +from pydantic.functional_validators import AfterValidator, field_validator from utilities.validators import validate_instance_type @@ -87,7 +87,7 @@ class LoadBalancerConfig(BaseModel): class AutoScalingConfig(BaseModel): - """Autoscaling configuration.""" + """Autoscaling configuration upon model creation.""" minCapacity: int maxCapacity: int @@ -96,6 +96,14 @@ class AutoScalingConfig(BaseModel): metricConfig: MetricConfig +class AutoScalingInstanceConfig(BaseModel): + """Autoscaling instance count configuration upon model update.""" + + minCapacity: Optional[int] = None + maxCapacity: Optional[int] = None + desiredCapacity: Optional[int] = None + + class ContainerHealthCheckConfig(BaseModel): """Health check configuration for a container.""" @@ -180,16 +188,24 @@ class GetModelResponse(ApiResponseBase): class UpdateModelRequest(BaseModel): """Request object when updating a model.""" - autoScalingConfig: Optional[AutoScalingConfig] = None - containerConfig: Optional[ContainerConfig] = None - inferenceContainer: Optional[InferenceContainer] = None - instanceType: Optional[Annotated[str, AfterValidator(validate_instance_type)]] = None - loadBalancerConfig: Optional[LoadBalancerConfig] = None - modelId: str - modelName: Optional[str] = None + autoScalingInstanceConfig: Optional[AutoScalingInstanceConfig] = None + enabled: Optional[bool] = None modelType: Optional[ModelType] = None streaming: Optional[bool] = None + @field_validator("autoScalingInstanceConfig") # type: ignore + @classmethod + def validate_autoscaling_instance_config(cls, config: AutoScalingInstanceConfig) -> AutoScalingInstanceConfig: + """Validate that the AutoScaling instance config has at least one positive value.""" + if not config: + raise ValueError("The autoScalingInstanceConfig must not be null if defined in request payload.") + config_fields = (config.minCapacity, config.maxCapacity, config.desiredCapacity) + if all((field is None for field in config_fields)): + raise ValueError("At least one option of autoScalingInstanceConfig must be defined.") + if any((isinstance(field, int) and field < 0 for field in config_fields)): + raise ValueError("All autoScalingInstanceConfig fields must be >= 0.") + return config + class UpdateModelResponse(ApiResponseBase): """Response object when updating a model.""" diff --git a/lambda/models/exception/__init__.py b/lambda/models/exception/__init__.py index 01a3b431..418d54fa 100644 --- a/lambda/models/exception/__init__.py +++ b/lambda/models/exception/__init__.py @@ -15,6 +15,9 @@ """Exception definitions for model management APIs.""" +# LiteLLM errors + + class ModelNotFoundError(LookupError): """Error to raise when a specified model cannot be found in the database.""" @@ -25,3 +28,24 @@ class ModelAlreadyExistsError(LookupError): """Error to raise when a specified model already exists in the database.""" pass + + +# State machine exceptions + + +class MaxPollsExceededException(Exception): + """Exception to indicate that polling for a state timed out.""" + + pass + + +class StackFailedToCreateException(Exception): + """Exception to indicate that the CDK for creating a model stack failed.""" + + pass + + +class UnexpectedCloudFormationStateException(Exception): + """Exception to indicate that the CloudFormation stack has transitioned to a non-healthy state.""" + + pass diff --git a/lambda/models/lambda_functions.py b/lambda/models/lambda_functions.py index dd621531..41003dfe 100644 --- a/lambda/models/lambda_functions.py +++ b/lambda/models/lambda_functions.py @@ -123,7 +123,7 @@ def _create_dummy_model(model_name: str, model_type: ModelType, model_status: Mo ), sharedMemorySize=2048, healthCheckConfig=ContainerHealthCheckConfig( - command=["CMD-SHELL", "exit 0"], Interval=10, StartPeriod=30, Timeout=5, Retries=5 + command=["CMD-SHELL", "exit 0"], interval=10, startPeriod=30, timeout=5, retries=5 ), environment={ "MAX_CONCURRENT_REQUESTS": "128", @@ -177,7 +177,7 @@ async def update_model( ) -> UpdateModelResponse: """Endpoint to update a model.""" # TODO add service to update model - model = _create_dummy_model("model_name", ModelType.TEXTGEN, ModelStatus.UPDATING) + model = _create_dummy_model(model_id, ModelType.TEXTGEN, ModelStatus.UPDATING) return UpdateModelResponse(model=model) diff --git a/lambda/models/state_machine/create_model.py b/lambda/models/state_machine/create_model.py index b28c8ed9..3cbaa47a 100644 --- a/lambda/models/state_machine/create_model.py +++ b/lambda/models/state_machine/create_model.py @@ -24,6 +24,11 @@ from botocore.config import Config from models.clients.litellm_client import LiteLLMClient from models.domain_objects import CreateModelRequest, ModelStatus +from models.exception import ( + MaxPollsExceededException, + StackFailedToCreateException, + UnexpectedCloudFormationStateException, +) from utilities.common_functions import get_cert_path, get_rest_api_container_endpoint, retry_config lambdaConfig = Config(connect_timeout=60, read_timeout=600, retries={"max_attempts": 1}) @@ -113,8 +118,13 @@ def handle_poll_docker_image_available(event: Dict[str, Any], context: Any) -> D output_dict["image_info"]["remaining_polls"] -= 1 if output_dict["image_info"]["remaining_polls"] <= 0: ec2Client.terminate_instances(InstanceIds=[event["image_info"]["instance_id"]]) - raise Exception( - "Maximum number of ECR poll attempts reached. Something went wrong building the docker image." + raise MaxPollsExceededException( + json.dumps( + { + "error": "Max number of ECR polls reached. Docker Image was not replicated successfully.", + "event": event, + } + ) ) return output_dict @@ -153,7 +163,17 @@ def camelize_object(o): # type: ignore[no-untyped-def] payload = response["Payload"].read() payload = json.loads(payload) - stack_name = payload.get("stackName") + stack_name = payload.get("stackName", None) + + if not stack_name: + raise StackFailedToCreateException( + json.dumps( + { + "error": "Failed to create Model CloudFormation Stack. Please validate model parameters are valid.", + "event": event, + } + ) + ) response = cfnClient.describe_stacks(StackName=stack_name) stack_arn = response["Stacks"][0]["StackId"] @@ -192,10 +212,24 @@ def handle_poll_create_stack(event: Dict[str, Any], context: Any) -> Dict[str, A output_dict["continue_polling_stack"] = True output_dict["remaining_polls_stack"] -= 1 if output_dict["remaining_polls_stack"] <= 0: - raise Exception("Maximum number of CloudFormation polls reached") + raise MaxPollsExceededException( + json.dumps( + { + "error": "Max number of CloudFormation polls reached.", + "event": event, + } + ) + ) return output_dict else: - raise Exception(f"Stack in unexpected state: {stackStatus}") + raise UnexpectedCloudFormationStateException( + json.dumps( + { + "error": f"Stack entered unexpected state: {stackStatus}", + "event": event, + } + ) + ) def handle_add_model_to_litellm(event: Dict[str, Any], context: Any) -> Dict[str, Any]: @@ -234,3 +268,38 @@ def handle_add_model_to_litellm(event: Dict[str, Any], context: Any) -> Dict[str ) return output_dict + + +def handle_failure(event: Dict[str, Any], context: Any) -> Dict[str, Any]: + """ + Handle failures from state machine. + + Possible causes of failures would be: + 1. Docker Image failed to replicate into ECR in expected amount of time + 2. CloudFormation Stack creation failed from parameter validation. + 3. CloudFormation Stack creation failed from taking too long to stand up. + + Expectation of this function is to terminate the EC2 instance if it is still running, and to set the model status + to Failed. Cleaning up the CloudFormation stack, if it still exists, will happen in the DeleteModel API. + """ + error_dict = json.loads( # error from SFN is json payload on top of json payload we add to the exception + json.loads(event["Cause"])["errorMessage"] + ) + error_reason = error_dict["error"] + original_event = error_dict["event"] + + # terminate EC2 instance if we have one recorded + if "image_info" in original_event and "instance_id" in original_event["image_info"]: + ec2Client.terminate_instances(InstanceIds=[original_event["image_info"]["instance_id"]]) + + # set model as Failed in DDB, so it shows as such in the UI. adds error reason as well. + model_table.update_item( + Key={"model_id": original_event["modelId"]}, + UpdateExpression="SET model_status = :ms, last_modified_date = :lm, failure_reason = :fr", + ExpressionAttributeValues={ + ":ms": ModelStatus.FAILED, + ":lm": int(datetime.utcnow().timestamp()), + ":fr": error_reason, + }, + ) + return event diff --git a/lambda/models/state_machine/delete_model.py b/lambda/models/state_machine/delete_model.py index 506fdd03..8ba002e8 100644 --- a/lambda/models/state_machine/delete_model.py +++ b/lambda/models/state_machine/delete_model.py @@ -62,7 +62,7 @@ def handle_set_model_to_deleting(event: Dict[str, Any], context: Any) -> Dict[st if not item: raise RuntimeError(f"Requested model '{model_id}' was not found in DynamoDB table.") output_dict[CFN_STACK_ARN] = item.get(CFN_STACK_ARN, None) - output_dict[LITELLM_ID] = item[LITELLM_ID] + output_dict[LITELLM_ID] = item.get(LITELLM_ID, None) ddb_table.update_item( Key=model_key, @@ -77,7 +77,8 @@ def handle_set_model_to_deleting(event: Dict[str, Any], context: Any) -> Dict[st def handle_delete_from_litellm(event: Dict[str, Any], context: Any) -> Dict[str, Any]: """Delete model reference from LiteLLM.""" - litellm_client.delete_model(identifier=event[LITELLM_ID]) + if event[LITELLM_ID]: # if non-null ID + litellm_client.delete_model(identifier=event[LITELLM_ID]) return event diff --git a/lambda/utilities/common_functions.py b/lambda/utilities/common_functions.py index 34111aa8..0acffc0c 100644 --- a/lambda/utilities/common_functions.py +++ b/lambda/utilities/common_functions.py @@ -279,14 +279,14 @@ def get_id_token(event: dict) -> str: auth_header = None if "authorization" in event["headers"]: - auth_header = event["headers"]["authorization"].split(" ") + auth_header = event["headers"]["authorization"] elif "Authorization" in event["headers"]: - auth_header = event["headers"]["Authorization"].split(" ") + auth_header = event["headers"]["Authorization"] else: raise ValueError("Missing authorization token.") - token = auth_header[1] - return str(token) + # remove bearer token prefix if present + return str(auth_header).removeprefix("Bearer ").removeprefix("bearer ").strip() @cache diff --git a/lib/api-base/authorizer.ts b/lib/api-base/authorizer.ts index a4eb9a05..358b8b6c 100644 --- a/lib/api-base/authorizer.ts +++ b/lib/api-base/authorizer.ts @@ -23,6 +23,8 @@ import { StringParameter } from 'aws-cdk-lib/aws-ssm'; import { Construct } from 'constructs'; import { BaseProps } from '../schema'; +import { createCdkId } from '../core/utils'; +import { Secret } from 'aws-cdk-lib/aws-secretsmanager'; /** * Properties for RestApiGateway Construct. @@ -67,6 +69,8 @@ export class CustomAuthorizer extends Construct { StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/layerVersion/authorizer`), ); + const managementKeySecretNameStringParameter = StringParameter.fromStringParameterName(this, createCdkId([id, 'managementKeyStringParameter']), `${config.deploymentPrefix}/managementKeySecretName`); + // Create Lambda authorizer const authorizerLambda = new Function(this, 'AuthorizerLambda', { runtime: config.lambdaConfig.pythonRuntime, @@ -82,12 +86,16 @@ export class CustomAuthorizer extends Construct { AUTHORITY: config.authConfig!.authority, ADMIN_GROUP: config.authConfig!.adminGroup, JWT_GROUPS_PROP: config.authConfig!.jwtGroupsProperty, + MANAGEMENT_KEY_NAME: managementKeySecretNameStringParameter.stringValue }, role: role, vpc: vpc, securityGroups: securityGroups, }); + const managementKeySecret = Secret.fromSecretNameV2(this, createCdkId([id, 'managementKey']), managementKeySecretNameStringParameter.stringValue); + managementKeySecret.grantRead(authorizerLambda); + // Update this.authorizer = new RequestAuthorizer(this, 'APIGWAuthorizer', { handler: authorizerLambda, diff --git a/lib/api-base/ecsCluster.ts b/lib/api-base/ecsCluster.ts index aa248552..105b9edb 100644 --- a/lib/api-base/ecsCluster.ts +++ b/lib/api-base/ecsCluster.ts @@ -262,10 +262,11 @@ export class ECSCluster extends Construct { const loadBalancer = new ApplicationLoadBalancer(this, createCdkId([ecsConfig.identifier, 'ALB']), { deletionProtection: config.removalPolicy !== RemovalPolicy.DESTROY, internetFacing: ecsConfig.internetFacing, - loadBalancerName: createCdkId([config.deploymentName, ecsConfig.identifier], 32, 2), + loadBalancerName: createCdkId([config.deploymentName, ecsConfig.identifier], 32, 2).toLowerCase(), dropInvalidHeaderFields: true, securityGroup, vpc, + idleTimeout: Duration.seconds(600) }); // Add listener @@ -286,7 +287,7 @@ export class ECSCluster extends Construct { // Add targets const loadBalancerHealthCheckConfig = ecsConfig.loadBalancerConfig.healthCheckConfig; const targetGroup = listener.addTargets(createCdkId([ecsConfig.identifier, 'TgtGrp']), { - targetGroupName: createCdkId([config.deploymentName, ecsConfig.identifier], 32, 2), + targetGroupName: createCdkId([config.deploymentName, ecsConfig.identifier], 32, 2).toLowerCase(), healthCheck: { path: loadBalancerHealthCheckConfig.path, interval: Duration.seconds(loadBalancerHealthCheckConfig.interval), diff --git a/lib/models/index.ts b/lib/models/index.ts index 01ab1e60..d6aced0f 100644 --- a/lib/models/index.ts +++ b/lib/models/index.ts @@ -24,7 +24,6 @@ import { Vpc } from '../networking/vpc'; import { ModelsApi } from './model-api'; import { BaseProps } from '../schema'; import { StringParameter } from 'aws-cdk-lib/aws-ssm'; -import { Secret } from 'aws-cdk-lib/aws-secretsmanager'; type LisaModelsApiStackProps = BaseProps & StackProps & { @@ -34,7 +33,6 @@ type LisaModelsApiStackProps = BaseProps & rootResourceId: string; securityGroups?: ISecurityGroup[]; vpc: Vpc; - managementKeySecret: Secret; }; /** @@ -49,7 +47,7 @@ export class LisaModelsApiStack extends Stack { constructor (scope: Construct, id: string, props: LisaModelsApiStackProps) { super(scope, id, props); - const { authorizer, lisaServeEndpointUrlPs, config, restApiId, rootResourceId, securityGroups, vpc, managementKeySecret } = props; + const { authorizer, lisaServeEndpointUrlPs, config, restApiId, rootResourceId, securityGroups, vpc } = props; // Add REST API Lambdas to APIGW new ModelsApi(this, 'ModelsApi', { @@ -60,7 +58,6 @@ export class LisaModelsApiStack extends Stack { rootResourceId, securityGroups, vpc, - managementKeySecret, }); } } diff --git a/lib/models/model-api.ts b/lib/models/model-api.ts index 415e1261..b9935c34 100644 --- a/lib/models/model-api.ts +++ b/lib/models/model-api.ts @@ -62,7 +62,6 @@ type ModelsApiProps = BaseProps & { rootResourceId: string; securityGroups?: ISecurityGroup[]; vpc: Vpc; - managementKeySecret: Secret; }; /** @@ -72,7 +71,7 @@ export class ModelsApi extends Construct { constructor (scope: Construct, id: string, props: ModelsApiProps) { super(scope, id); - const { authorizer, config, lambdaExecutionRole, lisaServeEndpointUrlPs, restApiId, rootResourceId, securityGroups, vpc, managementKeySecret } = props; + const { authorizer, config, lambdaExecutionRole, lisaServeEndpointUrlPs, restApiId, rootResourceId, securityGroups, vpc } = props; // Get common layer based on arn from SSM due to issues with cross stack references const commonLambdaLayer = LayerVersion.fromLayerVersionArn( @@ -124,6 +123,8 @@ export class ModelsApi extends Construct { mountS3DebUrl: config.mountS3DebUrl! }); + const managementKeyName = StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/managementKeySecretName`); + const stateMachinesLambdaRole = new Role(this, 'ModelsSfnLambdaRole', { assumedBy: new ServicePrincipal('lambda.amazonaws.com'), managedPolicies: [ @@ -197,15 +198,13 @@ export class ModelsApi extends Construct { actions: [ 'secretsmanager:GetSecretValue' ], - resources: [managementKeySecret.secretArn], + resources: [`${Secret.fromSecretNameV2(this, 'ManagementKeySecret', managementKeyName).secretArn}-??????`], // question marks required to resolve the ARN correctly }), ] }), } }); - const managementKeyName = StringParameter.valueForStringParameter(this, `${config.deploymentPrefix}/managementKeySecretName`); - const createModelStateMachine = new CreateModelStateMachine(this, 'CreateModelWorkflow', { config: config, modelTable: modelTable, diff --git a/lib/models/state-machine/create-model.ts b/lib/models/state-machine/create-model.ts index 1bf967b0..946553f4 100644 --- a/lib/models/state-machine/create-model.ts +++ b/lib/models/state-machine/create-model.ts @@ -18,6 +18,7 @@ import { Choice, Condition, DefinitionBody, + Fail, StateMachine, Succeed, Wait, @@ -117,7 +118,23 @@ export class CreateModelStateMachine extends Construct { layers: lambdaLayers, environment: environment, }), - outputPath: OUTPUT_PATH + outputPath: OUTPUT_PATH, + }); + + const handleFailureState = new LambdaInvoke(this, 'HandleFailure', { + lambdaFunction: new Function(this, 'HandleFailureFunc', { + runtime: config.lambdaConfig.pythonRuntime, + handler: 'models.state_machine.create_model.handle_failure', + code: Code.fromAsset(config.lambdaSourcePath), + timeout: LAMBDA_TIMEOUT, + memorySize: LAMBDA_MEMORY, + role: role, + vpc: vpc, + securityGroups: securityGroups, + layers: lambdaLayers, + environment: environment, + }), + outputPath: OUTPUT_PATH, }); const pollDockerImageChoice = new Choice(this, 'PollDockerImageChoice'); @@ -181,6 +198,7 @@ export class CreateModelStateMachine extends Construct { }); const successState = new Succeed(this, 'CreateSuccess'); + const failState = new Fail(this, 'CreateFailed'); // State Machine definition setModelToCreating.next(createModelInfraChoice); @@ -188,20 +206,36 @@ export class CreateModelStateMachine extends Construct { .when(Condition.booleanEquals('$.create_infra', true), startCopyDockerImage) .otherwise(addModelToLitellm); + // poll ECR image copy status loop startCopyDockerImage.next(pollDockerImageAvailable); pollDockerImageAvailable.next(pollDockerImageChoice); + pollDockerImageAvailable.addCatch(handleFailureState, { // fail if exception thrown from code + errors: ['MaxPollsExceededException'], + }); pollDockerImageChoice .when(Condition.booleanEquals('$.continue_polling_docker', true), waitBeforePollingDockerImage) .otherwise(startCreateStack); waitBeforePollingDockerImage.next(pollDockerImageAvailable); + // poll CloudFormation stack status loop startCreateStack.next(pollCreateStack); + startCreateStack.addCatch(handleFailureState, { // fail if CDK failed to create model stack + errors: ['StackFailedToCreateException'] + }); pollCreateStack.next(pollCreateStackChoice); + pollCreateStack.addCatch(handleFailureState, { // fail if model failed or failed to create in time + errors: [ + 'MaxPollsExceededException', + 'UnexpectedCloudFormationStateException', + ], + }); pollCreateStackChoice .when(Condition.booleanEquals('$.continue_polling_stack', true), waitBeforePollingCreateStack) .otherwise(addModelToLitellm); waitBeforePollingCreateStack.next(pollCreateStack); + // terminal states + handleFailureState.next(failState); addModelToLitellm.next(successState); const stateMachine = new StateMachine(this, 'CreateModelSM', { diff --git a/lib/serve/index.ts b/lib/serve/index.ts index 11f85e4c..8116ad67 100644 --- a/lib/serve/index.ts +++ b/lib/serve/index.ts @@ -46,7 +46,6 @@ export class LisaServeApplicationStack extends Stack { public readonly restApi: FastApiContainer; public readonly modelsPs: StringParameter; public readonly endpointUrl: StringParameter; - public readonly managementKeySecret: Secret; /** * @param {Construct} scope - The parent or owner of the construct. @@ -85,7 +84,7 @@ export class LisaServeApplicationStack extends Stack { vpc: vpc.vpc, }); - this.managementKeySecret = new Secret(this, createCdkId([id, 'managementKeySecret']), { + const managementKeySecret = new Secret(this, createCdkId([id, 'managementKeySecret']), { secretName: `lisa_management_key_secret-${Date.now()}`, // pragma: allowlist secret` description: 'This is a secret created with AWS CDK', generateSecretString: { @@ -132,7 +131,7 @@ export class LisaServeApplicationStack extends Stack { const managementKeySecretNameStringParameter = new StringParameter(this, createCdkId(['ManagementKeySecretName']), { parameterName: `${config.deploymentPrefix}/managementKeySecretName`, - stringValue: this.managementKeySecret.secretName, + stringValue: managementKeySecret.secretName, }); restApi.container.addEnvironment('MANAGEMENT_KEY_NAME', managementKeySecretNameStringParameter.stringValue); diff --git a/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py b/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py index 73e4fe8b..b98aaacf 100644 --- a/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py +++ b/lib/serve/rest-api/src/api/endpoints/v2/litellm_passthrough.py @@ -140,7 +140,7 @@ async def litellm_passthrough(request: Request, api_path: str) -> Response: def refresh_management_tokens() -> list[str]: - """Return DDB entry for token if it exists.""" + """Return secret management tokens if they exist.""" secret_tokens = [] try: diff --git a/lib/serve/rest-api/src/auth.py b/lib/serve/rest-api/src/auth.py index f507b62a..bdb32e83 100644 --- a/lib/serve/rest-api/src/auth.py +++ b/lib/serve/rest-api/src/auth.py @@ -198,7 +198,7 @@ def __init__(self) -> None: self._last_run = 0 def _refreshTokens(self) -> None: - """Return DDB entry for token if it exists.""" + """Refresh secret management tokens.""" current_time = int(time()) if current_time - (self._last_run or 0) > 3600: secret_tokens = [] diff --git a/lib/serve/rest-api/src/entrypoint.sh b/lib/serve/rest-api/src/entrypoint.sh index a898ccfa..8e8c0bb6 100644 --- a/lib/serve/rest-api/src/entrypoint.sh +++ b/lib/serve/rest-api/src/entrypoint.sh @@ -20,4 +20,4 @@ litellm -c litellm_config.yaml & echo "Starting Gunicorn with $THREADS workers..." # Start Gunicorn with Uvicorn workers. -exec gunicorn -k uvicorn.workers.UvicornWorker -w "$THREADS" -b "$HOST:$PORT" "src.main:app" +exec gunicorn -k uvicorn.workers.UvicornWorker -t 600 -w "$THREADS" -b "$HOST:$PORT" "src.main:app" diff --git a/lib/serve/rest-api/src/utils/generate_litellm_config.py b/lib/serve/rest-api/src/utils/generate_litellm_config.py index a417c172..19b346f3 100644 --- a/lib/serve/rest-api/src/utils/generate_litellm_config.py +++ b/lib/serve/rest-api/src/utils/generate_litellm_config.py @@ -52,6 +52,7 @@ def generate_config(filepath: str) -> None: config_models = config_contents["model_list"] or [] # ensure config_models is a list and not None config_models.extend(litellm_model_params) config_contents["model_list"] = config_models + config_contents["litellm_settings"] = {"request_timeout": 600} # Get database connection info db_param_response = ssm_client.get_parameter(Name=os.environ["LITELLM_DB_INFO_PS_NAME"]) diff --git a/lib/stages.ts b/lib/stages.ts index 46bc53e6..f8ab0b35 100644 --- a/lib/stages.ts +++ b/lib/stages.ts @@ -162,7 +162,6 @@ export class LisaServeApplicationStage extends Stage { rootResourceId: apiBaseStack.rootResourceId, stackName: createCdkId([config.deploymentName, config.appName, 'models', config.deploymentStage]), vpc: networkingStack.vpc, - managementKeySecret: serveStack.managementKeySecret, }); modelsApiDeploymentStack.addDependency(serveStack); apiDeploymentStack.addDependency(modelsApiDeploymentStack); diff --git a/lib/user-interface/react/src/components/model-management/create-model/AutoScalingConfig.tsx b/lib/user-interface/react/src/components/model-management/create-model/AutoScalingConfig.tsx index a37e5853..3d2b99b4 100644 --- a/lib/user-interface/react/src/components/model-management/create-model/AutoScalingConfig.tsx +++ b/lib/user-interface/react/src/components/model-management/create-model/AutoScalingConfig.tsx @@ -20,7 +20,7 @@ import FormField from '@cloudscape-design/components/form-field'; import Input from '@cloudscape-design/components/input'; import { IAutoScalingConfig } from '../../../shared/model/model-management.model'; -import { Header, SpaceBetween } from '@cloudscape-design/components'; +import { Grid, Header, SpaceBetween } from '@cloudscape-design/components'; import Container from '@cloudscape-design/components/container'; export function AutoScalingConfig (props: FormProps) : ReactElement { @@ -32,24 +32,36 @@ export function AutoScalingConfig (props: FormProps) : React } > - props.touchFields(['autoScalingConfig.minCapacity'])} onChange={({ detail }) => { - props.setFields({ 'autoScalingConfig.minCapacity': Number(detail.value) }); - }}/> + + props.touchFields(['autoScalingConfig.minCapacity'])} onChange={({ detail }) => { + props.setFields({ 'autoScalingConfig.minCapacity': Number(detail.value) }); + }}/> + instances + - props.touchFields(['autoScalingConfig.maxCapacity'])} onChange={({ detail }) => { - props.setFields({ 'autoScalingConfig.maxCapacity': Number(detail.value) }); - }}/> + + props.touchFields(['autoScalingConfig.maxCapacity'])} onChange={({ detail }) => { + props.setFields({ 'autoScalingConfig.maxCapacity': Number(detail.value) }); + }}/> + instances + - props.touchFields(['autoScalingConfig.cooldown'])} onChange={({ detail }) => { - props.setFields({ 'autoScalingConfig.Cooldown': Number(detail.value) }); - }}/> + + props.touchFields(['autoScalingConfig.cooldown'])} onChange={({ detail }) => { + props.setFields({ 'autoScalingConfig.Cooldown': Number(detail.value) }); + }}/> + seconds + - props.touchFields(['autoScalingConfig.defaultInstanceWarmup'])} onChange={({ detail }) => { - props.setFields({ 'autoScalingConfig.defaultInstanceWarmup': Number(detail.value) }); - }}/> + + props.touchFields(['autoScalingConfig.defaultInstanceWarmup'])} onChange={({ detail }) => { + props.setFields({ 'autoScalingConfig.defaultInstanceWarmup': Number(detail.value) }); + }}/> + seconds + ) : React }}/> - props.touchFields(['autoScalingConfig.metricConfig.duration'])} onChange={({ detail }) => { - props.setFields({ 'autoScalingConfig.metricConfig.duration': Number(detail.value) }); - }}/> + + props.touchFields(['autoScalingConfig.metricConfig.duration'])} onChange={({ detail }) => { + props.setFields({ 'autoScalingConfig.metricConfig.duration': Number(detail.value) }); + }}/> + seconds + - props.touchFields(['autoScalingConfig.metricConfig.estimatedInstanceWarmup'])} onChange={({ detail }) => { - props.setFields({ 'autoScalingConfig.metricConfig.estimatedInstanceWarmup': Number(detail.value) }); - }}/> + + props.touchFields(['autoScalingConfig.metricConfig.estimatedInstanceWarmup'])} onChange={({ detail }) => { + props.setFields({ 'autoScalingConfig.metricConfig.estimatedInstanceWarmup': Number(detail.value) }); + }}/> + seconds + diff --git a/lib/user-interface/react/src/components/model-management/create-model/ContainerConfig.tsx b/lib/user-interface/react/src/components/model-management/create-model/ContainerConfig.tsx index 6d33335b..454ed0c5 100644 --- a/lib/user-interface/react/src/components/model-management/create-model/ContainerConfig.tsx +++ b/lib/user-interface/react/src/components/model-management/create-model/ContainerConfig.tsx @@ -33,9 +33,12 @@ export function ContainerConfig (props: FormProps) : ReactElem > - props.touchFields(['containerConfig.sharedMemorySize'])} onChange={({ detail }) => { - props.setFields({ 'containerConfig.sharedMemorySize': Number(detail.value) }); - }}/> + + props.touchFields(['containerConfig.sharedMemorySize'])} onChange={({ detail }) => { + props.setFields({ 'containerConfig.sharedMemorySize': Number(detail.value) }); + }}/> + MiB + props.touchFields(['containerConfig.baseImage.baseImage'])} onChange={({ detail }) => { @@ -92,19 +95,28 @@ export function ContainerConfig (props: FormProps) : ReactElem - props.touchFields(['containerConfig.healthCheckConfig.interval'])} onChange={({ detail }) => { - props.setFields({ 'containerConfig.healthCheckConfig.interval': Number(detail.value) }); - }}/> + + props.touchFields(['containerConfig.healthCheckConfig.interval'])} onChange={({ detail }) => { + props.setFields({ 'containerConfig.healthCheckConfig.interval': Number(detail.value) }); + }}/> + seconds + - props.touchFields(['containerConfig.healthCheckConfig.startPeriod'])} onChange={({ detail }) => { - props.setFields({ 'containerConfig.healthCheckConfig.startPeriod': Number(detail.value) }); - }}/> + + props.touchFields(['containerConfig.healthCheckConfig.startPeriod'])} onChange={({ detail }) => { + props.setFields({ 'containerConfig.healthCheckConfig.startPeriod': Number(detail.value) }); + }}/> + seconds + - props.touchFields(['containerConfig.healthCheckConfig.timeout'])} onChange={({ detail }) => { - props.setFields({ 'containerConfig.healthCheckConfig.timeout': Number(detail.value) }); - }}/> + + props.touchFields(['containerConfig.healthCheckConfig.timeout'])} onChange={({ detail }) => { + props.setFields({ 'containerConfig.healthCheckConfig.timeout': Number(detail.value) }); + }}/> + seconds + props.touchFields(['containerConfig.healthCheckConfig.retries'])} onChange={({ detail }) => { diff --git a/lib/user-interface/react/src/components/model-management/create-model/LoadBalancerConfig.tsx b/lib/user-interface/react/src/components/model-management/create-model/LoadBalancerConfig.tsx index a0e762c6..920ccf6b 100644 --- a/lib/user-interface/react/src/components/model-management/create-model/LoadBalancerConfig.tsx +++ b/lib/user-interface/react/src/components/model-management/create-model/LoadBalancerConfig.tsx @@ -19,7 +19,7 @@ import { FormProps} from '../../../shared/form/form-props'; import FormField from '@cloudscape-design/components/form-field'; import Input from '@cloudscape-design/components/input'; import { ILoadBalancerConfig } from '../../../shared/model/model-management.model'; -import { Header } from '@cloudscape-design/components'; +import { Grid, Header } from '@cloudscape-design/components'; import Container from '@cloudscape-design/components/container'; export function LoadBalancerConfig (props: FormProps) : ReactElement { @@ -36,14 +36,20 @@ export function LoadBalancerConfig (props: FormProps) : Rea }}/> - props.touchFields(['loadBalancerConfig.healthCheckConfig.interval'])} onChange={({ detail }) => { - props.setFields({ 'loadBalancerConfig.healthCheckConfig.interval': Number(detail.value) }); - }}/> + + props.touchFields(['loadBalancerConfig.healthCheckConfig.interval'])} onChange={({ detail }) => { + props.setFields({ 'loadBalancerConfig.healthCheckConfig.interval': Number(detail.value) }); + }}/> + seconds + - props.touchFields(['loadBalancerConfig.healthCheckConfig.timeout'])} onChange={({ detail }) => { - props.setFields({ 'loadBalancerConfig.healthCheckConfig.timeout': Number(detail.value) }); - }}/> + + props.touchFields(['loadBalancerConfig.healthCheckConfig.timeout'])} onChange={({ detail }) => { + props.setFields({ 'loadBalancerConfig.healthCheckConfig.timeout': Number(detail.value) }); + }}/> + seconds + props.touchFields(['loadBalancerConfig.healthCheckConfig.healthyThresholdCount'])} onChange={({ detail }) => { diff --git a/lisa-sdk/pyproject.toml b/lisa-sdk/pyproject.toml index f3b8fac6..35527a08 100644 --- a/lisa-sdk/pyproject.toml +++ b/lisa-sdk/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "lisapy" -version = "3.0.0" +version = "3.0.1" description = "A simple SDK to help you interact with LISA. LISA is an LLM hosting solution for AWS dedicated clouds or ADCs." authors = ["Steve Goley "] readme = "README.md" diff --git a/package-lock.json b/package-lock.json index d45888de..ca82589e 100644 --- a/package-lock.json +++ b/package-lock.json @@ -1,12 +1,12 @@ { "name": "lisa", - "version": "3.0.0", + "version": "3.0.1", "lockfileVersion": 3, "requires": true, "packages": { "": { "name": "lisa", - "version": "3.0.0", + "version": "3.0.1", "license": "Apache-2.0", "dependencies": { "aws-cdk-lib": "2.125.0", diff --git a/package.json b/package.json index 90a8ae28..14157efa 100644 --- a/package.json +++ b/package.json @@ -1,6 +1,6 @@ { "name": "lisa", - "version": "3.0.0", + "version": "3.0.1", "bin": { "lisa": "bin/lisa.js" }, diff --git a/test/cdk/stacks/core-api-base.test.ts b/test/cdk/stacks/core-api-base.test.ts index 438d66df..e14dddde 100644 --- a/test/cdk/stacks/core-api-base.test.ts +++ b/test/cdk/stacks/core-api-base.test.ts @@ -112,6 +112,6 @@ describe.each(regions)('API Core Nag Pack Tests | Region Test: %s', (awsRegion) test('NIST800.53r5 CDK NAG Errors', () => { const errors = Annotations.fromStack(stack).findError('*', Match.stringLikeRegexp('NIST.*')); - expect(errors.length).toBe(6); + expect(errors.length).toBe(7); }); });