Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Moving ALB behind APIGW to allow service code to call apis without auth #66

Merged
merged 15 commits into from
Sep 5, 2024
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ repos:
hooks:
- id: codespell
entry: codespell
args: ['--skip=*.git*,*cdk.out*,*venv*,*mypy_cache*,*package-lock*,*node_modules*,*dist/*,*poetry.lock*,*coverage*', "-L=xdescribe"]
args: ['--skip=*.git*,*cdk.out*,*venv*,*mypy_cache*,*package-lock*,*node_modules*,*dist/*,*poetry.lock*,*coverage*,*models/*', "-L=xdescribe"]
pass_filenames: false

- repo: https://github.com/pycqa/isort
Expand Down
61 changes: 47 additions & 14 deletions lib/api-base/ecsCluster.ts
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,13 @@
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
*/

// ECS Cluster Construct.
import { Duration, RemovalPolicy } from 'aws-cdk-lib';
import { BlockDeviceVolume, GroupMetrics, Monitoring } from 'aws-cdk-lib/aws-autoscaling';
import { Metric, Stats } from 'aws-cdk-lib/aws-cloudwatch';
import { InstanceType, SecurityGroup, IVpc } from 'aws-cdk-lib/aws-ec2';
import { InstanceType, IVpc, SecurityGroup } from 'aws-cdk-lib/aws-ec2';
import { Repository } from 'aws-cdk-lib/aws-ecr';
import {
AmiHardwareType,
Expand All @@ -37,14 +37,20 @@ import {
Protocol,
Volume,
} from 'aws-cdk-lib/aws-ecs';
import { ApplicationLoadBalancer, BaseApplicationListenerProps } from 'aws-cdk-lib/aws-elasticloadbalancingv2';
import {
ApplicationLoadBalancer,
ApplicationProtocol,
BaseApplicationListenerProps,
NetworkLoadBalancer,
NetworkTargetGroup
} from 'aws-cdk-lib/aws-elasticloadbalancingv2';
import { IRole, ManagedPolicy, Role } from 'aws-cdk-lib/aws-iam';
import { StringParameter } from 'aws-cdk-lib/aws-ssm';
import { Construct } from 'constructs';

import { createCdkId } from '../core/utils';
import { BaseProps, Ec2Metadata, EcsSourceType } from '../schema';
import { ECSConfig } from '../schema';
import { BaseProps, Ec2Metadata, ECSConfig, EcsSourceType } from '../schema';
import { AlbTarget } from 'aws-cdk-lib/aws-elasticloadbalancingv2-targets';

/**
* Properties for the ECSCluster Construct.
Expand All @@ -57,6 +63,7 @@ type ECSClusterProps = {
ecsConfig: ECSConfig;
securityGroup: SecurityGroup;
vpc: IVpc;
addNlb?: boolean;
} & BaseProps;

/**
Expand All @@ -72,6 +79,10 @@ export class ECSCluster extends Construct {
/** Endpoint URL of application load balancer for the cluster. */
public readonly endpointUrl: string;

public readonly alb: ApplicationLoadBalancer;

public readonly nlb: NetworkLoadBalancer;

/**
* @param {Construct} scope - The parent or owner of the construct.
* @param {string} id - The unique identifier for the construct within its scope.
Expand Down Expand Up @@ -259,25 +270,47 @@ export class ECSCluster extends Construct {
service.node.addDependency(autoScalingGroup);

// Create application load balancer
const loadBalancer = new ApplicationLoadBalancer(this, createCdkId([ecsConfig.identifier, 'ALB']), {
this.alb = new ApplicationLoadBalancer(this, createCdkId([ecsConfig.identifier, 'ALB']), {
deletionProtection: config.removalPolicy !== RemovalPolicy.DESTROY,
internetFacing: ecsConfig.internetFacing,
loadBalancerName: createCdkId([config.deploymentName, ecsConfig.identifier], 32, 2),
loadBalancerName: createCdkId([config.deploymentName, ecsConfig.identifier, 'ALB'], 32, 2),
dropInvalidHeaderFields: true,
securityGroup,
vpc,
});

if (props.addNlb) {
this.nlb = new NetworkLoadBalancer(this, createCdkId([ecsConfig.identifier, 'NLB']), {
deletionProtection: config.removalPolicy !== RemovalPolicy.DESTROY,
crossZoneEnabled: true,
internetFacing: ecsConfig.internetFacing,
loadBalancerName: createCdkId([config.deploymentName, ecsConfig.identifier, 'NLB'], 32, 2),
securityGroups: [securityGroup],
vpc,
});

const nlbListener = this.nlb.addListener('Listener', { port: 80 });

const albTargetGroup = new NetworkTargetGroup(this, 'ALB-Target-Group', {
port: 80,
vpc: vpc,
targets: [new AlbTarget(this.alb, 80)],
healthCheck: {
path: '/health'
}
});

nlbListener.addTargetGroups('ALB-Target-Group', albTargetGroup);
}

// Add listener
const listenerProps: BaseApplicationListenerProps = {
port: ecsConfig.loadBalancerConfig.sslCertIamArn ? 443 : 80,
port: 80,
open: ecsConfig.internetFacing,
certificates: ecsConfig.loadBalancerConfig.sslCertIamArn
? [{ certificateArn: ecsConfig.loadBalancerConfig.sslCertIamArn }]
: undefined,
protocol: ApplicationProtocol.HTTP
};

const listener = loadBalancer.addListener(
const listener = this.alb.addListener(
createCdkId([ecsConfig.identifier, 'ApplicationListener']),
listenerProps,
);
Expand Down Expand Up @@ -305,7 +338,7 @@ export class ECSCluster extends Construct {
namespace: 'AWS/ApplicationELB',
dimensionsMap: {
TargetGroup: targetGroup.targetGroupFullName,
LoadBalancer: loadBalancer.loadBalancerFullName,
LoadBalancer: this.alb.loadBalancerFullName,
},
statistic: Stats.SAMPLE_COUNT,
period: Duration.seconds(ecsConfig.autoScalingConfig.metricConfig.duration),
Expand All @@ -321,7 +354,7 @@ export class ECSCluster extends Construct {
const domain =
ecsConfig.loadBalancerConfig.domainName !== null
? ecsConfig.loadBalancerConfig.domainName
: loadBalancer.loadBalancerDnsName;
: this.alb.loadBalancerDnsName;
const endpoint = `${protocol}://${domain}`;
this.endpointUrl = endpoint;

Expand Down
67 changes: 56 additions & 11 deletions lib/api-base/fastApiContainer.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,15 @@ import { dump as yamlDump } from 'js-yaml';

import { ECSCluster } from './ecsCluster';
import { BaseProps, Ec2Metadata, EcsSourceType, FastApiContainerConfig } from '../schema';
import {
ConnectionType,
Cors,
IAuthorizer,
Integration,
IntegrationType,
RestApi,
VpcLink
} from 'aws-cdk-lib/aws-apigateway';

// This is the amount of memory to buffer (or subtract off) from the total instance memory, if we don't include this,
// the container can have a hard time finding available RAM resources to start and the tasks will fail deployment
Expand All @@ -36,6 +45,9 @@ const CONTAINER_MEMORY_BUFFER = 1024 * 2;
* @property {SecurityGroup} securityGroups - The security groups of the application.
*/
type FastApiContainerProps = {
authorizer: IAuthorizer;
restApiId: string;
rootResourceId: string;
apiName: string;
resourcePath: string;
securityGroup: SecurityGroup;
Expand Down Expand Up @@ -82,18 +94,9 @@ export class FastApiContainer extends Construct {
AWS_REGION_NAME: config.region, // for supporting SageMaker endpoints in LiteLLM
THREADS: Ec2Metadata.get(taskConfig.instanceType).vCpus.toString(),
LITELLM_KEY: config.litellmConfig.general_settings.master_key,
USE_AUTH: 'false',
};

if (config.restApiConfig.internetFacing) {
environment.USE_AUTH = 'true';
environment.AUTHORITY = config.authConfig!.authority;
environment.CLIENT_ID = config.authConfig!.clientId;
environment.ADMIN_GROUP = config.authConfig!.adminGroup;
environment.JWT_GROUPS_PROP = config.authConfig!.jwtGroupsProperty;
} else {
environment.USE_AUTH = 'false';
}

if (tokenTable) {
environment.TOKEN_TABLE_NAME = tokenTable.tableName;
}
Expand All @@ -109,16 +112,58 @@ export class FastApiContainer extends Construct {
environment,
identifier: props.apiName,
instanceType: taskConfig.instanceType,
internetFacing: config.restApiConfig.internetFacing,
internetFacing: false,
loadBalancerConfig: taskConfig.loadBalancerConfig,
},
securityGroup,
vpc,
addNlb: true
});
estohlmann marked this conversation as resolved.
Show resolved Hide resolved

const nlbVpcLink = new VpcLink(this, 'nlb-vpc-link', {
targets: [apiCluster.nlb]
});

// get the rest api
const restApi = RestApi.fromRestApiAttributes(this, 'RestApi', {
restApiId: props.restApiId,
rootResourceId: props.rootResourceId,
});

const integration = new Integration({
type: IntegrationType.HTTP_PROXY,
integrationHttpMethod: 'ANY',
options: {
connectionType: ConnectionType.VPC_LINK,
vpcLink: nlbVpcLink,
requestParameters: {
'integration.request.path.proxy': 'method.request.path.proxy'
},
},
uri: `${apiCluster.endpointUrl}/{proxy}`,
});

// create the proxy
const resource = restApi.root.addResource('llm').addProxy({
defaultIntegration: integration,
anyMethod: true,
defaultMethodOptions: {
authorizer: props.authorizer,
requestParameters: {
'method.request.path.proxy': true
}
}
});

resource.addCorsPreflight({
allowOrigins: Cors.ALL_ORIGINS,
allowHeaders: ['*'],
});

if (tokenTable) {
tokenTable.grantReadData(apiCluster.taskRole);
}

this.endpoint = apiCluster.endpointUrl;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I missed this detail- looking in the cluster definition, this is already the ALB address


// Update
Expand Down
8 changes: 1 addition & 7 deletions lib/networking/vpc/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -118,13 +118,7 @@ export class Vpc extends Construct {
// All HTTP VPC traffic -> ECS model ALB
ecsModelAlbSg.addIngressRule(Peer.ipv4(vpc.vpcCidrBlock), Port.tcp(80), 'Allow VPC traffic on port 80');

if (config.restApiConfig.loadBalancerConfig.sslCertIamArn) {
// All HTTPS IPV4 traffic -> REST API ALB
restApiAlbSg.addIngressRule(Peer.anyIpv4(), Port.tcp(443), 'Allow any traffic on port 443');
} else {
// All HTTP VPC traffic -> REST API ALB
restApiAlbSg.addIngressRule(Peer.ipv4(vpc.vpcCidrBlock), Port.tcp(80), 'Allow VPC traffic on port 80');
}
restApiAlbSg.addIngressRule(Peer.ipv4(vpc.vpcCidrBlock), Port.tcp(80), 'Allow VPC traffic on port 80');

// Update
this.vpc = vpc;
Expand Down
7 changes: 7 additions & 0 deletions lib/serve/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,15 @@ import { FastApiContainer } from '../api-base/fastApiContainer';
import { createCdkId } from '../core/utils';
import { Vpc } from '../networking/vpc';
import { BaseProps } from '../schema';
import { IAuthorizer } from 'aws-cdk-lib/aws-apigateway';

const HERE = path.resolve(__dirname);

type CustomLisaStackProps = {
vpc: Vpc;
authorizer: IAuthorizer;
restApiId: string;
rootResourceId: string;
} & BaseProps;
type LisaStackProps = CustomLisaStackProps & StackProps;

Expand Down Expand Up @@ -73,6 +77,9 @@ export class LisaServeApplicationStack extends Stack {

// Create REST API
const restApi = new FastApiContainer(this, 'RestApi', {
authorizer: props.authorizer,
restApiId: props.restApiId,
rootResourceId: props.rootResourceId,
apiName: 'REST',
config: config,
resourcePath: path.join(HERE, 'rest-api'),
Expand Down
22 changes: 13 additions & 9 deletions lib/stages.ts
Original file line number Diff line number Diff line change
Expand Up @@ -128,24 +128,28 @@ export class LisaServeApplicationStage extends Stage {
});
stacks.push(coreStack);

const apiBaseStack = new LisaApiBaseStack(this, 'LisaApiBase', {
...baseStackProps,
stackName: createCdkId([config.deploymentName, config.appName, 'API']),
description: `LISA-API: ${config.deploymentName}-${config.deploymentStage}`,
vpc: networkingStack.vpc.vpc,
});
apiBaseStack.addDependency(coreStack);
stacks.push(apiBaseStack);

const serveStack = new LisaServeApplicationStack(this, 'LisaServe', {
...baseStackProps,
authorizer: apiBaseStack.authorizer,
restApiId: apiBaseStack.restApiId,
rootResourceId: apiBaseStack.rootResourceId,
description: `LISA-serve: ${config.deploymentName}-${config.deploymentStage}`,
stackName: createCdkId([config.deploymentName, config.appName, 'serve', config.deploymentStage]),
vpc: networkingStack.vpc,
});
stacks.push(serveStack);

serveStack.addDependency(iamStack);

const apiBaseStack = new LisaApiBaseStack(this, 'LisaApiBase', {
...baseStackProps,
stackName: createCdkId([config.deploymentName, config.appName, 'API']),
description: `LISA-API: ${config.deploymentName}-${config.deploymentStage}`,
vpc: networkingStack.vpc.vpc,
});
apiBaseStack.addDependency(coreStack);
stacks.push(apiBaseStack);
serveStack.addDependency(apiBaseStack);

const apiDeploymentStack = new LisaApiDeploymentStack(this, 'LisaApiDeployment', {
...baseStackProps,
Expand Down
8 changes: 1 addition & 7 deletions lib/user-interface/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,8 @@ import { ManagedPolicy, Role, ServicePrincipal } from 'aws-cdk-lib/aws-iam';
import { Architecture, Runtime } from 'aws-cdk-lib/aws-lambda';
import { BlockPublicAccess, Bucket, BucketEncryption } from 'aws-cdk-lib/aws-s3';
import { BucketDeployment, Source } from 'aws-cdk-lib/aws-s3-deployment';
import { StringParameter } from 'aws-cdk-lib/aws-ssm';
import { Construct } from 'constructs';

import { createCdkId } from '../core/utils';
import { BaseProps } from '../schema';

/**
Expand Down Expand Up @@ -187,11 +185,7 @@ export class UserInterfaceStack extends Stack {
ADMIN_GROUP: config.authConfig!.adminGroup,
JWT_GROUPS_PROP: config.authConfig!.jwtGroupsProperty,
CUSTOM_SCOPES: config.authConfig!.additionalScopes,
RESTAPI_URI: StringParameter.fromStringParameterName(
estohlmann marked this conversation as resolved.
Show resolved Hide resolved
this,
createCdkId(['LisaRestApiUri', 'StringParameter']),
`${config.deploymentPrefix}/lisaServeRestApiUri`,
).stringValue,
RESTAPI_ID: config.apiGatewayConfig?.domainName ? '/llm' : `/${config.deploymentStage}/llm`,
RESTAPI_VERSION: config.restApiConfig.apiVersion,
RAG_ENABLED: config.deployRag,
SYSTEM_BANNER: {
Expand Down
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"@aws-cdk/aws-lambda-python-alpha": "2.125.0-alpha.0",
"@aws-sdk/client-iam": "^3.490.0",
"@cdklabs/cdk-enterprise-iac": "^0.0.400",
"@stylistic/eslint-plugin": "^2.6.4",
"@stylistic/eslint-plugin": "^2.7.2",
"@types/jest": "^29.5.12",
"@types/js-yaml": "^4.0.5",
"@types/node": "20.5.3",
Expand Down
Loading