Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

correct role handling #2

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 69 additions & 41 deletions run.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,14 +60,45 @@
# AUXILIARY FUNCTIONS
#################################

def get_aws_credentials(AWS_PROFILE):
session = boto3.Session(profile_name=AWS_PROFILE)
credentials = session.get_credentials()
return credentials.access_key, credentials.secret_key

def generate_task_definition(AWS_PROFILE):
taskRoleArn = False
task_definition = TASK_DEFINITION.copy()
key, secret = get_aws_credentials(AWS_PROFILE)

config = configparser.ConfigParser()
config.read(f"{os.environ['HOME']}/.aws/config")

if config.has_section(AWS_PROFILE):
profile_name = AWS_PROFILE
elif config.has_section(f'profile {AWS_PROFILE}'):
profile_name = f'profile {AWS_PROFILE}'
else:
print ('Problem handling profile')

if config.has_option(profile_name, 'role_arn'):
print ("Using role for credentials", config[profile_name]['role_arn'])
taskRoleArn = config[profile_name]['role_arn']
else:
if config.has_option(profile_name, 'source_profile'):
creds = configparser.ConfigParser()
creds.read(f"{os.environ['HOME']}/.aws/credentials")
source_profile = config[profile_name]['source_profile']
aws_access_key = creds[source_profile]['aws_access_key_id']
aws_secret_key = creds[source_profile]['aws_secret_access_key']
elif config.has_option(profile_name, 'aws_access_key_id'):
aws_access_key = config[profile_name]['aws_access_key_id']
aws_secret_key = config[profile_name]['aws_secret_access_key']
else:
print ("Problem getting credentials")
task_definition['containerDefinitions'][0]['environment'] += [
{
"name": "AWS_ACCESS_KEY_ID",
"value": aws_access_key
},
{
"name": "AWS_SECRET_ACCESS_KEY",
"value": aws_secret_key
}]

sqs = boto3.client('sqs')
queue_name = get_queue_url(sqs)
task_definition['containerDefinitions'][0]['environment'] += [
Expand All @@ -79,14 +110,6 @@ def generate_task_definition(AWS_PROFILE):
'name': 'SQS_QUEUE_URL',
'value': queue_name
},
{
"name": "AWS_ACCESS_KEY_ID",
"value": key
},
{
"name": "AWS_SECRET_ACCESS_KEY",
"value": secret
},
{
"name": "AWS_BUCKET",
"value": AWS_BUCKET
Expand Down Expand Up @@ -119,8 +142,13 @@ def generate_task_definition(AWS_PROFILE):
return task_definition

def update_ecs_task_definition(ecs, ECS_TASK_NAME, AWS_PROFILE):
task_definition = generate_task_definition(AWS_PROFILE)
ecs.register_task_definition(family=ECS_TASK_NAME,containerDefinitions=task_definition['containerDefinitions'])
task_definition, taskRoleArn = generate_task_definition(AWS_PROFILE)
if not taskRoleArn:
ecs.register_task_definition(family=ECS_TASK_NAME,containerDefinitions=task_definition['containerDefinitions'])
elif taskRoleArn:
ecs.register_task_definition(family=ECS_TASK_NAME,containerDefinitions=task_definition['containerDefinitions'],taskRoleArn=taskRoleArn)
else:
print('Mistake in handling role for Task Definition.')
print('Task definition registered')

def get_or_create_cluster(ecs):
Expand Down Expand Up @@ -179,14 +207,14 @@ def killdeadAlarms(fleetId,monitorapp,ec2,cloud):
todel.append(eachevent['EventInformation']['InstanceId'])

existing_alarms = [x['AlarmName'] for x in cloud.describe_alarms(AlarmNamePrefix=monitorapp)['MetricAlarms']]

for eachmachine in todel:
monitorname = monitorapp+'_'+eachmachine
if monitorname in existing_alarms:
cloud.delete_alarms(AlarmNames=[monitorname])
print('Deleted', monitorname, 'if it existed')
time.sleep(3)

print('Old alarms deleted')

def generateECSconfig(ECS_CLUSTER,APP_NAME,AWS_BUCKET,s3client):
Expand Down Expand Up @@ -232,7 +260,7 @@ def removequeue(queueName):
for eachUrl in queueoutput["QueueUrls"]:
if eachUrl.split('/')[-1] == queueName:
queueUrl=eachUrl

sqs.delete_queue(QueueUrl=queueUrl)

def deregistertask(taskName, ecs):
Expand Down Expand Up @@ -262,7 +290,7 @@ def downscaleSpotFleet(queue, spotFleetID, ec2, manual=False):

def export_logs(logs, loggroupId, starttime, bucketId):
result = logs.create_export_task(taskName = loggroupId, logGroupName = loggroupId, fromTime = int(starttime), to = int(time.time()*1000), destination = bucketId, destinationPrefix = 'exportedlogs/'+loggroupId)

logExportId = result['taskId']

while True:
Expand All @@ -285,7 +313,7 @@ def __init__(self,name=None):
self.queue = self.sqs.get_queue_by_name(QueueName=SQS_QUEUE_NAME)
else:
self.queue = self.sqs.get_queue_by_name(QueueName=name)
self.inProcess = -1
self.inProcess = -1
self.pending = -1

def scheduleBatch(self, data):
Expand Down Expand Up @@ -342,7 +370,7 @@ def submitJob():

# Step 1: Read the job configuration file
jobInfo = loadConfig(sys.argv[2])
templateMessage = {'Metadata': '',
templateMessage = {'Metadata': '',
'output_file_location': jobInfo["output_file_location"],
'shared_metadata': jobInfo["shared_metadata"]
}
Expand All @@ -357,7 +385,7 @@ def submitJob():
print('Job submitted. Check your queue')

#################################
# SERVICE 3: START CLUSTER
# SERVICE 3: START CLUSTER
#################################

def startCluster():
Expand All @@ -376,7 +404,7 @@ def startCluster():
spotfleetConfig['SpotPrice'] = '%.2f' %MACHINE_PRICE
DOCKER_BASE_SIZE = int(round(float(EBS_VOL_SIZE)/int(TASKS_PER_MACHINE))) - 2
userData=generateUserData(ecsConfigFile,DOCKER_BASE_SIZE)
for LaunchSpecification in range(0,len(spotfleetConfig['LaunchSpecifications'])):
for LaunchSpecification in range(0,len(spotfleetConfig['LaunchSpecifications'])):
spotfleetConfig['LaunchSpecifications'][LaunchSpecification]["UserData"]=userData
spotfleetConfig['LaunchSpecifications'][LaunchSpecification]['BlockDeviceMappings'][1]['Ebs']["VolumeSize"]= EBS_VOL_SIZE
spotfleetConfig['LaunchSpecifications'][LaunchSpecification]['InstanceType'] = MACHINE_TYPE[LaunchSpecification]
Expand All @@ -399,7 +427,7 @@ def startCluster():
createMonitor.write('"MONITOR_LOG_GROUP_NAME" : "'+LOG_GROUP_NAME+'",\n')
createMonitor.write('"MONITOR_START_TIME" : "'+ starttime+'"}\n')
createMonitor.close()

# Step 4: Create a log group for this app and date if one does not already exist
logclient=boto3.client('logs')
loggroupinfo=logclient.describe_log_groups(logGroupNamePrefix=LOG_GROUP_NAME)
Expand All @@ -410,13 +438,13 @@ def startCluster():
if LOG_GROUP_NAME+'_perInstance' not in groupnames:
logclient.create_log_group(logGroupName=LOG_GROUP_NAME+'_perInstance')
logclient.put_retention_policy(logGroupName=LOG_GROUP_NAME+'_perInstance', retentionInDays=60)

# Step 5: update the ECS service to be ready to inject docker containers in EC2 instances
print('Updating service')
ecs = boto3.client('ecs')
ecs.update_service(cluster=ECS_CLUSTER, service=APP_NAME+'Service', desiredCount=CLUSTER_MACHINES*TASKS_PER_MACHINE)
print('Service updated.')
print('Service updated.')

# Step 6: Monitor the creation of the instances until all are present
status = ec2client.describe_spot_fleet_instances(SpotFleetRequestId=requestInfo['SpotFleetRequestId'])
#time.sleep(15) # This is now too fast, so sometimes the spot fleet request history throws an error!
Expand All @@ -436,7 +464,7 @@ def startCluster():
return
ec2client.cancel_spot_fleet_requests(SpotFleetRequestIds=[requestInfo['SpotFleetRequestId']], TerminateInstances=True)
return

# If everything seems good, just bide your time until you're ready to go
print('.')
time.sleep(20)
Expand All @@ -445,39 +473,39 @@ def startCluster():
print('Spot fleet successfully created. Your job should start in a few minutes.')

#################################
# SERVICE 4: MONITOR JOB
# SERVICE 4: MONITOR JOB
#################################

def monitor(cheapest=False):
if len(sys.argv) < 3:
print('Use: run.py monitor spotFleetIdFile')
sys.exit()

if '.json' not in sys.argv[2]:
print('Use: run.py monitor spotFleetIdFile')
sys.exit()

if len(sys.argv) == 4:
cheapest = sys.argv[3]

monitorInfo = loadConfig(sys.argv[2])
monitorcluster=monitorInfo["MONITOR_ECS_CLUSTER"]
monitorapp=monitorInfo["MONITOR_APP_NAME"]
fleetId=monitorInfo["MONITOR_FLEET_ID"]
queueId=monitorInfo["MONITOR_QUEUE_NAME"]

ec2 = boto3.client('ec2')
cloud = boto3.client('cloudwatch')
cloud = boto3.client('cloudwatch')

# Optional Step 0 - decide if you're going to be cheap rather than fast. This means that you'll get 15 minutes
# from the start of the monitor to get as many machines as you get, and then it will set the requested number to 1.
# Benefit: this will always be the cheapest possible way to run, because if machines die they'll die fast,
# Potential downside- if machines are at low availability when you start to run, you'll only ever get a small number
# Benefit: this will always be the cheapest possible way to run, because if machines die they'll die fast,
# Potential downside- if machines are at low availability when you start to run, you'll only ever get a small number
# of machines (as opposed to getting more later when they become available), so it might take VERY long to run if that happens.
if cheapest:
queue = JobQueue(name=queueId)
startcountdown = time.time()
while queue.pendingLoad():
while queue.pendingLoad():
if time.time() - startcountdown > 900:
downscaleSpotFleet(queue, fleetId, ec2, manual=1)
break
Expand All @@ -486,7 +514,7 @@ def monitor(cheapest=False):
# Step 1: Create job and count messages periodically
queue = JobQueue(name=queueId)
while queue.pendingLoad():
#Once an hour (except at midnight) check for terminated machines and delete their alarms.
#Once an hour (except at midnight) check for terminated machines and delete their alarms.
#This is slooooooow, which is why we don't just do it at the end
curtime=datetime.datetime.now().strftime('%H%M')
if curtime[-2:]=='00':
Expand All @@ -499,7 +527,7 @@ def monitor(cheapest=False):
if curtime[-1:]=='9':
downscaleSpotFleet(queue, fleetId, ec2)
time.sleep(MONITOR_TIME)

# Step 2: When no messages are pending, stop service
# Reload the monitor info, because for long jobs new fleets may have been started, etc
monitorInfo = loadConfig(sys.argv[2])
Expand All @@ -509,7 +537,7 @@ def monitor(cheapest=False):
queueId=monitorInfo["MONITOR_QUEUE_NAME"]
bucketId=monitorInfo["MONITOR_BUCKET_NAME"]
loggroupId=monitorInfo["MONITOR_LOG_GROUP_NAME"]
starttime=monitorInfo["MONITOR_START_TIME"]
starttime=monitorInfo["MONITOR_START_TIME"]

ecs = boto3.client('ecs')
ecs.update_service(cluster=monitorcluster, service=monitorapp+'Service', desiredCount=0)
Expand Down Expand Up @@ -560,14 +588,14 @@ def monitor(cheapest=False):
print('All export tasks done')

#################################
# MAIN USER INTERACTION
# MAIN USER INTERACTION
#################################

if __name__ == '__main__':
if len(sys.argv) < 2:
print('Use: run.py setup | submitJob | startCluster | monitor')
sys.exit()

if sys.argv[1] == 'setup':
setup()
elif sys.argv[1] == 'submitJob':
Expand Down
10 changes: 4 additions & 6 deletions worker/run-worker.sh
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@ echo "Queue $SQS_QUEUE_URL"
echo "Bucket $AWS_BUCKET"

# 1. CONFIGURE AWS CLI
aws configure set aws_access_key_id $AWS_ACCESS_KEY_ID
aws configure set aws_secret_access_key $AWS_SECRET_ACCESS_KEY
aws configure set default.region $AWS_REGION
MY_INSTANCE_ID=$(curl http://169.254.169.254/latest/meta-data/instance-id)
echo "Instance ID $MY_INSTANCE_ID"
Expand All @@ -17,15 +15,15 @@ aws ec2 create-tags --resources $VOL_0_ID --tags Key=Name,Value=${APP_NAME}Worke
VOL_1_ID=$(aws ec2 describe-instance-attribute --instance-id $MY_INSTANCE_ID --attribute blockDeviceMapping --output text --query BlockDeviceMappings[1].Ebs.[VolumeId])
aws ec2 create-tags --resources $VOL_1_ID --tags Key=Name,Value=${APP_NAME}Worker

# 2. MOUNT S3
# 2. MOUNT S3
echo $AWS_ACCESS_KEY_ID:$AWS_SECRET_ACCESS_KEY > /credentials.txt
chmod 600 /credentials.txt
mkdir -p /home/ubuntu/bucket
mkdir -p /home/ubuntu/local_output
stdbuf -o0 s3fs $AWS_BUCKET /home/ubuntu/bucket -o passwd_file=/credentials.txt
stdbuf -o0 s3fs $AWS_BUCKET /home/ubuntu/bucket -o passwd_file=/credentials.txt

# 3. SET UP ALARMS
aws cloudwatch put-metric-alarm --alarm-name ${APP_NAME}_${MY_INSTANCE_ID} --alarm-actions arn:aws:swf:${AWS_REGION}:${OWNER_ID}:action/actions/AWS_EC2.InstanceId.Terminate/1.0 --statistic Maximum --period 60 --threshold 1 --comparison-operator LessThanThreshold --metric-name CPUUtilization --namespace AWS/EC2 --evaluation-periods 15 --dimensions "Name=InstanceId,Value=${MY_INSTANCE_ID}"
aws cloudwatch put-metric-alarm --alarm-name ${APP_NAME}_${MY_INSTANCE_ID} --alarm-actions arn:aws:swf:${AWS_REGION}:${OWNER_ID}:action/actions/AWS_EC2.InstanceId.Terminate/1.0 --statistic Maximum --period 60 --threshold 1 --comparison-operator LessThanThreshold --metric-name CPUUtilization --namespace AWS/EC2 --evaluation-periods 15 --dimensions "Name=InstanceId,Value=${MY_INSTANCE_ID}"

# 4. DOWNLOAD PLUGIN FILE
wget -P /opt/fiji/Fiji.app/plugins/ $SCRIPT_DOWNLOAD_URL
Expand All @@ -35,4 +33,4 @@ wget -P /opt/fiji/Fiji.app/plugins/ $SCRIPT_DOWNLOAD_URL
python3 instance-monitor.py &

# 6. RUN FIJI WORKER
python3 fiji-worker.py |& tee $k.out
python3 fiji-worker.py |& tee $k.out