Skip to content

Commit

Permalink
Refactor ip limit
Browse files Browse the repository at this point in the history
  • Loading branch information
wzshiming committed Jun 12, 2024
1 parent 8ba1786 commit c841f6a
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 34 deletions.
49 changes: 15 additions & 34 deletions crproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -525,20 +525,22 @@ func (c *CRProxy) cacheBlobResponse(rw http.ResponseWriter, r *http.Request, inf

stat, err := c.storageDriver.Stat(ctx, blobPath)
if err == nil {
doneCache()

size := stat.Size()
if r.Method == http.MethodHead {
rw.Header().Set("Content-Length", strconv.FormatInt(stat.Size(), 10))
rw.Header().Set("Content-Length", strconv.FormatInt(size, 10))
rw.Header().Set("Content-Type", "application/octet-stream")
doneCache()
return
}
c.accumulativeLimit(rw, r, info, stat.Size())

c.accumulativeLimit(rw, r, info, size)

err = c.redirect(rw, r, blobPath)
if err == nil {
doneCache()
return
}
c.errorResponse(rw, r, ctx.Err())
doneCache()
return
}
if c.logger != nil {
Expand Down Expand Up @@ -700,29 +702,6 @@ func (c *CRProxy) checkLimit(rw http.ResponseWriter, r *http.Request, info *Path
}
}

if c.ipsSpeedLimit != nil && info.Blobs != "" {
address := addr(r.RemoteAddr)
bps, _ := c.speedLimitRecord.LoadOrStore(address, geario.NewBPSAver(c.ipsSpeedLimitDuration))
aver := bps.Aver()
if aver > *c.ipsSpeedLimit {
if c.logger != nil {
c.logger.Println("exceed limit", address, aver, *c.ipsSpeedLimit)
}
if c.limitDelay {
select {
case <-r.Context().Done():
return false
case <-time.After(bps.Next().Sub(time.Now())):
}
} else {
err := ErrorCodeTooManyRequests
rw.Header().Set("X-Retry-After", strconv.FormatInt(bps.Next().Unix(), 10))
errcode.ServeJSON(rw, err)
return false
}
}
}

return true
}

Expand All @@ -731,15 +710,17 @@ func (c *CRProxy) accumulativeLimit(rw http.ResponseWriter, r *http.Request, inf
return
}

if c.blobsSpeedLimit != nil && info.Blobs != "" {
bps, ok := c.speedLimitRecord.Load(info.Blobs)
if ok {
bps.Add(geario.B(size))
if c.ipsSpeedLimit != nil {
dur := GetSleepDuration(geario.B(size), *c.ipsSpeedLimit, c.ipsSpeedLimitDuration)
select {
case <-r.Context().Done():
return
case <-time.After(dur):
}
}

if c.ipsSpeedLimit != nil && info.Blobs != "" {
bps, ok := c.speedLimitRecord.Load(addr(r.RemoteAddr))
if c.blobsSpeedLimit != nil && info.Blobs != "" {
bps, ok := c.speedLimitRecord.Load(info.Blobs)
if ok {
bps.Add(geario.B(size))
}
Expand Down
11 changes: 11 additions & 0 deletions hang_manager.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package crproxy

import (
"time"

"github.com/wzshiming/geario"
)

func GetSleepDuration(s geario.B, limit geario.B, r time.Duration) time.Duration {
return time.Duration(s/(limit/geario.B(r)*geario.B(time.Second))) * time.Second
}
69 changes: 69 additions & 0 deletions hang_manager_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
package crproxy

import (
"testing"
"time"

"github.com/wzshiming/geario"
)

func TestGetSleepDuration(t *testing.T) {
type args struct {
s geario.B
limit geario.B
r time.Duration
}
tests := []struct {
name string
args args
want time.Duration
}{
{
args: args{
s: 100,
limit: 100,
r: time.Second,
},
want: time.Second,
},
{
args: args{
s: 200,
limit: 100,
r: time.Second,
},
want: 2 * time.Second,
},
{
args: args{
s: 100,
limit: 50,
r: time.Second,
},
want: 2 * time.Second,
},
{
args: args{
s: 100,
limit: 100,
r: 2 * time.Second,
},
want: 2 * time.Second,
},
{
args: args{
s: 100 * geario.MiB,
limit: geario.MiB,
r: time.Second,
},
want: 100 * time.Second,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := GetSleepDuration(tt.args.s, tt.args.limit, tt.args.r); got != tt.want {
t.Errorf("GetSleepDuration() = %v, want %v", got, tt.want)
}
})
}
}

0 comments on commit c841f6a

Please sign in to comment.