Skip to content

Commit

Permalink
Allow Security Group overrides
Browse files Browse the repository at this point in the history
  • Loading branch information
bedanley authored Nov 26, 2024
1 parent 7d8447b commit b923389
Show file tree
Hide file tree
Showing 37 changed files with 693 additions and 259 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ cdk.context.json
.venv
.DS_Store
*.iml
*.code-workspace

# Coverage Statistic Folders
coverage
Expand Down
37 changes: 10 additions & 27 deletions ecs_model_deployer/src/lib/schema.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
*/

// Models for schema validation.
import * as cdk from 'aws-cdk-lib';
import * as ec2 from 'aws-cdk-lib/aws-ec2';
import { AmiHardwareType } from 'aws-cdk-lib/aws-ecs';
import { z } from 'zod';
import { SecurityGroupConfigSchema } from '../../../lib/schema';

const VERSION: string = '2.0.1';

Expand Down Expand Up @@ -64,19 +64,6 @@ export type RegisteredModel = {
streaming?: boolean;
};

/**
* Custom security groups for application.
*
* @property {ec2.SecurityGroup} ecsModelAlbSg - ECS model application load balancer security group.
* @property {ec2.SecurityGroup} restApiAlbSg - REST API application load balancer security group.
* @property {ec2.SecurityGroup} lambdaSecurityGroup - Lambda security group.
*/
export type SecurityGroups = {
ecsModelAlbSg: ec2.SecurityGroup;
restApiAlbSg: ec2.SecurityGroup;
lambdaSecurityGroup: ec2.SecurityGroup;
};

/**
* Metadata for a specific EC2 instance type.
*
Expand Down Expand Up @@ -336,7 +323,7 @@ const ImageRegistryAsset = z.object({
*
* @property {string} baseImage - Base image for the container.
* @property {Record<string, string>} [environment={}] - Environment variables for the container.
* @property {ContainerHealthCheckConfig} [healthCheckConfig={}] - Health check configuration for the container.
* @property {ContainerHealthCheckConfigSchema} [healthCheckConfig={}] - Health check configuration for the container.
* @property {number} [sharedMemorySize=0] - The value for the size of the /dev/shm volume.
*/
const ContainerConfigSchema = z.object({
Expand Down Expand Up @@ -380,7 +367,7 @@ const HealthCheckConfigSchema = z.object({
* Configuration schema for the load balancer.
*
* @property {string} [sslCertIamArn=null] - SSL certificate IAM ARN for load balancer.
* @property {HealthCheckConfig} healthCheckConfig - Health check configuration for the load balancer.
* @property {HealthCheckConfigSchema} healthCheckConfig - Health check configuration for the load balancer.
* @property {string} domainName - Domain name to use instead of the load balancer's default DNS name.
*/
const LoadBalancerConfigSchema = z.object({
Expand Down Expand Up @@ -414,7 +401,7 @@ const MetricConfigSchema = z.object({
* @property {number} [cooldown=420] - Cool down period in seconds between scaling activities.
* @property {number} [defaultInstanceWarmup=180] - Default warm-up time in seconds until a newly launched instance can
send metrics to CloudWatch.
* @property {MetricConfig} metricConfig - Metric configuration for auto scaling.
* @property {MetricConfigSchema} metricConfig - Metric configuration for auto scaling.
*/
const AutoScalingConfigSchema = z.object({
blockDeviceVolumeSize: z.number().min(30).default(30),
Expand All @@ -432,7 +419,7 @@ const AutoScalingConfigSchema = z.object({
* @property {AutoScalingConfigSchema} autoScalingConfig - Configuration for auto scaling settings.
* @property {Record<string,string>} buildArgs - Optional build args to be applied when creating the
* task container if containerConfig.image.type is ASSET
* @property {ContainerConfig} containerConfig - Configuration for the container.
* @property {ContainerConfigSchema} containerConfig - Configuration for the container.
* @property {number} [containerMemoryBuffer=2048] - This is the amount of memory to buffer (or subtract off)
* from the total instance memory, if we don't include this,
* the container can have a hard time finding available RAM
Expand All @@ -441,7 +428,7 @@ const AutoScalingConfigSchema = z.object({
* @property {identifier} modelType - Unique identifier for the cluster which will be used when naming resources
* @property {string} instanceType - EC2 instance type for running the model.
* @property {boolean} [internetFacing=false] - Whether or not the cluster will be configured as internet facing
* @property {LoadBalancerConfig} loadBalancerConfig - Configuration for load balancer settings.
* @property {LoadBalancerConfigSchema} loadBalancerConfig - Configuration for load balancer settings.
*/
const EcsBaseConfigSchema = z.object({
amiHardwareType: z.nativeEnum(AmiHardwareType),
Expand Down Expand Up @@ -477,9 +464,9 @@ export type ECSConfig = EcsBaseConfig;
* @property {string} modelType - Type of model.
* @property {string} instanceType - EC2 instance type for running the model.
* @property {string} inferenceContainer - Prebuilt inference container for serving model.
* @property {ContainerConfig} containerConfig - Configuration for the container.
* @property {ContainerConfigSchema} containerConfig - Configuration for the container.
* @property {AutoScalingConfigSchema} autoScalingConfig - Configuration for auto scaling settings.
* @property {LoadBalancerConfig} loadBalancerConfig - Configuration for load balancer settings.
* @property {LoadBalancerConfigSchema} loadBalancerConfig - Configuration for load balancer settings.
* @property {string} [localModelCode='/opt/model-code'] - Path in container for local model code.
* @property {string} [modelHosting='ecs'] - Model hosting.
*/
Expand Down Expand Up @@ -562,19 +549,14 @@ const PypiConfigSchema = z.object({
* @property {string} deploymentStage - Deployment stage for the application.
* @property {string} removalPolicy - Removal policy for resources (destroy or retain).
* @property {boolean} [runCdkNag=false] - Whether to run CDK Nag checks.
* @property {string} [lambdaSourcePath='./lambda'] - Path to Lambda source code dir.
* @property {string} s3BucketModels - S3 bucket for models.
* @property {string} mountS3DebUrl - URL for S3-mounted Debian package.
* @property {string[]} [accountNumbersEcr=null] - List of AWS account numbers for ECR repositories.
* @property {boolean} [deployRag=false] - Whether to deploy RAG stacks.
* @property {boolean} [deployChat=true] - Whether to deploy chat stacks.
* @property {boolean} [deployUi=true] - Whether to deploy UI stacks.
* @property {string} logLevel - Log level for application.
* @property {AuthConfigSchema} authConfig - Authorization configuration.
* @property {RagRepositoryConfigSchema} ragRepositoryConfig - Rag Repository configuration.
* @property {RagFileProcessingConfigSchema} ragFileProcessingConfig - Rag file processing configuration.
* @property {EcsModelConfigSchema[]} ecsModels - Array of ECS model configurations.
* @property {ApiGatewayConfigSchema} apiGatewayConfig - API Gateway Endpoint configuration.
* @property {string} [nvmeHostMountPath='/nvme'] - Host path for NVMe drives.
* @property {string} [nvmeContainerMountPath='/nvme'] - Container path for NVMe drives.
* @property {Array<{ Key: string, Value: string }>} [tags=null] - Array of key-value pairs for tagging.
Expand All @@ -591,6 +573,7 @@ const RawConfigSchema = z
vpcId: z.string().optional(),
deploymentStage: z.string(),
removalPolicy: z.union([z.literal('destroy'), z.literal('retain')]).transform((value) => REMOVAL_POLICIES[value]),
securityGroupConfig: SecurityGroupConfigSchema.optional(),
s3BucketModels: z.string(),
mountS3DebUrl: z.string().optional(),
pypiConfig: PypiConfigSchema.optional().default({
Expand Down
7 changes: 7 additions & 0 deletions example_config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,13 @@ s3BucketModels: hf-models-gaiic
# subnets:
# - subnetId:
# ipv4CidrBlock:
# securityGroupConfig: # If securityGroupConfig is provided, all security groups must be overridden. Vector stores SGs are optional based on deployment preferences.
# modelSecurityGroupId: sg-0123456789abcdef
# restAlbSecurityGroupId: sg-0123456789abcdef
# lambdaSecurityGroupId: sg-0123456789abcdef
# liteLlmDbSecurityGroupId: sg-0123456789abcdef
# openSearchSecurityGroupId: sg-0123456789abcdef #Optional
# pgVectorSecurityGroupId: sg-0123456789abcdef #Optional
# The following configuration will allow for using a custom domain for the chat user interface.
# If this option is specified, the API Gateway invocation URL will NOT work on its own as the application URL.
# Users must use the custom domain for the user interface to work if this option is populated.
Expand Down
8 changes: 4 additions & 4 deletions lib/api-base/authorizer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ import { Queue } from 'aws-cdk-lib/aws-sqs';
*/
type AuthorizerProps = {
role?: IRole;
vpc?: Vpc;
securityGroups?: ISecurityGroup[];
vpc: Vpc;
securityGroups: ISecurityGroup[];
} & BaseProps;

/**
Expand Down Expand Up @@ -98,9 +98,9 @@ export class CustomAuthorizer extends Construct {
},
reservedConcurrentExecutions: 20,
role: role,
vpc: vpc?.vpc,
vpc: vpc.vpc,
securityGroups: securityGroups,
vpcSubnets: vpc?.subnetSelection
vpcSubnets: vpc.subnetSelection
});

const managementKeySecret = Secret.fromSecretNameV2(this, createCdkId([id, 'managementKey']), managementKeySecretNameStringParameter.stringValue);
Expand Down
16 changes: 7 additions & 9 deletions lib/api-base/ecsCluster.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import { Duration, RemovalPolicy } from 'aws-cdk-lib';
import { BlockDeviceVolume, GroupMetrics, Monitoring } from 'aws-cdk-lib/aws-autoscaling';
import { Metric, Stats } from 'aws-cdk-lib/aws-cloudwatch';
import { InstanceType, SecurityGroup } from 'aws-cdk-lib/aws-ec2';
import { InstanceType, ISecurityGroup } from 'aws-cdk-lib/aws-ec2';
import { Repository } from 'aws-cdk-lib/aws-ecr';
import {
AmiHardwareType,
Expand All @@ -43,20 +43,19 @@ import { StringParameter } from 'aws-cdk-lib/aws-ssm';
import { Construct } from 'constructs';

import { createCdkId } from '../core/utils';
import { BaseProps, Ec2Metadata, EcsSourceType } from '../schema';
import { ECSConfig } from '../schema';
import { BaseProps, Ec2Metadata, ECSConfig, EcsSourceType } from '../schema';
import { Vpc } from '../networking/vpc';

/**
* Properties for the ECSCluster Construct.
*
* @property {IVpc} vpc - The virtual private cloud (VPC).
* @property {SecurityGroups} securityGroups - The security group that the ECS cluster should use.
* @property {ECSConfig} ecsConfig - The configuration for the cluster.
* @property {ISecurityGroup} securityGroup - The security group that the ECS cluster should use.
* @property {Vpc} vpc - The virtual private cloud (VPC).
*/
type ECSClusterProps = {
ecsConfig: ECSConfig;
securityGroup: SecurityGroup;
securityGroup: ISecurityGroup;
vpc: Vpc;
} & BaseProps;

Expand Down Expand Up @@ -89,7 +88,7 @@ export class ECSCluster extends Construct {
containerInsights: !config.region.includes('iso'),
});

// Create auto scaling group
// Create auto-scaling group
const autoScalingGroup = cluster.addCapacity(createCdkId(['ASG']), {
vpcSubnets: vpc.subnetSelection,
instanceType: new InstanceType(ecsConfig.instanceType),
Expand Down Expand Up @@ -326,8 +325,7 @@ export class ECSCluster extends Construct {
ecsConfig.loadBalancerConfig.domainName !== null
? ecsConfig.loadBalancerConfig.domainName
: loadBalancer.loadBalancerDnsName;
const endpoint = `${protocol}://${domain}`;
this.endpointUrl = endpoint;
this.endpointUrl = `${protocol}://${domain}`;

// Update
this.container = container;
Expand Down
10 changes: 5 additions & 5 deletions lib/api-base/fastApiContainer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@

import { CfnOutput } from 'aws-cdk-lib';
import { ITable } from 'aws-cdk-lib/aws-dynamodb';
import { SecurityGroup } from 'aws-cdk-lib/aws-ec2';
import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2';
import { AmiHardwareType, ContainerDefinition } from 'aws-cdk-lib/aws-ecs';
import { IRole } from 'aws-cdk-lib/aws-iam';
import { Construct } from 'constructs';
Expand All @@ -33,13 +33,13 @@ const CONTAINER_MEMORY_BUFFER = 1024 * 2;
/**
* Properties for FastApiContainer Construct.
*
* @property {IVpc} vpc - The virtual private cloud (VPC).
* @property {SecurityGroup} securityGroups - The security groups of the application.
* @property {Vpc} vpc - The virtual private cloud (VPC).
* @property {ISecurityGroup} securityGroup - The security groups of the application.
*/
type FastApiContainerProps = {
apiName: string;
resourcePath: string;
securityGroup: SecurityGroup;
securityGroup: ISecurityGroup;
tokenTable: ITable | undefined;
vpc: Vpc;
} & BaseProps;
Expand All @@ -60,7 +60,7 @@ export class FastApiContainer extends Construct {
/**
* @param {Construct} scope - The parent or owner of the construct.
* @param {string} id - The unique identifier for the construct within its scope.
* @param {RestApiProps} props - The properties of the construct.
* @param {FastApiContainerProps} props - The properties of the construct.
*/
constructor (scope: Construct, id: string, props: FastApiContainerProps) {
super(scope, id);
Expand Down
16 changes: 8 additions & 8 deletions lib/api-base/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
*/

/*
* Your use of this service is governed by the terms of the AWS Customer Agreement
Expand All @@ -28,15 +28,15 @@ import * as cdk from 'aws-cdk-lib';
import { Duration } from 'aws-cdk-lib';
import {
AuthorizationType,
Cors,
IAuthorizer,
IResource,
LambdaIntegration,
IRestApi,
Cors,
LambdaIntegration,
} from 'aws-cdk-lib/aws-apigateway';
import { ISecurityGroup } from 'aws-cdk-lib/aws-ec2';
import { IRole } from 'aws-cdk-lib/aws-iam';
import { Code, Function, Runtime, ILayerVersion, IFunction, CfnPermission } from 'aws-cdk-lib/aws-lambda';
import { CfnPermission, Code, Function, IFunction, ILayerVersion, Runtime } from 'aws-cdk-lib/aws-lambda';
import { Construct } from 'constructs';
import { Vpc } from '../networking/vpc';
import { Queue } from 'aws-cdk-lib/aws-sqs';
Expand Down Expand Up @@ -82,9 +82,9 @@ export function registerAPIEndpoint (
layers: ILayerVersion[],
funcDef: PythonLambdaFunction,
pythonRuntime: Runtime,
vpc: Vpc,
securityGroups: ISecurityGroup[],
role?: IRole,
vpc?: Vpc,
securityGroups?: ISecurityGroup[],
): IFunction {
const functionId = `${
funcDef.id ||
Expand Down Expand Up @@ -124,9 +124,9 @@ export function registerAPIEndpoint (
layers,
reservedConcurrentExecutions: 20,
role,
vpc: vpc?.vpc,
vpc: vpc.vpc,
securityGroups,
vpcSubnets: vpc?.subnetSelection,
vpcSubnets: vpc.subnetSelection,
});
}

Expand Down
6 changes: 3 additions & 3 deletions lib/chat/api/configuration.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ type ConfigurationApiProps = {
authorizer: IAuthorizer;
restApiId: string;
rootResourceId: string;
securityGroups?: ISecurityGroup[];
vpc?: Vpc;
securityGroups: ISecurityGroup[];
vpc: Vpc;
} & BaseProps;

/**
Expand Down Expand Up @@ -157,9 +157,9 @@ export class ConfigurationApi extends Construct {
[commonLambdaLayer],
f,
Runtime.PYTHON_3_10,
lambdaRole,
vpc,
securityGroups,
lambdaRole,
);
if (f.method === 'POST' || f.method === 'PUT') {
configTable.grantWriteData(lambdaFunction);
Expand Down
6 changes: 3 additions & 3 deletions lib/chat/api/session.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ type SessionApiProps = {
authorizer: IAuthorizer;
restApiId: string;
rootResourceId: string;
securityGroups?: ISecurityGroup[];
vpc?: Vpc;
securityGroups: ISecurityGroup[];
vpc: Vpc;
} & BaseProps;

/**
Expand Down Expand Up @@ -157,9 +157,9 @@ export class SessionApi extends Construct {
[commonLambdaLayer],
f,
Runtime.PYTHON_3_10,
lambdaRole,
vpc,
securityGroups,
lambdaRole,
);
if (f.method === 'POST' || f.method === 'PUT') {
sessionTable.grantWriteData(lambdaFunction);
Expand Down
4 changes: 2 additions & 2 deletions lib/chat/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ type CustomLisaChatStackProps = {
authorizer: IAuthorizer;
restApiId: string;
rootResourceId: string;
securityGroups?: ISecurityGroup[];
vpc?: Vpc;
securityGroups: ISecurityGroup[];
vpc: Vpc;
} & BaseProps;
type LisaChatStackProps = CustomLisaChatStackProps & StackProps;

Expand Down
3 changes: 2 additions & 1 deletion lib/core/api_base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import { Vpc } from '../networking/vpc';
type LisaApiBaseStackProps = {
vpc: Vpc;
} & BaseProps &
StackProps;
StackProps;

export class LisaApiBaseStack extends Stack {
public readonly restApi: RestApi;
Expand Down Expand Up @@ -61,6 +61,7 @@ export class LisaApiBaseStack extends Stack {
// Create the authorizer Lambda for APIGW
const authorizer = new CustomAuthorizer(this, 'LisaApiAuthorizer', {
config: config,
securityGroups: [vpc.securityGroups.lambdaSg],
vpc,
});

Expand Down
Loading

0 comments on commit b923389

Please sign in to comment.