From 46ac91cc68f8aaf45380f60a412f1a9810bbda71 Mon Sep 17 00:00:00 2001 From: akutz Date: Mon, 1 May 2017 20:51:32 -0500 Subject: [PATCH] Enable Override of EBS Region, AZ, & Endpoint Props This patch introduces the ability for the EBS driver to override the client's provide region, availability zone, and endpoint properties by specifying them in the server-side config with a suffix of !important (like CSS). --- drivers/storage/ebs/ebs.go | 6 ++ drivers/storage/ebs/ebs_config_compat.go | 9 +++ drivers/storage/ebs/storage/ebs_storage.go | 64 ++++++++++++++++++---- 3 files changed, 69 insertions(+), 10 deletions(-) diff --git a/drivers/storage/ebs/ebs.go b/drivers/storage/ebs/ebs.go index 8ba4ea55..74f8fe69 100644 --- a/drivers/storage/ebs/ebs.go +++ b/drivers/storage/ebs/ebs.go @@ -53,6 +53,9 @@ const ( // Endpoint is a key constant. Endpoint = "endpoint" + // AvaiZone is a key constant. + AvaiZone = "availabilityZone" + // MaxRetries is a key constant. MaxRetries = "maxRetries" @@ -92,6 +95,7 @@ func init() { r.Key(gofig.String, "", "", "", Name+"."+SecretKey) r.Key(gofig.String, "", "", "", Name+"."+Region) r.Key(gofig.String, "", "", "", Name+"."+Endpoint) + r.Key(gofig.String, "", "", "", Name+"."+AvaiZone) r.Key(gofig.Int, "", DefaultMaxRetries, "", Name+"."+MaxRetries) r.Key(gofig.String, "", "", "Tag prefix for EBS naming", Name+"."+Tag) r.Key(gofig.String, "", "", "", Name+"."+KmsKeyID) @@ -106,6 +110,7 @@ func init() { r.Key(gofig.String, "", "", "", NameEC2+"."+SecretKey) r.Key(gofig.String, "", "", "", NameEC2+"."+Region) r.Key(gofig.String, "", "", "", NameEC2+"."+Endpoint) + r.Key(gofig.String, "", "", "", NameEC2+"."+AvaiZone) r.Key(gofig.Int, "", DefaultMaxRetries, "", NameEC2+"."+MaxRetries) r.Key(gofig.String, "", "", "Tag prefix for EBS naming", NameEC2+"."+Tag) r.Key(gofig.String, "", "", "", NameEC2+"."+KmsKeyID) @@ -114,6 +119,7 @@ func init() { r.Key(gofig.String, "", "", "", NameAWS+"."+SecretKey) r.Key(gofig.String, "", "", "", NameAWS+"."+Region) r.Key(gofig.String, "", "", "", NameAWS+"."+Endpoint) + r.Key(gofig.String, "", "", "", NameAWS+"."+AvaiZone) r.Key(gofig.Int, "", DefaultMaxRetries, "", NameAWS+"."+MaxRetries) r.Key(gofig.String, "", "", "Tag prefix for EBS naming", NameAWS+"."+Tag) r.Key(gofig.String, "", "", "", NameAWS+"."+KmsKeyID) diff --git a/drivers/storage/ebs/ebs_config_compat.go b/drivers/storage/ebs/ebs_config_compat.go index 56453e89..d2409e6e 100644 --- a/drivers/storage/ebs/ebs_config_compat.go +++ b/drivers/storage/ebs/ebs_config_compat.go @@ -20,6 +20,9 @@ const ( // ConfigEBSRegion is a config key. ConfigEBSRegion = ConfigEBS + "." + Region + // ConfigEBSAvaiZone is a config key. + ConfigEBSAvaiZone = ConfigEBS + "." + AvaiZone + // ConfigEBSEndpoint is a config key. ConfigEBSEndpoint = ConfigEBS + "." + Endpoint @@ -44,6 +47,9 @@ const ( // ConfigEC2SecretKey is a config key. ConfigEC2SecretKey = ConfigEC2 + "." + SecretKey + // ConfigEC2AvaiZone is a config key. + ConfigEC2AvaiZone = ConfigEC2 + "." + AvaiZone + // ConfigEC2Region is a config key. ConfigEC2Region = ConfigEC2 + "." + Region @@ -71,6 +77,9 @@ const ( // ConfigAWSSecretKey is a config key. ConfigAWSSecretKey = ConfigAWS + "." + SecretKey + // ConfigAWSAvaiZone is a config key. + ConfigAWSAvaiZone = ConfigAWS + "." + AvaiZone + // ConfigAWSRegion is a config key. ConfigAWSRegion = ConfigAWS + "." + Region diff --git a/drivers/storage/ebs/storage/ebs_storage.go b/drivers/storage/ebs/storage/ebs_storage.go index 8fea1584..c6589101 100644 --- a/drivers/storage/ebs/storage/ebs_storage.go +++ b/drivers/storage/ebs/storage/ebs_storage.go @@ -6,6 +6,7 @@ import ( "crypto/md5" "fmt" "hash" + "regexp" "strings" "sync" "time" @@ -46,6 +47,7 @@ type driver struct { config gofig.Config region *string endpoint *string + avaiZone *string maxRetries *int accessKey string kmsKeyID string @@ -78,6 +80,9 @@ func (d *driver) Init(context types.Context, config gofig.Config) error { ebs.BackCompat(config) d.config = config d.accessKey = d.getAccessKey() + if v := d.getAvaiZone(); v != "" { + d.avaiZone = &v + } if v := d.getRegion(); v != "" { d.region = &v } @@ -126,18 +131,27 @@ func (d *driver) Login(ctx types.Context) (interface{}, error) { defer sessionsL.Unlock() var ( - endpoint *string - ckey string - hkey = md5.New() - akey = d.accessKey - region = d.mustRegion(ctx) + endpoint *string + endpointIsImportant bool + ckey string + hkey = md5.New() + akey = d.accessKey + region = d.mustRegion(ctx) ) - if region != nil { - szEndpint := fmt.Sprintf("ec2.%s.amazonaws.com", *region) - endpoint = &szEndpint - } else { - endpoint = d.endpoint + if d.endpoint != nil { + if v, ok := isImportant(*d.endpoint); ok { + endpoint = &v + endpointIsImportant = true + } + } + if !endpointIsImportant { + if region != nil { + szEndpint := fmt.Sprintf("ec2.%s.amazonaws.com", *region) + endpoint = &szEndpint + } else { + endpoint = d.endpoint + } } writeHkey(hkey, region) @@ -213,7 +227,22 @@ func mustInstanceIDID(ctx types.Context) *string { return &context.MustInstanceID(ctx).ID } +var rxImportant = regexp.MustCompile(`^(?i)(.+?)\s+!important$`) + +func isImportant(s string) (string, bool) { + m := rxImportant.FindStringSubmatch(s) + if len(m) == 0 { + return "", false + } + return m[1], true +} + func (d *driver) mustRegion(ctx types.Context) *string { + if d.region != nil { + if region, ok := isImportant(*d.region); ok { + return ®ion + } + } if iid, ok := context.InstanceID(ctx); ok { if v, ok := iid.Fields[ebs.InstanceIDFieldRegion]; ok && v != "" { return &v @@ -223,6 +252,11 @@ func (d *driver) mustRegion(ctx types.Context) *string { } func (d *driver) mustAvailabilityZone(ctx types.Context) *string { + if d.avaiZone != nil { + if az, ok := isImportant(*d.avaiZone); ok { + return &az + } + } if iid, ok := context.InstanceID(ctx); ok { if v, ok := iid.Fields[ebs.InstanceIDFieldAvailabilityZone]; ok { if v != "" { @@ -1389,6 +1423,16 @@ func (d *driver) getRegion() string { return d.config.GetString(ebs.ConfigEC2Region) } +func (d *driver) getAvaiZone() string { + if az := d.config.GetString(ebs.ConfigEBSAvaiZone); az != "" { + return az + } + if az := d.config.GetString(ebs.ConfigAWSAvaiZone); az != "" { + return az + } + return d.config.GetString(ebs.ConfigEC2AvaiZone) +} + func (d *driver) getEndpoint() string { if endpoint := d.config.GetString(ebs.ConfigEBSEndpoint); endpoint != "" { return endpoint