Skip to content

Commit

Permalink
Mypy (#120)
Browse files Browse the repository at this point in the history
* Add mypy and black

* stubs

* Type upload-ami

* Fix all type errors

* Check mypy and black

* Fix
  • Loading branch information
arianvp authored Apr 19, 2024
1 parent 7d37252 commit e0faf6e
Show file tree
Hide file tree
Showing 10 changed files with 223 additions and 87 deletions.
2 changes: 2 additions & 0 deletions upload-ami/.envrc
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#!/bin/sh
use flake .#upload-ami
40 changes: 35 additions & 5 deletions upload-ami/default.nix
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,49 @@
, lib
}:

let pyproject = builtins.fromTOML (builtins.readFile ./pyproject.toml);
let
pyproject = builtins.fromTOML (builtins.readFile ./pyproject.toml);
# str -> { name: str, extras: [str] }
parseDependency = dep:
let
parts = lib.splitString "[" dep;
name = lib.head parts;
extras = lib.optionals (lib.length parts > 1)
(lib.splitString "," (lib.removeSuffix "]" (builtins.elemAt parts 1)));
in
{ name = name; extras = extras; };

# { name: str, extras: [str] } -> [package]
resolvePackages = dep:
let
inherit (parseDependency dep) name extras;
package = python3Packages.${name};
optionalPackages = lib.flatten (map (name: package.optional-dependencies.${name}) extras);
in
[ package ] ++ optionalPackages;


in
buildPythonApplication {
pname = pyproject.project.name;
version = pyproject.project.version;
src = ./.;
pyproject = true;

nativeBuildInputs =
map (name: python3Packages.${name}) pyproject.build-system.requires;
map (name: python3Packages.${name}) pyproject.build-system.requires ++ [
python3Packages.mypy
python3Packages.black
];

propagatedBuildInputs = lib.flatten (map resolvePackages pyproject.project.dependencies);

propagatedBuildInputs =
map (name: python3Packages.${name}) pyproject.project.dependencies;
checkPhase = ''
mypy src
black --check src
'';

passthru.pyproject = pyproject;
passthru.parseDependency = parseDependency;
passthru.resolvePackages = resolvePackages;

}
7 changes: 4 additions & 3 deletions upload-ami/pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ name = "upload-ami"
version = "0.1.0"
dependencies = [
"boto3",
"botocore",
"mypy-boto3-ec2",
"mypy-boto3-s3",
"boto3-stubs[ec2,s3,sts,account,service-quotas]",
"botocore-stubs",
]
[project.scripts]
upload-ami = "upload_ami.upload_ami:main"
Expand All @@ -18,3 +17,5 @@ disable-image-block-public-access = "upload_ami.disable_image_block_public_acces
enable-regions = "upload_ami.enable_regions:main"
request-public-ami-quota-increase = "upload_ami.request_public_ami_quota_increase:main"
describe-images = "upload_ami.describe_images:main"
[tool.mypy]
strict=true
13 changes: 9 additions & 4 deletions upload-ami/src/upload_ami/describe_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,28 @@
import boto3
import json

def main():
from mypy_boto3_ec2 import EC2Client


def main() -> None:
logging.basicConfig(level=logging.INFO)
ec2 = boto3.client("ec2")
ec2: EC2Client = boto3.client("ec2")
regions = ec2.describe_regions()["Regions"]

images = {}

for region in regions:
assert "RegionName" in region
ec2r = boto3.client("ec2", region_name=region["RegionName"])

result = ec2r.describe_images(
Owners=["self"],
ExecutableUsers=["all"],
)
images[region["RegionName"]] = result

print(json.dumps(images, indent=2))


if __name__ == "__main__":
main()
main()
26 changes: 18 additions & 8 deletions upload-ami/src/upload_ami/disable_image_block_public_access.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,35 @@
import logging
import time
from concurrent.futures import ThreadPoolExecutor
from mypy_boto3_ec2.type_defs import RegionTypeDef

def main():

def main() -> None:
logging.basicConfig(level=logging.INFO)
ec2 = boto3.client("ec2")
regions = ec2.describe_regions()["Regions"]

def disable_image_block_public_access(region):
logging.info("disabling image block public access in %s. Can take some minutes to apply", region["RegionName"])
def disable_image_block_public_access(region: RegionTypeDef) -> None:
assert "RegionName" in region
logging.info(
"disabling image block public access in %s. Can take some minutes to apply",
region["RegionName"],
)
ec2 = boto3.client("ec2", region_name=region["RegionName"])
ec2.disable_image_block_public_access()

while True:
state = ec2.get_image_block_public_access_state()["ImageBlockPublicAccessState"]
state = ec2.get_image_block_public_access_state()[
"ImageBlockPublicAccessState"
]
if state == "unblocked":
break
logging.info("waiting for image block public access state %s to be unblocked in %s", state, region["RegionName"])
logging.info(
"waiting for image block public access state %s to be unblocked in %s",
state,
region["RegionName"],
)
time.sleep(30)

with ThreadPoolExecutor(max_workers=len(regions)) as executor:
executor.map(disable_image_block_public_access, regions)


6 changes: 4 additions & 2 deletions upload-ami/src/upload_ami/enable_regions.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import boto3
from mypy_boto3_account import AccountClient
import logging


def main():
def main() -> None:
"""
Enable all regions that are disabled
Due to rate limiting, you might need to run this multiple times.
"""
logging.basicConfig(level=logging.INFO)
account = boto3.client("account")
account: AccountClient = boto3.client("account")
pages = account.get_paginator("list_regions").paginate(
RegionOptStatusContains=["DISABLED"]
)
for page in pages:
for region in page["Regions"]:
assert "RegionName" in region
logging.info(f"enabling region {region['RegionName']}")
account.enable_region(RegionName=region["RegionName"])
15 changes: 12 additions & 3 deletions upload-ami/src/upload_ami/nuke.py
Original file line number Diff line number Diff line change
@@ -1,24 +1,33 @@
import logging
import boto3
from mypy_boto3_ec2 import EC2Client


def main():
def main() -> None:
logging.basicConfig(level=logging.INFO)
ec2 = boto3.client("ec2", region_name="us-east-1")
ec2: EC2Client = boto3.client("ec2", region_name="us-east-1")

regions = ec2.describe_regions()["Regions"]

for region in regions:
assert "RegionName" in region
ec2r = boto3.client("ec2", region_name=region["RegionName"])
logging.info(f"Nuking {region['RegionName']}")
snapshots = ec2r.describe_snapshots(OwnerIds=["self"])
for snapshot in snapshots["Snapshots"]:

assert "SnapshotId" in snapshot
images = ec2r.describe_images(
Owners=["self"],
Filters=[{"Name": "block-device-mapping.snapshot-id", "Values": [snapshot["SnapshotId"]]}],
Filters=[
{
"Name": "block-device-mapping.snapshot-id",
"Values": [snapshot["SnapshotId"]],
}
],
)
for image in images["Images"]:
assert "ImageId" in image
logging.info(f"Deregistering {image['ImageId']}")
ec2r.deregister_image(ImageId=image["ImageId"])

Expand Down
49 changes: 35 additions & 14 deletions upload-ami/src/upload_ami/request_public_ami_quota_increase.py
Original file line number Diff line number Diff line change
@@ -1,33 +1,54 @@
from ast import List
from typing import Iterator
import boto3
import logging

from mypy_boto3_ec2 import EC2Client
from mypy_boto3_service_quotas import ServiceQuotasClient
from mypy_boto3_service_quotas.type_defs import (
ListServiceQuotasResponseTypeDef,
ServiceQuotaTypeDef,
)

def get_public_ami_service_quota(servicequotas):
return next(servicequotas
.get_paginator('list_service_quotas')
.paginate(ServiceCode="ec2")
.search("Quotas[?QuotaName=='Public AMIs']"))

def get_public_ami_service_quota(
servicequotas: ServiceQuotasClient,
) -> ServiceQuotaTypeDef:
paginator = servicequotas.get_paginator("list_service_quotas")
searched: Iterator[ServiceQuotaTypeDef] = paginator.paginate(
ServiceCode="ec2"
).search("Quotas[?QuotaName=='Public AMIs']")
return next(searched)

def main():

def main() -> None:
import argparse

parser = argparse.ArgumentParser()
parser.add_argument("--desired-value", type=int, default=1000)
args = parser.parse_args()
logging.basicConfig(level=logging.INFO)
ec2 = boto3.client("ec2")
ec2: EC2Client = boto3.client("ec2")
regions = ec2.describe_regions()["Regions"]
for region in regions:
servicequotas = boto3.client(
"service-quotas", region_name=region["RegionName"])
assert "RegionName" in region
servicequotas: ServiceQuotasClient = boto3.client(
"service-quotas", region_name=region["RegionName"]
)
service_quota = get_public_ami_service_quota(servicequotas)
logging.info(
f"Quota for {region['RegionName']} is {service_quota['Value']}")

assert "Value" in service_quota
logging.info(f"Quota for {region['RegionName']} is {service_quota['Value']}")
try:
if service_quota['Value'] < args.desired_value:
if service_quota["Value"] < args.desired_value:
logging.info(
f"Requesting quota increase for {region['RegionName']} from {service_quota['Value']} to {args.desired_value}")
servicequotas.request_service_quota_increase( ServiceCode="ec2", QuotaCode=service_quota['QuotaCode'], DesiredValue=args.desired_value)
f"Requesting quota increase for {region['RegionName']} from {service_quota['Value']} to {args.desired_value}"
)
servicequotas.request_service_quota_increase(
ServiceCode="ec2",
QuotaCode=service_quota["QuotaCode"],
DesiredValue=args.desired_value,
)
except Exception as e:
logging.warn(e)

Expand Down
18 changes: 14 additions & 4 deletions upload-ami/src/upload_ami/smoke_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,19 @@
import argparse
import logging

from mypy_boto3_ec2 import EC2Client
from mypy_boto3_ec2.literals import InstanceTypeType

def smoke_test(image_id, run_id, cancel):
ec2 = boto3.client("ec2")

def smoke_test(image_id: str, run_id: str, cancel: bool) -> None:
ec2: EC2Client = boto3.client("ec2")

images = ec2.describe_images(Owners=["self"], ImageIds=[image_id])
assert len(images["Images"]) == 1
image = images["Images"][0]
assert "Architecture" in image
architecture = image["Architecture"]
instance_type: InstanceTypeType
if architecture == "x86_64":
instance_type = "t3.nano"
elif architecture == "arm64":
Expand All @@ -29,6 +34,7 @@ def smoke_test(image_id, run_id, cancel):
)

instance = run_instances["Instances"][0]
assert "InstanceId" in instance
instance_id = instance["InstanceId"]

try:
Expand All @@ -46,7 +52,9 @@ def smoke_test(image_id, run_id, cancel):
logging.info(
f"Waiting for console output to become available ({tries} tries left)"
)
console_output = ec2.get_console_output(InstanceId=instance_id, Latest=True)
console_output = ec2.get_console_output(
InstanceId=instance_id, Latest=True
)
output = console_output.get("Output")
tries -= 1
logging.info(f"Console output: {output}")
Expand All @@ -55,12 +63,14 @@ def smoke_test(image_id, run_id, cancel):
raise
finally:
logging.info(f"Terminating instance {instance_id}")
assert "State" in instance
assert "Name" in instance["State"]
if instance["State"]["Name"] != "terminated":
ec2.terminate_instances(InstanceIds=[instance_id])
ec2.get_waiter("instance_terminated").wait(InstanceIds=[instance_id])


def main():
def main() -> None:
logging.basicConfig(level=logging.INFO)

parser = argparse.ArgumentParser()
Expand Down
Loading

0 comments on commit e0faf6e

Please sign in to comment.