Skip to content

Commit

Permalink
Add internal api
Browse files Browse the repository at this point in the history
  • Loading branch information
wzshiming committed Jul 24, 2024
1 parent cb08acf commit 619aac8
Showing 1 changed file with 101 additions and 30 deletions.
131 changes: 101 additions & 30 deletions cmd/crproxy/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ import (
"net/http/pprof"
"net/url"
"os"
"slices"
"strings"
"sync/atomic"
"time"

"github.com/distribution/distribution/v3/registry/storage/driver/factory"
Expand Down Expand Up @@ -66,6 +68,8 @@ var (
acmeCacheDir string
certFile string
privateKeyFile string

enableInternalAPI bool
)

func init() {
Expand Down Expand Up @@ -105,6 +109,7 @@ func init() {
pflag.StringVar(&acmeCacheDir, "acme-cache-dir", "", "acme cache dir")
pflag.StringVar(&certFile, "cert-file", "", "cert file")
pflag.StringVar(&privateKeyFile, "private-key-file", "", "private key file")
pflag.BoolVar(&enableInternalAPI, "enable-internal-api", false, "enable internal api")
pflag.Parse()
}

Expand Down Expand Up @@ -209,22 +214,44 @@ func main() {
os.Exit(1)
}

lines := bufio.NewReader(bytes.NewReader(f))
hosts := []string{}
for {
line, _, err := lines.ReadLine()
if err == io.EOF {
break
}
h := strings.TrimSpace(string(line))
if len(h) == 0 {
continue
}
hosts = append(hosts, h)
var matcher atomic.Pointer[hostmatcher.Matcher]
m, err := getListFrom(bytes.NewReader(f))
if err != nil {
logger.Println("can't read allow list file %s", allowImageListFromFile)
os.Exit(1)
}
matcher.Store(&m)

if enableInternalAPI {
mux.HandleFunc("PUT /internal/api/allows", func(rw http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
logger.Println("read body failed:", err)
rw.WriteHeader(http.StatusBadRequest)
rw.Write([]byte(err.Error()))
return
}
m, err := getListFrom(bytes.NewReader(body))
if err != nil {
logger.Println("can't read allow list file %s", allowImageListFromFile)
rw.WriteHeader(http.StatusBadRequest)
rw.Write([]byte(err.Error()))
return
}

err = os.WriteFile(allowImageListFromFile, body, 0644)
if err != nil {
logger.Println("write file failed:", err)
rw.WriteHeader(http.StatusBadRequest)
rw.Write([]byte(err.Error()))
return
}

matcher.Store(&m)
})
}
matcher := hostmatcher.NewMatcher(hosts)
opts = append(opts, crproxy.WithBlockFunc(func(info *crproxy.ImageInfo) bool {
return !matcher.Match(info.Host + "/" + info.Name)
return !(*matcher.Load()).Match(info.Host + "/" + info.Name)
}))
} else if len(blockImageList) != 0 || len(allowHostList) != 0 {
allowHostMap := map[string]struct{}{}
Expand Down Expand Up @@ -260,28 +287,49 @@ func main() {
}

if len(privilegedIPList) != 0 || privilegedImageListFromFile != "" {
var matcher hostmatcher.Matcher
var matcher atomic.Pointer[hostmatcher.Matcher]
if privilegedImageListFromFile != "" {
f, err := os.ReadFile(privilegedImageListFromFile)
if err != nil {
logger.Println("can't read privileged list file %s", privilegedImageListFromFile)
os.Exit(1)
}

lines := bufio.NewReader(bytes.NewReader(f))
hosts := []string{}
for {
line, _, err := lines.ReadLine()
if err == io.EOF {
break
}
h := strings.TrimSpace(string(line))
if len(h) == 0 {
continue
}
hosts = append(hosts, h)
m, err := getListFrom(bytes.NewReader(f))
if err != nil {
logger.Println("can't read privileged list file %s", privilegedImageListFromFile)
os.Exit(1)
}
matcher.Store(&m)

if enableInternalAPI {
mux.HandleFunc("PUT /internal/api/privileged", func(rw http.ResponseWriter, r *http.Request) {
body, err := io.ReadAll(r.Body)
if err != nil {
logger.Println("read body failed:", err)
rw.WriteHeader(http.StatusBadRequest)
rw.Write([]byte(err.Error()))
return
}
m, err := getListFrom(bytes.NewReader(body))
if err != nil {
logger.Println("can't read allow list file %s", privilegedImageListFromFile)
rw.WriteHeader(http.StatusBadRequest)
rw.Write([]byte(err.Error()))
return
}

err = os.WriteFile(privilegedImageListFromFile, body, 0644)
if err != nil {
logger.Println("write file failed:", err)
rw.WriteHeader(http.StatusBadRequest)
rw.Write([]byte(err.Error()))
return
}

matcher.Store(&m)
})
}
matcher = hostmatcher.NewMatcher(hosts)
}

set := map[string]struct{}{}
Expand All @@ -295,8 +343,8 @@ func main() {
return true
}
}
if matcher != nil && info != nil {
return matcher.Match(info.Host + "/" + info.Name)
if m := matcher.Load(); m != nil && info != nil {
return (*m).Match(info.Host + "/" + info.Name)
}
return false
}))
Expand Down Expand Up @@ -447,3 +495,26 @@ func getLimit(s string) (geario.B, time.Duration, error) {

return b, d, nil
}

func getListFrom(r io.Reader) (hostmatcher.Matcher, error) {
lines := bufio.NewReader(r)
hosts := []string{}
for {
line, _, err := lines.ReadLine()
if err == io.EOF {
break
}
h := strings.TrimSpace(string(line))
if len(h) == 0 {
continue
}
hosts = append(hosts, h)
}
if len(hosts) == 0 {
return nil, fmt.Errorf("no hosts found")
}
if !slices.IsSorted(hosts) {
return nil, fmt.Errorf("hosts not sorted: %v", hosts)
}
return hostmatcher.NewMatcher(hosts), nil
}

0 comments on commit 619aac8

Please sign in to comment.