Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Release v3.0.1 into Main #109

Merged
merged 28 commits into from
Sep 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
4631fdb
Fix: remove import/export dependency between stacks on management key
petermuller Sep 13, 2024
807dfd6
Fix GetSecretValue scoped permissions
petermuller Sep 13, 2024
e2ae9ef
Remove import/export dependency between stacks on management key
estohlmann Sep 16, 2024
4cbe453
Fix: add error catching to Create API, allow model Deletes when LiteL…
petermuller Sep 18, 2024
ca19446
awslabs/fix/failed-creates
estohlmann Sep 18, 2024
be1c835
Update API model and API stub for UpdateModel
petermuller Sep 18, 2024
9d5a1c0
Update API model and API stub for UpdateModel
estohlmann Sep 18, 2024
3688a27
Add enabled field to UpdateModel for start and stop operations
petermuller Sep 18, 2024
45ac5ef
Add enabled field to UpdateModel for start and stop operations
estohlmann Sep 18, 2024
5cc2830
Added Output for ASG name to the model CDK infra
dustins Sep 19, 2024
24f0583
added units to model create form
dustins Sep 19, 2024
017d129
Added units for create Model form fields
estohlmann Sep 19, 2024
657bf76
Merge branch 'develop' into asg-output-20240919
estohlmann Sep 19, 2024
d5be99a
Added Output for ASG name to the model CDK infra
estohlmann Sep 19, 2024
0a89b72
Updating default timeouts to 600 seconds
estohlmann Sep 19, 2024
3819e7d
Updating default timeouts to 600 seconds
estohlmann Sep 19, 2024
c388965
Updating default timeouts to 600 seconds
estohlmann Sep 19, 2024
12a9335
Updating default timeouts to 600 seconds
estohlmann Sep 19, 2024
5cb23fb
lowercasing alb and target group names
estohlmann Sep 19, 2024
3374e38
Merge remote-tracking branch 'origin/fix/adding-increased-timeouts' i…
estohlmann Sep 19, 2024
d9d1cab
adding increased timeouts and lower casing alb/target group names
estohlmann Sep 19, 2024
fc61d05
Adds new authentication method for model management API (#99)
dustins Sep 19, 2024
e6671a3
Add functionality change between v2 and v3
petermuller Sep 19, 2024
78f825e
Add functionality change between v2 and v3
estohlmann Sep 19, 2024
0500a57
fix sed command
estohlmann Sep 20, 2024
a9c9b03
fix sed command in github actions
estohlmann Sep 20, 2024
1999a40
Updating version for release v3.0.1
petermuller Sep 20, 2024
ba4fd4f
change log updates
estohlmann Sep 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/code.release.branch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ jobs:
RELEASE_TAG=${{ github.event.inputs.release_tag }}
git checkout -b release/${{ github.event.inputs.release_tag }}
echo "$( jq --arg version ${RELEASE_TAG:1} '.version = $version' package.json )" > package.json
sed -E -i "" -e "s/version = \"[0-9\.].+\"/version = \"${RELEASE_TAG:1}\"/g" lisa-sdk/pyproject.toml
sed -E -i -e "s/version = \"[0-9\.].+\"/version = \"${RELEASE_TAG:1}\"/g" lisa-sdk/pyproject.toml
echo ${RELEASE_TAG:1} > VERSION
git commit -a -m "Updating version for release ${{ github.event.inputs.release_tag }}"
git push origin release/${{ github.event.inputs.release_tag }}
Expand Down
22 changes: 22 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,25 @@
# v3.0.1
## Bug fixes
- Updated our Lambda admin validation to work for no-auth if user has the admin secret token. This applies to model management APIs.
- State machine for create model was not reporting failed status
- Delete state machine could not delete models that weren't stored in LiteLLM DB

## Enhancements
- Added units to the create model wizard to help with clarity
- Increased default timeouts to 10 minutes to enable large documentation processing without errors
- Updated ALB and Target group names to be lower cased by default to prevent networking issues

## Coming Soon
- 3.1.0 will expand support for model management. Administrators will be able to modify, activate, and deactivate models through the UI or APIs. The following release we will continue to ease deployment steps for customers through a new deployment wizard and updated documentation.

## Acknowledgements
* @petermuller
* @estohlmann
* @dustins

**Full Changelog**: https://github.com/awslabs/LISA/compare/v3.0.0...v3.0.1


# v3.0.0
## Key Features
### Model Management Administration
Expand Down
4 changes: 4 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -659,6 +659,10 @@ curl -s -H "Authorization: Bearer <admin_token>" -X GET https://<apigw_endpoint>

LISA provides the `/models` endpoint for creating both ECS and LiteLLM-hosted models. Depending on the request payload, infrastructure will be created or bypassed (e.g., for LiteLLM-only models).

This API accepts the same model definition parameters that were accepted in the V2 model definitions within the config.yaml file with one notable difference: the `containerConfig.baseImage.path` field is
now a path relative to the `lib/serve/ecs-model` directory, instead of from its original path relative to the repository root. This means that if the path used to be `lib/serve/ecs-model/textgen/tgi`, then
it will now be `textgen/tgi` for the CreateModel API. For vLLM models, the `path` will be `vllm`, and for TEI, it will be `embedding/tei`.

#### Request Example:

```
Expand Down
2 changes: 1 addition & 1 deletion VERSION
Original file line number Diff line number Diff line change
@@ -1 +1 @@
3.0.0
3.0.1
10 changes: 8 additions & 2 deletions ecs_model_deployer/src/lib/ecsCluster.ts
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@ export class ECSCluster extends Construct {
],
});

new CfnOutput(this, 'autoScalingGroup', {
key: 'autoScalingGroup',
value: autoScalingGroup.autoScalingGroupName,
});

const environment = ecsConfig.environment;
const volumes: Volume[] = [];
const mountPoints: MountPoint[] = [];
Expand Down Expand Up @@ -272,10 +277,11 @@ export class ECSCluster extends Construct {
const loadBalancer = new ApplicationLoadBalancer(this, createCdkId([ecsConfig.identifier, 'ALB']), {
deletionProtection: config.removalPolicy !== RemovalPolicy.DESTROY,
internetFacing: false,
loadBalancerName: createCdkId([config.deploymentName, ecsConfig.identifier], 32, 2),
loadBalancerName: createCdkId([config.deploymentName, ecsConfig.identifier], 32, 2).toLowerCase(),
dropInvalidHeaderFields: true,
securityGroup,
vpc,
idleTimeout: Duration.seconds(600)
});

// Add listener
Expand All @@ -294,7 +300,7 @@ export class ECSCluster extends Construct {
// Add targets
const loadBalancerHealthCheckConfig = ecsConfig.loadBalancerConfig.healthCheckConfig;
const targetGroup = listener.addTargets(createCdkId([ecsConfig.identifier, 'TgtGrp']), {
targetGroupName: createCdkId([config.deploymentName, ecsConfig.identifier], 32, 2),
targetGroupName: createCdkId([config.deploymentName, ecsConfig.identifier], 32, 2).toLowerCase(),
healthCheck: {
path: loadBalancerHealthCheckConfig.path,
interval: Duration.seconds(loadBalancerHealthCheckConfig.interval),
Expand Down
31 changes: 30 additions & 1 deletion lambda/authorizer/lambda_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,15 +16,20 @@
import logging
import os
import ssl
from functools import cache
from typing import Any, Dict

import boto3
import create_env_variables # noqa: F401
import jwt
import requests
from utilities.common_functions import authorization_wrapper, get_id_token
from botocore.exceptions import ClientError
from utilities.common_functions import authorization_wrapper, get_id_token, retry_config

logger = logging.getLogger(__name__)

secrets_manager = boto3.client("secretsmanager", region_name=os.environ["AWS_REGION"], config=retry_config)


@authorization_wrapper
def lambda_handler(event: Dict[str, Any], context) -> Dict[str, Any]: # type: ignore [no-untyped-def]
Expand All @@ -48,6 +53,11 @@ def lambda_handler(event: Dict[str, Any], context) -> Dict[str, Any]: # type: i

deny_policy = generate_policy(effect="Deny", resource=event["methodArn"])

if id_token in get_management_tokens():
allow_policy = generate_policy(effect="Allow", resource=event["methodArn"], username="lisa-management-token")
logger.debug(f"Generated policy: {allow_policy}")
return allow_policy

if jwt_data := id_token_is_valid(id_token=id_token, client_id=client_id, authority=authority):
is_admin_user = is_admin(jwt_data, admin_group, jwt_groups_property)
allow_policy = generate_policy(effect="Allow", resource=event["methodArn"], username=jwt_data["sub"])
Expand Down Expand Up @@ -134,3 +144,22 @@ def is_admin(jwt_data: dict[str, Any], admin_group: str, jwt_groups_property: st
else:
return False
return admin_group in current_node


@cache
def get_management_tokens() -> list[str]:
"""Return secret management tokens if they exist."""
secret_tokens: list[str] = []
secret_id = os.environ.get("MANAGEMENT_KEY_NAME")

try:
secret_tokens.append(
secrets_manager.get_secret_value(SecretId=secret_id, VersionStage="AWSCURRENT")["SecretString"]
)
secret_tokens.append(
secrets_manager.get_secret_value(SecretId=secret_id, VersionStage="AWSPREVIOUS")["SecretString"]
)
except ClientError as e:
logger.warn(f"Unable to fetch {secret_id}. {e.response['Error']['Code']}: {e.response['Error']['Message']}")

return secret_tokens
34 changes: 25 additions & 9 deletions lambda/models/domain_objects.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from typing import Annotated, List, Optional, Union

from pydantic import BaseModel
from pydantic.functional_validators import AfterValidator
from pydantic.functional_validators import AfterValidator, field_validator
from utilities.validators import validate_instance_type


Expand Down Expand Up @@ -87,7 +87,7 @@ class LoadBalancerConfig(BaseModel):


class AutoScalingConfig(BaseModel):
"""Autoscaling configuration."""
"""Autoscaling configuration upon model creation."""

minCapacity: int
maxCapacity: int
Expand All @@ -96,6 +96,14 @@ class AutoScalingConfig(BaseModel):
metricConfig: MetricConfig


class AutoScalingInstanceConfig(BaseModel):
"""Autoscaling instance count configuration upon model update."""

minCapacity: Optional[int] = None
maxCapacity: Optional[int] = None
desiredCapacity: Optional[int] = None


class ContainerHealthCheckConfig(BaseModel):
"""Health check configuration for a container."""

Expand Down Expand Up @@ -180,16 +188,24 @@ class GetModelResponse(ApiResponseBase):
class UpdateModelRequest(BaseModel):
"""Request object when updating a model."""

autoScalingConfig: Optional[AutoScalingConfig] = None
containerConfig: Optional[ContainerConfig] = None
inferenceContainer: Optional[InferenceContainer] = None
instanceType: Optional[Annotated[str, AfterValidator(validate_instance_type)]] = None
loadBalancerConfig: Optional[LoadBalancerConfig] = None
modelId: str
modelName: Optional[str] = None
autoScalingInstanceConfig: Optional[AutoScalingInstanceConfig] = None
enabled: Optional[bool] = None
modelType: Optional[ModelType] = None
streaming: Optional[bool] = None

@field_validator("autoScalingInstanceConfig") # type: ignore
@classmethod
def validate_autoscaling_instance_config(cls, config: AutoScalingInstanceConfig) -> AutoScalingInstanceConfig:
"""Validate that the AutoScaling instance config has at least one positive value."""
if not config:
raise ValueError("The autoScalingInstanceConfig must not be null if defined in request payload.")
config_fields = (config.minCapacity, config.maxCapacity, config.desiredCapacity)
if all((field is None for field in config_fields)):
raise ValueError("At least one option of autoScalingInstanceConfig must be defined.")
if any((isinstance(field, int) and field < 0 for field in config_fields)):
raise ValueError("All autoScalingInstanceConfig fields must be >= 0.")
return config


class UpdateModelResponse(ApiResponseBase):
"""Response object when updating a model."""
Expand Down
24 changes: 24 additions & 0 deletions lambda/models/exception/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
"""Exception definitions for model management APIs."""


# LiteLLM errors


class ModelNotFoundError(LookupError):
"""Error to raise when a specified model cannot be found in the database."""

Expand All @@ -25,3 +28,24 @@ class ModelAlreadyExistsError(LookupError):
"""Error to raise when a specified model already exists in the database."""

pass


# State machine exceptions


class MaxPollsExceededException(Exception):
"""Exception to indicate that polling for a state timed out."""

pass


class StackFailedToCreateException(Exception):
"""Exception to indicate that the CDK for creating a model stack failed."""

pass


class UnexpectedCloudFormationStateException(Exception):
"""Exception to indicate that the CloudFormation stack has transitioned to a non-healthy state."""

pass
4 changes: 2 additions & 2 deletions lambda/models/lambda_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,7 +123,7 @@ def _create_dummy_model(model_name: str, model_type: ModelType, model_status: Mo
),
sharedMemorySize=2048,
healthCheckConfig=ContainerHealthCheckConfig(
command=["CMD-SHELL", "exit 0"], Interval=10, StartPeriod=30, Timeout=5, Retries=5
command=["CMD-SHELL", "exit 0"], interval=10, startPeriod=30, timeout=5, retries=5
),
environment={
"MAX_CONCURRENT_REQUESTS": "128",
Expand Down Expand Up @@ -177,7 +177,7 @@ async def update_model(
) -> UpdateModelResponse:
"""Endpoint to update a model."""
# TODO add service to update model
model = _create_dummy_model("model_name", ModelType.TEXTGEN, ModelStatus.UPDATING)
model = _create_dummy_model(model_id, ModelType.TEXTGEN, ModelStatus.UPDATING)
return UpdateModelResponse(model=model)


Expand Down
79 changes: 74 additions & 5 deletions lambda/models/state_machine/create_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,11 @@
from botocore.config import Config
from models.clients.litellm_client import LiteLLMClient
from models.domain_objects import CreateModelRequest, ModelStatus
from models.exception import (
MaxPollsExceededException,
StackFailedToCreateException,
UnexpectedCloudFormationStateException,
)
from utilities.common_functions import get_cert_path, get_rest_api_container_endpoint, retry_config

lambdaConfig = Config(connect_timeout=60, read_timeout=600, retries={"max_attempts": 1})
Expand Down Expand Up @@ -113,8 +118,13 @@ def handle_poll_docker_image_available(event: Dict[str, Any], context: Any) -> D
output_dict["image_info"]["remaining_polls"] -= 1
if output_dict["image_info"]["remaining_polls"] <= 0:
ec2Client.terminate_instances(InstanceIds=[event["image_info"]["instance_id"]])
raise Exception(
"Maximum number of ECR poll attempts reached. Something went wrong building the docker image."
raise MaxPollsExceededException(
json.dumps(
{
"error": "Max number of ECR polls reached. Docker Image was not replicated successfully.",
"event": event,
}
)
)
return output_dict

Expand Down Expand Up @@ -153,7 +163,17 @@ def camelize_object(o): # type: ignore[no-untyped-def]

payload = response["Payload"].read()
payload = json.loads(payload)
stack_name = payload.get("stackName")
stack_name = payload.get("stackName", None)

if not stack_name:
raise StackFailedToCreateException(
json.dumps(
{
"error": "Failed to create Model CloudFormation Stack. Please validate model parameters are valid.",
"event": event,
}
)
)

response = cfnClient.describe_stacks(StackName=stack_name)
stack_arn = response["Stacks"][0]["StackId"]
Expand Down Expand Up @@ -192,10 +212,24 @@ def handle_poll_create_stack(event: Dict[str, Any], context: Any) -> Dict[str, A
output_dict["continue_polling_stack"] = True
output_dict["remaining_polls_stack"] -= 1
if output_dict["remaining_polls_stack"] <= 0:
raise Exception("Maximum number of CloudFormation polls reached")
raise MaxPollsExceededException(
json.dumps(
{
"error": "Max number of CloudFormation polls reached.",
"event": event,
}
)
)
return output_dict
else:
raise Exception(f"Stack in unexpected state: {stackStatus}")
raise UnexpectedCloudFormationStateException(
json.dumps(
{
"error": f"Stack entered unexpected state: {stackStatus}",
"event": event,
}
)
)


def handle_add_model_to_litellm(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
Expand Down Expand Up @@ -234,3 +268,38 @@ def handle_add_model_to_litellm(event: Dict[str, Any], context: Any) -> Dict[str
)

return output_dict


def handle_failure(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
"""
Handle failures from state machine.

Possible causes of failures would be:
1. Docker Image failed to replicate into ECR in expected amount of time
2. CloudFormation Stack creation failed from parameter validation.
3. CloudFormation Stack creation failed from taking too long to stand up.

Expectation of this function is to terminate the EC2 instance if it is still running, and to set the model status
to Failed. Cleaning up the CloudFormation stack, if it still exists, will happen in the DeleteModel API.
"""
error_dict = json.loads( # error from SFN is json payload on top of json payload we add to the exception
json.loads(event["Cause"])["errorMessage"]
)
error_reason = error_dict["error"]
original_event = error_dict["event"]

# terminate EC2 instance if we have one recorded
if "image_info" in original_event and "instance_id" in original_event["image_info"]:
ec2Client.terminate_instances(InstanceIds=[original_event["image_info"]["instance_id"]])

# set model as Failed in DDB, so it shows as such in the UI. adds error reason as well.
model_table.update_item(
Key={"model_id": original_event["modelId"]},
UpdateExpression="SET model_status = :ms, last_modified_date = :lm, failure_reason = :fr",
ExpressionAttributeValues={
":ms": ModelStatus.FAILED,
":lm": int(datetime.utcnow().timestamp()),
":fr": error_reason,
},
)
return event
5 changes: 3 additions & 2 deletions lambda/models/state_machine/delete_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ def handle_set_model_to_deleting(event: Dict[str, Any], context: Any) -> Dict[st
if not item:
raise RuntimeError(f"Requested model '{model_id}' was not found in DynamoDB table.")
output_dict[CFN_STACK_ARN] = item.get(CFN_STACK_ARN, None)
output_dict[LITELLM_ID] = item[LITELLM_ID]
output_dict[LITELLM_ID] = item.get(LITELLM_ID, None)

ddb_table.update_item(
Key=model_key,
Expand All @@ -77,7 +77,8 @@ def handle_set_model_to_deleting(event: Dict[str, Any], context: Any) -> Dict[st

def handle_delete_from_litellm(event: Dict[str, Any], context: Any) -> Dict[str, Any]:
"""Delete model reference from LiteLLM."""
litellm_client.delete_model(identifier=event[LITELLM_ID])
if event[LITELLM_ID]: # if non-null ID
litellm_client.delete_model(identifier=event[LITELLM_ID])
return event


Expand Down
Loading
Loading