diff --git a/registry/registry.go b/registry/registry.go index 89692ce1..04c793a7 100644 --- a/registry/registry.go +++ b/registry/registry.go @@ -68,11 +68,7 @@ func NewInsecure(registryURL, username, password string, logFlag int) (*Registry * error handling this library relies on. */ func WrapTransport(transport http.RoundTripper, url, username, password string) http.RoundTripper { - tokenTransport := &TokenTransport{ - Transport: transport, - Username: username, - Password: password, - } + tokenTransport := NewTokenTransport(transport, username, password) basicAuthTransport := &BasicTransport{ Transport: tokenTransport, URL: url, diff --git a/registry/tokenpool.go b/registry/tokenpool.go new file mode 100644 index 00000000..0be1a391 --- /dev/null +++ b/registry/tokenpool.go @@ -0,0 +1,39 @@ +package registry + +import ( + "strings" + "sync" +) + +type TokenPool struct { + tokens map[string]string + rwm *sync.RWMutex +} + +func NewTokenPool() *TokenPool { + return &TokenPool{ + tokens: make(map[string]string), + rwm: &sync.RWMutex{}, + } +} + +func (t *TokenPool) GetToken(scope string) string { + if scope == "" { + return "" + } + + if _, ok := t.tokens[scope]; ok { + return t.tokens[scope] + } + return "" +} + +func (t *TokenPool) SetToken(scope, token string) { + // repository:gds-eip/eip-api:pull + if l := strings.Split(scope, ":"); len(l) == 3 { + t.rwm.Lock() + t.tokens[l[2]] = token + t.rwm.Unlock() + } + +} diff --git a/registry/tokentransport.go b/registry/tokentransport.go index 05760951..aa673aaa 100644 --- a/registry/tokentransport.go +++ b/registry/tokentransport.go @@ -6,15 +6,44 @@ import ( "io/ioutil" "net/http" "net/url" + "regexp" + "strings" ) type TokenTransport struct { Transport http.RoundTripper Username string Password string + tokens *TokenPool +} + +func NewTokenTransport(transport http.RoundTripper, username string, password string) *TokenTransport { + return &TokenTransport{ + Transport: transport, + Username: username, + Password: password, + tokens: NewTokenPool(), + } +} + +var scopeReg = regexp.MustCompile(`^/v2/([A-Za-z0-9/-]+)/\b(tags|manifests|blobs)\b`) + +func (t *TokenTransport) GetScope(u string) string { + sc := scopeReg.Find([]byte(u)) + if sc == nil { + return "" + } else { + return string(sc) + } } func (t *TokenTransport) RoundTrip(req *http.Request) (*http.Response, error) { + sc := t.GetScope(req.URL.EscapedPath()) + token := t.tokens.GetToken(sc) + if len(token) > 0 { + req.Header.Set("Authorization", fmt.Sprintf("Bearer %s", token)) + } + resp, err := t.Transport.RoundTrip(req) if err != nil { return resp, err @@ -39,6 +68,7 @@ func (t *TokenTransport) authAndRetry(authService *authService, req *http.Reques return authResp, err } + t.tokens.SetToken(authService.Scope, token) retryResp, err := t.retry(req, token) return retryResp, err } @@ -94,7 +124,12 @@ func (authService *authService) Request(username, password string) (*http.Reques q := url.Query() q.Set("service", authService.Service) if authService.Scope != "" { - q.Set("scope", authService.Scope) + sl := strings.Split(authService.Scope, ":") + if len(sl) >= 3 { + q.Set("scope", fmt.Sprintf("%s:%s:pull,push", sl[1], sl[2])) + } else { + q.Set("scope", authService.Scope) + } } url.RawQuery = q.Encode()