Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reusable buffers #174

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions pkg/download/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ type BufferMode struct {
// we use this errgroup as a semaphore (via sem.SetLimit())
sem *errgroup.Group
queue *workQueue

bufferedReaderPool *readerPool
}

func GetBufferMode(opts Options) *BufferMode {
Expand All @@ -28,12 +30,14 @@ func GetBufferMode(opts Options) *BufferMode {
sem.SetLimit(opts.maxConcurrency())
queue := newWorkQueue(opts.maxConcurrency())
queue.start()
return &BufferMode{
mode := &BufferMode{
Client: client,
Options: opts,
sem: sem,
queue: queue,
}
mode.bufferedReaderPool = newReaderPool(mode.chunkSize())
return mode
}

func (m *BufferMode) chunkSize() int64 {
Expand Down Expand Up @@ -61,7 +65,7 @@ type firstReqResult struct {
func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, error) {
logger := logging.GetLogger()

br := newBufferedReader(m.chunkSize())
br := m.bufferedReaderPool.Get()

firstReqResultCh := make(chan firstReqResult)
m.queue.submit(func() {
Expand Down Expand Up @@ -114,7 +118,7 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e
// integer divide rounding up
numChunks := int((remainingBytes-1)/m.chunkSize() + 1)

readersCh := make(chan io.Reader, numChunks+1)
readersCh := make(chan io.ReadCloser, numChunks+1)
readersCh <- br

startOffset := m.chunkSize()
Expand All @@ -135,7 +139,7 @@ func (m *BufferMode) Fetch(ctx context.Context, url string) (io.Reader, int64, e
end = fileSize - 1
}

br := newBufferedReader(m.chunkSize())
br := m.bufferedReaderPool.Get()
readersCh <- br

m.sem.Go(func() error {
Expand Down
90 changes: 82 additions & 8 deletions pkg/download/buffered_reader.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,42 +5,56 @@ import (
"fmt"
"io"
"net/http"
"sync"
)

// A bufferedReader wraps an http.Response.Body so that it can be eagerly
// downloaded to a buffer before the actual io.Reader consumer can read it.
// It implements io.Reader.
type bufferedReader struct {
// ready channel is closed when we're ready to read
ready chan struct{}
ready bool
c sync.Cond
buf *bytes.Buffer
err error
pool Pool
}

var _ io.Reader = &bufferedReader{}
type Pool interface {
Get() *bufferedReader
Put(br *bufferedReader)
}

var _ io.ReadCloser = &bufferedReader{}

func newBufferedReader(capacity int64) *bufferedReader {
func newBufferedReader(capacity int64, readerPool Pool) *bufferedReader {
return &bufferedReader{
ready: make(chan struct{}),
buf: bytes.NewBuffer(make([]byte, 0, capacity)),
c: sync.Cond{L: new(sync.Mutex)},
buf: bytes.NewBuffer(make([]byte, 0, capacity)),
pool: readerPool,
}
}

// Read implements io.Reader. It will block until the full body is available for
// reading.
func (b *bufferedReader) Read(buf []byte) (int, error) {
<-b.ready
b.waitOnReady()
if b.err != nil {
return 0, b.err
}
return b.buf.Read(buf)
}

func (b *bufferedReader) done() {
close(b.ready)
b.c.L.Lock()
defer b.c.L.Unlock()
b.ready = true
b.c.Broadcast()
}

func (b *bufferedReader) downloadBody(resp *http.Response) error {
if b.ready {
return fmt.Errorf("bufferedReader has already been marked as ready")
}
expectedBytes := resp.ContentLength

if expectedBytes > int64(b.buf.Cap()) {
Expand All @@ -58,3 +72,63 @@ func (b *bufferedReader) downloadBody(resp *http.Response) error {
}
return nil
}

func (b *bufferedReader) waitOnReady() {
b.c.L.Lock()
for !b.ready {
b.c.Wait()
}
b.c.L.Unlock()
}

func (b *bufferedReader) Close() error {
b.c.L.Lock()
defer b.c.L.Unlock()

b.ready = false
b.err = nil
b.buf.Reset()

if b.pool != nil {
b.pool.Put(b)
}

return nil
}

type readerPool struct {
pool sync.Pool
}

func (p *readerPool) Get() *bufferedReader {
var reader *bufferedReader
for {
reader = p.pool.Get().(*bufferedReader)
if reader.ready || reader.buf.Len() != 0 {
// unclean buffer, get a different one
continue
}
break
}
// suspenders and a belt, we should not need this but it guarantees the reader
// is coming back to the pool instead of who-knows-where
reader.pool = p
return reader
}

func (p *readerPool) Put(br *bufferedReader) {
if br == nil || br.pool == nil {
return
}
p.pool.Put(br)
}

func newReaderPool(chunkSize int64) *readerPool {
rp := &readerPool{}
rp.pool = sync.Pool{
New: func() interface{} {
return newBufferedReader(chunkSize, rp)
},
}
return rp
}
210 changes: 210 additions & 0 deletions pkg/download/buffered_reader_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
package download

import (
"bytes"
"errors"
"io"
"math/rand"
"net/http"
"strings"
"sync"
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
)

func TestNewBufferedReader(t *testing.T) {
const capacity = int64(100)
rp := newReaderPool(capacity)
br := newBufferedReader(capacity, rp)
require.NotNil(t, br)
assert.Equal(t, capacity, int64(br.buf.Cap()))
assert.Equal(t, int64(0), int64(br.buf.Len()))
assert.Equal(t, false, br.ready)
assert.Equal(t, rp, br.pool)
}

func TestBufferedReader_downloadBody(t *testing.T) {
br := newBufferedReader(100, nil)
require.NotNil(t, br)
data := []byte("The quick brown fox jumps over the lazy dog.")
resp := &http.Response{ContentLength: int64(len(data)), Body: io.NopCloser(bytes.NewReader(data))}
err := br.downloadBody(resp)
assert.NoError(t, err)
assert.Equal(t, int64(len(data)), int64(br.buf.Len()))
br.done()
resp = &http.Response{ContentLength: int64(len(data)), Body: io.NopCloser(bytes.NewReader(data))}
err = br.downloadBody(resp)
require.Error(t, err)
}

func TestBufferedReader_Read(t *testing.T) {
testErr := errors.New("error")
tc := []struct {
name string
expectedErr error
expectedRead int
bufferErr error
waitOnReady bool
}{
{
name: "Read with no error",
expectedErr: nil,
expectedRead: 10,
waitOnReady: false,
},
{
name: "Read with error EOF",
expectedErr: io.EOF,
expectedRead: 0,
waitOnReady: false,
},
{
name: "Read waiting on ready",
expectedErr: nil,
expectedRead: 10,
waitOnReady: true,
},
{
name: "Read waiting on ready",
expectedErr: testErr,
expectedRead: 0,
bufferErr: testErr,
waitOnReady: true,
},
}
for _, tt := range tc {
t.Run(tt.name, func(t *testing.T) {
testCase := tt
wg := new(sync.WaitGroup)
wg.Add(1)
br := newBufferedReader(100, nil)
if testCase.bufferErr != nil {
br.err = testCase.bufferErr
}
if testCase.expectedRead > 0 {
content := []byte(strings.Repeat("a", 100))
_, _ = br.buf.ReadFrom(bytes.NewReader(content))
}
require.NotNil(t, br)
if !tt.waitOnReady {
br.ready = true
}
readBuf := make([]byte, 10)
go func() {
defer wg.Done()
n, err := br.Read(readBuf)
assert.Equal(t, testCase.expectedRead, n)
assert.Equal(t, testCase.expectedErr, err)
}()
br.done()
wg.Wait()
},
)
}
}

func TestBufferedReader_done(t *testing.T) {
br := newBufferedReader(100, nil)
assert.False(t, br.ready)
br.done()
assert.True(t, br.ready)
}

func getReaderPool(t *testing.T) (*readerPool, int64) {
capacity := 1000 + rand.Int63n(2225-1000+1)
rp := newReaderPool(capacity)
require.NotNil(t, rp)
return rp, capacity
}

func TestReaderPool_Get(t *testing.T) {
rp, capacity := getReaderPool(t)
buf := rp.Get()
require.NotNil(t, buf)
assert.Equal(t, capacity, int64(buf.buf.Cap()))
assert.Equal(t, int64(0), int64(buf.buf.Len()))
assert.Equal(t, false, buf.ready)

rp.Put(buf)
buf.pool = nil

newBuf := rp.Get()
require.NotNil(t, buf)
assert.Equal(t, &buf, &newBuf)

buf.ready = true
rp.pool.Put(buf)
newBuf = rp.Get()
require.NotNil(t, newBuf)
assert.NotEqual(t, &buf, &newBuf)
}

func TestReaderPool_Put(t *testing.T) {
rp, _ := getReaderPool(t)
// Get a buffer from the pool and fill it with data
buf := rp.pool.Get().(*bufferedReader)
require.NotNil(t, buf)
rp.Put(buf)
// Get a new buffer from the pool and verify it is the same as the one we just put back
newBuffer := rp.pool.Get().(*bufferedReader)
require.NotNil(t, newBuffer)
assert.Equal(t, &newBuffer, &buf)
// check a nil put
rp.Put(nil)
reader := rp.pool.Get()
require.NotNil(t, reader)
assert.NotEqual(t, &newBuffer, &reader)
}

func TestNewReaderPool(t *testing.T) {
rp, capacity := getReaderPool(t)
buf := rp.pool.Get().(*bufferedReader)
require.NotNil(t, buf)
assert.Equal(t, capacity, int64(buf.buf.Cap()))
assert.Equal(t, int64(0), int64(buf.buf.Len()))
assert.Equal(t, false, buf.ready)
}

type mockPool struct {
mock.Mock
br *bufferedReader
}

func (m *mockPool) Get() *bufferedReader {
m.Called()
return m.br
}

func (m *mockPool) Put(br *bufferedReader) {
m.Called(br)
}

func TestBufferedReader_Close(t *testing.T) {
var rp Pool
mp := &mockPool{}
rp = mp
mp.br = newBufferedReader(1024, rp)
capacity := int64(1024)

mp.On("Get").Return(mp.br)
mp.On("Put", mp.br).Return()

buf := rp.Get()
require.NotNil(t, buf)
content := []byte(strings.Repeat("a", 100))
_, _ = buf.buf.ReadFrom(bytes.NewReader(content))
buf.done()
assert.True(t, buf.ready)
assert.Nil(t, buf.err)
assert.Equal(t, int64(100), int64(buf.buf.Len()))
assert.Equal(t, &rp, &buf.pool)
assert.Equal(t, capacity, int64(buf.buf.Cap()))
buf.Close()
assert.NotNil(t, buf.pool)
assert.Nil(t, buf.err)
assert.Zero(t, buf.buf.Len())
assert.Equal(t, capacity, int64(buf.buf.Cap()))
}
Loading