diff --git a/upload-ami/.envrc b/upload-ami/.envrc new file mode 100644 index 0000000..ae7c0f2 --- /dev/null +++ b/upload-ami/.envrc @@ -0,0 +1,2 @@ +#!/bin/sh +use flake .#upload-ami diff --git a/upload-ami/default.nix b/upload-ami/default.nix index 95635a1..c4f16ae 100644 --- a/upload-ami/default.nix +++ b/upload-ami/default.nix @@ -3,19 +3,49 @@ , lib }: -let pyproject = builtins.fromTOML (builtins.readFile ./pyproject.toml); +let + pyproject = builtins.fromTOML (builtins.readFile ./pyproject.toml); + # str -> { name: str, extras: [str] } + parseDependency = dep: + let + parts = lib.splitString "[" dep; + name = lib.head parts; + extras = lib.optionals (lib.length parts > 1) + (lib.splitString "," (lib.removeSuffix "]" (builtins.elemAt parts 1))); + in + { name = name; extras = extras; }; + + # { name: str, extras: [str] } -> [package] + resolvePackages = dep: + let + inherit (parseDependency dep) name extras; + package = python3Packages.${name}; + optionalPackages = lib.flatten (map (name: package.optional-dependencies.${name}) extras); + in + [ package ] ++ optionalPackages; + + in buildPythonApplication { pname = pyproject.project.name; version = pyproject.project.version; src = ./.; pyproject = true; - nativeBuildInputs = - map (name: python3Packages.${name}) pyproject.build-system.requires; + map (name: python3Packages.${name}) pyproject.build-system.requires ++ [ + python3Packages.mypy + python3Packages.black + ]; + + propagatedBuildInputs = lib.flatten (map resolvePackages pyproject.project.dependencies); - propagatedBuildInputs = - map (name: python3Packages.${name}) pyproject.project.dependencies; + checkPhase = '' + mypy src + black --check src + ''; passthru.pyproject = pyproject; + passthru.parseDependency = parseDependency; + passthru.resolvePackages = resolvePackages; + } diff --git a/upload-ami/pyproject.toml b/upload-ami/pyproject.toml index 78ca3be..c483e56 100644 --- a/upload-ami/pyproject.toml +++ b/upload-ami/pyproject.toml @@ -6,9 +6,8 @@ name = "upload-ami" version = "0.1.0" dependencies = [ "boto3", - "botocore", - "mypy-boto3-ec2", - "mypy-boto3-s3", + "boto3-stubs[ec2,s3,sts,account,service-quotas]", + "botocore-stubs", ] [project.scripts] upload-ami = "upload_ami.upload_ami:main" @@ -18,3 +17,5 @@ disable-image-block-public-access = "upload_ami.disable_image_block_public_acces enable-regions = "upload_ami.enable_regions:main" request-public-ami-quota-increase = "upload_ami.request_public_ami_quota_increase:main" describe-images = "upload_ami.describe_images:main" +[tool.mypy] +strict=true diff --git a/upload-ami/src/upload_ami/describe_images.py b/upload-ami/src/upload_ami/describe_images.py index 9d90c99..cd28e5c 100644 --- a/upload-ami/src/upload_ami/describe_images.py +++ b/upload-ami/src/upload_ami/describe_images.py @@ -2,14 +2,18 @@ import boto3 import json -def main(): +from mypy_boto3_ec2 import EC2Client + + +def main() -> None: logging.basicConfig(level=logging.INFO) - ec2 = boto3.client("ec2") + ec2: EC2Client = boto3.client("ec2") regions = ec2.describe_regions()["Regions"] images = {} for region in regions: + assert "RegionName" in region ec2r = boto3.client("ec2", region_name=region["RegionName"]) result = ec2r.describe_images( @@ -17,8 +21,9 @@ def main(): ExecutableUsers=["all"], ) images[region["RegionName"]] = result - + print(json.dumps(images, indent=2)) + if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/upload-ami/src/upload_ami/disable_image_block_public_access.py b/upload-ami/src/upload_ami/disable_image_block_public_access.py index 3d99a3e..e8295de 100644 --- a/upload-ami/src/upload_ami/disable_image_block_public_access.py +++ b/upload-ami/src/upload_ami/disable_image_block_public_access.py @@ -2,25 +2,35 @@ import logging import time from concurrent.futures import ThreadPoolExecutor +from mypy_boto3_ec2.type_defs import RegionTypeDef -def main(): + +def main() -> None: logging.basicConfig(level=logging.INFO) ec2 = boto3.client("ec2") regions = ec2.describe_regions()["Regions"] - def disable_image_block_public_access(region): - logging.info("disabling image block public access in %s. Can take some minutes to apply", region["RegionName"]) + def disable_image_block_public_access(region: RegionTypeDef) -> None: + assert "RegionName" in region + logging.info( + "disabling image block public access in %s. Can take some minutes to apply", + region["RegionName"], + ) ec2 = boto3.client("ec2", region_name=region["RegionName"]) ec2.disable_image_block_public_access() while True: - state = ec2.get_image_block_public_access_state()["ImageBlockPublicAccessState"] + state = ec2.get_image_block_public_access_state()[ + "ImageBlockPublicAccessState" + ] if state == "unblocked": break - logging.info("waiting for image block public access state %s to be unblocked in %s", state, region["RegionName"]) + logging.info( + "waiting for image block public access state %s to be unblocked in %s", + state, + region["RegionName"], + ) time.sleep(30) - + with ThreadPoolExecutor(max_workers=len(regions)) as executor: executor.map(disable_image_block_public_access, regions) - - diff --git a/upload-ami/src/upload_ami/enable_regions.py b/upload-ami/src/upload_ami/enable_regions.py index 27d64e9..7badf22 100644 --- a/upload-ami/src/upload_ami/enable_regions.py +++ b/upload-ami/src/upload_ami/enable_regions.py @@ -1,19 +1,21 @@ import boto3 +from mypy_boto3_account import AccountClient import logging -def main(): +def main() -> None: """ Enable all regions that are disabled Due to rate limiting, you might need to run this multiple times. """ logging.basicConfig(level=logging.INFO) - account = boto3.client("account") + account: AccountClient = boto3.client("account") pages = account.get_paginator("list_regions").paginate( RegionOptStatusContains=["DISABLED"] ) for page in pages: for region in page["Regions"]: + assert "RegionName" in region logging.info(f"enabling region {region['RegionName']}") account.enable_region(RegionName=region["RegionName"]) diff --git a/upload-ami/src/upload_ami/nuke.py b/upload-ami/src/upload_ami/nuke.py index 1d0e464..2fd53ea 100644 --- a/upload-ami/src/upload_ami/nuke.py +++ b/upload-ami/src/upload_ami/nuke.py @@ -1,24 +1,33 @@ import logging import boto3 +from mypy_boto3_ec2 import EC2Client -def main(): +def main() -> None: logging.basicConfig(level=logging.INFO) - ec2 = boto3.client("ec2", region_name="us-east-1") + ec2: EC2Client = boto3.client("ec2", region_name="us-east-1") regions = ec2.describe_regions()["Regions"] for region in regions: + assert "RegionName" in region ec2r = boto3.client("ec2", region_name=region["RegionName"]) logging.info(f"Nuking {region['RegionName']}") snapshots = ec2r.describe_snapshots(OwnerIds=["self"]) for snapshot in snapshots["Snapshots"]: + assert "SnapshotId" in snapshot images = ec2r.describe_images( Owners=["self"], - Filters=[{"Name": "block-device-mapping.snapshot-id", "Values": [snapshot["SnapshotId"]]}], + Filters=[ + { + "Name": "block-device-mapping.snapshot-id", + "Values": [snapshot["SnapshotId"]], + } + ], ) for image in images["Images"]: + assert "ImageId" in image logging.info(f"Deregistering {image['ImageId']}") ec2r.deregister_image(ImageId=image["ImageId"]) diff --git a/upload-ami/src/upload_ami/request_public_ami_quota_increase.py b/upload-ami/src/upload_ami/request_public_ami_quota_increase.py index 4be79c7..457bdc9 100644 --- a/upload-ami/src/upload_ami/request_public_ami_quota_increase.py +++ b/upload-ami/src/upload_ami/request_public_ami_quota_increase.py @@ -1,33 +1,54 @@ +from ast import List +from typing import Iterator import boto3 import logging +from mypy_boto3_ec2 import EC2Client +from mypy_boto3_service_quotas import ServiceQuotasClient +from mypy_boto3_service_quotas.type_defs import ( + ListServiceQuotasResponseTypeDef, + ServiceQuotaTypeDef, +) -def get_public_ami_service_quota(servicequotas): - return next(servicequotas - .get_paginator('list_service_quotas') - .paginate(ServiceCode="ec2") - .search("Quotas[?QuotaName=='Public AMIs']")) +def get_public_ami_service_quota( + servicequotas: ServiceQuotasClient, +) -> ServiceQuotaTypeDef: + paginator = servicequotas.get_paginator("list_service_quotas") + searched: Iterator[ServiceQuotaTypeDef] = paginator.paginate( + ServiceCode="ec2" + ).search("Quotas[?QuotaName=='Public AMIs']") + return next(searched) -def main(): + +def main() -> None: import argparse + parser = argparse.ArgumentParser() parser.add_argument("--desired-value", type=int, default=1000) args = parser.parse_args() logging.basicConfig(level=logging.INFO) - ec2 = boto3.client("ec2") + ec2: EC2Client = boto3.client("ec2") regions = ec2.describe_regions()["Regions"] for region in regions: - servicequotas = boto3.client( - "service-quotas", region_name=region["RegionName"]) + assert "RegionName" in region + servicequotas: ServiceQuotasClient = boto3.client( + "service-quotas", region_name=region["RegionName"] + ) service_quota = get_public_ami_service_quota(servicequotas) - logging.info( - f"Quota for {region['RegionName']} is {service_quota['Value']}") + + assert "Value" in service_quota + logging.info(f"Quota for {region['RegionName']} is {service_quota['Value']}") try: - if service_quota['Value'] < args.desired_value: + if service_quota["Value"] < args.desired_value: logging.info( - f"Requesting quota increase for {region['RegionName']} from {service_quota['Value']} to {args.desired_value}") - servicequotas.request_service_quota_increase( ServiceCode="ec2", QuotaCode=service_quota['QuotaCode'], DesiredValue=args.desired_value) + f"Requesting quota increase for {region['RegionName']} from {service_quota['Value']} to {args.desired_value}" + ) + servicequotas.request_service_quota_increase( + ServiceCode="ec2", + QuotaCode=service_quota["QuotaCode"], + DesiredValue=args.desired_value, + ) except Exception as e: logging.warn(e) diff --git a/upload-ami/src/upload_ami/smoke_test.py b/upload-ami/src/upload_ami/smoke_test.py index c07d697..c42101b 100644 --- a/upload-ami/src/upload_ami/smoke_test.py +++ b/upload-ami/src/upload_ami/smoke_test.py @@ -3,14 +3,19 @@ import argparse import logging +from mypy_boto3_ec2 import EC2Client +from mypy_boto3_ec2.literals import InstanceTypeType -def smoke_test(image_id, run_id, cancel): - ec2 = boto3.client("ec2") + +def smoke_test(image_id: str, run_id: str, cancel: bool) -> None: + ec2: EC2Client = boto3.client("ec2") images = ec2.describe_images(Owners=["self"], ImageIds=[image_id]) assert len(images["Images"]) == 1 image = images["Images"][0] + assert "Architecture" in image architecture = image["Architecture"] + instance_type: InstanceTypeType if architecture == "x86_64": instance_type = "t3.nano" elif architecture == "arm64": @@ -29,6 +34,7 @@ def smoke_test(image_id, run_id, cancel): ) instance = run_instances["Instances"][0] + assert "InstanceId" in instance instance_id = instance["InstanceId"] try: @@ -46,7 +52,9 @@ def smoke_test(image_id, run_id, cancel): logging.info( f"Waiting for console output to become available ({tries} tries left)" ) - console_output = ec2.get_console_output(InstanceId=instance_id, Latest=True) + console_output = ec2.get_console_output( + InstanceId=instance_id, Latest=True + ) output = console_output.get("Output") tries -= 1 logging.info(f"Console output: {output}") @@ -55,12 +63,14 @@ def smoke_test(image_id, run_id, cancel): raise finally: logging.info(f"Terminating instance {instance_id}") + assert "State" in instance + assert "Name" in instance["State"] if instance["State"]["Name"] != "terminated": ec2.terminate_instances(InstanceIds=[instance_id]) ec2.get_waiter("instance_terminated").wait(InstanceIds=[instance_id]) -def main(): +def main() -> None: logging.basicConfig(level=logging.INFO) parser = argparse.ArgumentParser() diff --git a/upload-ami/src/upload_ami/upload_ami.py b/upload-ami/src/upload_ami/upload_ami.py index 6890820..ac53a0b 100644 --- a/upload-ami/src/upload_ami/upload_ami.py +++ b/upload-ami/src/upload_ami/upload_ami.py @@ -1,14 +1,30 @@ import json import hashlib import logging -import os +from typing import Iterable, Literal, TypedDict import boto3 +import boto3.ec2 +import boto3.ec2.createtags import botocore import botocore.exceptions + +from mypy_boto3_ec2.client import EC2Client +from mypy_boto3_ec2.literals import BootModeValuesType +from mypy_boto3_ec2.type_defs import RegionTypeDef +from mypy_boto3_s3.client import S3Client + from concurrent.futures import ThreadPoolExecutor -def upload_to_s3_if_not_exists(s3, bucket, key, file): +class ImageInfo(TypedDict): + file: str + label: str + system: str + boot_mode: BootModeValuesType + format: str + + +def upload_to_s3_if_not_exists(s3: S3Client, bucket: str, key: str, file: str) -> None: """ Upload file to S3 if it doesn't exist yet @@ -17,13 +33,15 @@ def upload_to_s3_if_not_exists(s3, bucket, key, file): try: logging.info(f"Checking if s3://{bucket}/{key} exists") s3.head_object(Bucket=bucket, Key=key) - except botocore.exceptions.ClientError as e: + except botocore.exceptions.ClientError: logging.info(f"Uploading {file} to s3://{bucket}/{key}") s3.upload_file(file, bucket, key) s3.get_waiter("object_exists").wait(Bucket=bucket, Key=key) -def import_snapshot(ec2, s3_bucket, s3_key, image_format): +def import_snapshot( + ec2: EC2Client, s3_bucket: str, s3_key: str, image_format: str +) -> str: """ Import snapshot from S3 and wait for it to finish @@ -55,11 +73,19 @@ def import_snapshot(ec2, s3_bucket, s3_key, image_format): ImportTaskIds=[snapshot_import_task["ImportTaskId"]] ) assert len(snapshot_import_tasks["ImportSnapshotTasks"]) != 0 - snapshot_import_task = snapshot_import_tasks["ImportSnapshotTasks"][0] - return snapshot_import_task["SnapshotTaskDetail"]["SnapshotId"] - - -def register_image_if_not_exists(ec2, image_name, image_info, snapshot_id, public): + snapshot_import_task_2 = snapshot_import_tasks["ImportSnapshotTasks"][0] + assert "SnapshotTaskDetail" in snapshot_import_task_2 + assert "SnapshotId" in snapshot_import_task_2["SnapshotTaskDetail"] + return snapshot_import_task_2["SnapshotTaskDetail"]["SnapshotId"] + + +def register_image_if_not_exists( + ec2: EC2Client, + image_name: str, + image_info: ImageInfo, + snapshot_id: str, + public: bool, +) -> str: """ Register image if it doesn't exist yet @@ -69,8 +95,12 @@ def register_image_if_not_exists(ec2, image_name, image_info, snapshot_id, publi Owners=["self"], Filters=[{"Name": "name", "Values": [image_name]}] ) if len(describe_images["Images"]) != 0: + assert len(describe_images["Images"]) == 1 + assert "ImageId" in describe_images["Images"][0] image_id = describe_images["Images"][0]["ImageId"] else: + architecture: Literal["x86_64", "arm64"] + assert "system" in image_info if image_info["system"] == "x86_64-linux": architecture = "x86_64" elif image_info["system"] == "aarch64-linux": @@ -78,9 +108,7 @@ def register_image_if_not_exists(ec2, image_name, image_info, snapshot_id, publi else: raise Exception("Unknown system: " + image_info["system"]) - logging.info( - f"Registering image {image_name} with snapshot {snapshot_id}") - tpmsupport = {} + logging.info(f"Registering image {image_name} with snapshot {snapshot_id}") # TODO(arianvp): Not all instance types support TPM 2.0 yet. We should # upload two images, one with and one without TPM 2.0 support. @@ -106,7 +134,6 @@ def register_image_if_not_exists(ec2, image_name, image_info, snapshot_id, publi EnaSupport=True, ImdsSupport="v2.0", SriovNetSupport="simple", - **tpmsupport ) image_id = register_image["ImageId"] @@ -121,7 +148,13 @@ def register_image_if_not_exists(ec2, image_name, image_info, snapshot_id, publi return image_id -def copy_image_to_regions(image_id, image_name, source_region, target_regions, public): +def copy_image_to_regions( + image_id: str, + image_name: str, + source_region: str, + target_regions: Iterable[RegionTypeDef], + public: bool, +) -> dict[str, str]: """ Copy image to all target regions @@ -131,7 +164,9 @@ def copy_image_to_regions(image_id, image_name, source_region, target_regions, p as the client_token for the copy_image task """ - def copy_image(image_id, image_name, source_region, target_region_name): + def copy_image( + image_id: str, image_name: str, source_region: str, target_region_name: str + ) -> tuple[str, str]: """ Copy image to target_region @@ -142,7 +177,7 @@ def copy_image(image_id, image_name, source_region, target_region_name): script a few months later? """ - ec2r = boto3.client("ec2", region_name=target_region_name) + ec2r: EC2Client = boto3.client("ec2", region_name=target_region_name) logging.info( f"Copying image {image_id} from {source_region} to {target_region_name}" ) @@ -152,8 +187,7 @@ def copy_image(image_id, image_name, source_region, target_region_name): Name=image_name, ClientToken=image_id, ) - ec2r.get_waiter("image_available").wait( - ImageIds=[copy_image["ImageId"]]) + ec2r.get_waiter("image_available").wait(ImageIds=[copy_image["ImageId"]]) logging.info( f"Finished image {image_id} from {source_region} to {target_region_name} {copy_image['ImageId']}" ) @@ -167,27 +201,35 @@ def copy_image(image_id, image_name, source_region, target_region_name): return (target_region_name, copy_image["ImageId"]) with ThreadPoolExecutor(max_workers=32) as executor: - image_ids = dict( - executor.map( - lambda target_region: copy_image( - image_id, image_name, source_region, target_region["RegionName"] - ), - target_regions, + + def _copy_image(target_region: RegionTypeDef) -> tuple[str, str]: + assert "RegionName" in target_region + return copy_image( + image_id, image_name, source_region, target_region["RegionName"] ) - ) + + image_ids = dict(executor.map(_copy_image, target_regions)) image_ids[source_region] = image_id return image_ids -def upload_ami(image_info, s3_bucket, copy_to_regions, prefix, run_id, public): +def upload_ami( + image_info: ImageInfo, + s3_bucket: str, + copy_to_regions: bool, + prefix: str, + run_id: str, + public: bool, +) -> dict[str, str]: """ Upload NixOS AMI to AWS and return the image ids for each region This function is idempotent because all the functions it calls are idempotent. """ - ec2 = boto3.client("ec2") - s3 = boto3.client("s3") + + ec2: EC2Client = boto3.client("ec2") + s3: S3Client = boto3.client("s3") image_file = image_info["file"] label = image_info["label"] @@ -200,31 +242,33 @@ def upload_ami(image_info, s3_bucket, copy_to_regions, prefix, run_id, public): snapshot_id = import_snapshot(ec2, s3_bucket, s3_key, image_format) image_id = register_image_if_not_exists( - ec2, image_name, image_info, snapshot_id, public) + ec2, image_name, image_info, snapshot_id, public + ) - regions = filter(lambda x: x["RegionName"] != - ec2.meta.region_name, ec2.describe_regions()["Regions"]) + regions = filter( + lambda x: x.get("RegionName") != ec2.meta.region_name, + ec2.describe_regions()["Regions"], + ) - image_ids = {} + image_ids: dict[str, str] = {} image_ids[ec2.meta.region_name] = image_id if copy_to_regions: image_ids.update( - copy_image_to_regions(image_id, image_name, - ec2.meta.region_name, regions, public) + copy_image_to_regions( + image_id, image_name, ec2.meta.region_name, regions, public + ) ) - + return image_ids -def main(): +def main() -> None: import argparse parser = argparse.ArgumentParser(description="Upload NixOS AMI to AWS") - parser.add_argument( - "--image-info", help="Path to image info", required=True) - parser.add_argument( - "--s3-bucket", help="S3 bucket to upload to", required=True) + parser.add_argument("--image-info", help="Path to image info", required=True) + parser.add_argument("--s3-bucket", help="S3 bucket to upload to", required=True) parser.add_argument("--debug", action="store_true") parser.add_argument("--cleanup", action="store_true") parser.add_argument("--copy-to-regions", action="store_true") @@ -236,15 +280,17 @@ def main(): level = logging.DEBUG if args.debug else logging.INFO logging.basicConfig(level=level) - sts = boto3.client("sts") - logging.info(sts.get_caller_identity()) - with open(args.image_info, "r") as f: image_info = json.load(f) image_ids = {} image_ids = upload_ami( - image_info, args.s3_bucket, args.copy_to_regions, args.prefix, args.run_id, args.public + image_info, + args.s3_bucket, + args.copy_to_regions, + args.prefix, + args.run_id, + args.public, ) print(json.dumps(image_ids))