diff --git a/reader-v1.go b/reader-v1.go index 44ff07e..433ee60 100644 --- a/reader-v1.go +++ b/reader-v1.go @@ -18,7 +18,6 @@ import ( "errors" "io" "io/ioutil" - "sync" ) type encReaderV10 struct { @@ -28,6 +27,7 @@ type encReaderV10 struct { buffer packageV10 offset int payloadSize int + stateErr error } // encryptReaderV10 returns an io.Reader wrapping the given io.Reader. @@ -40,12 +40,20 @@ func encryptReaderV10(src io.Reader, config *Config) (*encReaderV10, error) { return &encReaderV10{ authEncV10: ae, src: src, - buffer: make(packageV10, maxPackageSize), + buffer: packageBufferPool.Get().([]byte)[:maxPackageSize], payloadSize: config.PayloadSize, }, nil } +func (r *encReaderV10) recycle() { + recyclePackageBufferPool(r.buffer) + r.buffer = nil +} + func (r *encReaderV10) Read(p []byte) (int, error) { + if r.stateErr != nil { + return 0, r.stateErr + } var n int if r.offset > 0 { // write the buffered package to p remaining := r.buffer.Length() - r.offset // remaining encrypted bytes @@ -61,6 +69,8 @@ func (r *encReaderV10) Read(p []byte) (int, error) { for len(p) >= headerSize+r.payloadSize+tagSize { nn, err := io.ReadFull(r.src, p[headerSize:headerSize+r.payloadSize]) // read plaintext from src if err != nil && err != io.ErrUnexpectedEOF { + r.recycle() + r.stateErr = err return n, err // return if reading from src fails or reached EOF } r.Seal(p, p[headerSize:headerSize+nn]) @@ -70,6 +80,8 @@ func (r *encReaderV10) Read(p []byte) (int, error) { if len(p) > 0 { nn, err := io.ReadFull(r.src, r.buffer[headerSize:headerSize+r.payloadSize]) // read plaintext from src if err != nil && err != io.ErrUnexpectedEOF { + r.stateErr = err + r.recycle() return n, err // return if reading from src fails or reached EOF } r.Seal(r.buffer, r.buffer[headerSize:headerSize+nn]) @@ -87,8 +99,9 @@ type decReaderV10 struct { authDecV10 src io.Reader - buffer packageV10 - offset int + buffer packageV10 + offset int + stateErr error } // decryptReaderV10 returns an io.Reader wrapping the given io.Reader. @@ -101,11 +114,19 @@ func decryptReaderV10(src io.Reader, config *Config) (*decReaderV10, error) { return &decReaderV10{ authDecV10: ad, src: src, - buffer: make(packageV10, maxPackageSize), + buffer: packageBufferPool.Get().([]byte)[:maxPackageSize], }, nil } +func (r *decReaderV10) recycle() { + recyclePackageBufferPool(r.buffer) + r.buffer = nil +} + func (r *decReaderV10) Read(p []byte) (n int, err error) { + if r.stateErr != nil { + return 0, r.stateErr + } if r.offset > 0 { // write the buffered plaintext to p payload := r.buffer.Payload() remaining := len(payload) - r.offset // remaining plaintext bytes @@ -120,10 +141,14 @@ func (r *decReaderV10) Read(p []byte) (n int, err error) { } for len(p) >= maxPayloadSize { if err = r.readPackage(r.buffer); err != nil { + r.stateErr = err + r.recycle() return n, err } length := len(r.buffer.Payload()) if err = r.Open(p[:length], r.buffer[:r.buffer.Length()]); err != nil { // notice: buffer.Length() may be smaller than len(buffer) + r.stateErr = err + r.recycle() return n, err // decryption failed } p = p[length:] @@ -131,10 +156,14 @@ func (r *decReaderV10) Read(p []byte) (n int, err error) { } if len(p) > 0 { if err = r.readPackage(r.buffer); err != nil { + r.stateErr = err + r.recycle() return n, err } payload := r.buffer.Payload() if err = r.Open(payload, r.buffer[:r.buffer.Length()]); err != nil { // notice: buffer.Length() may be smaller than len(buffer) + r.stateErr = err + r.recycle() return n, err // decryption failed } if len(payload) < len(p) { @@ -170,8 +199,7 @@ func (r *decReaderV10) readPackage(dst packageV10) error { type decReaderAtV10 struct { src io.ReaderAt - ad authDecV10 - bufPool sync.Pool + ad authDecV10 } // decryptReaderAtV10 returns an io.ReaderAt wrapping the given io.ReaderAt. @@ -185,12 +213,7 @@ func decryptReaderAtV10(src io.ReaderAt, config *Config) (*decReaderAtV10, error ad: ad, src: src, } - r.bufPool = sync.Pool{ - New: func() interface{} { - b := make([]byte, maxPackageSize) - return &b - }, - } + return r, nil } @@ -204,12 +227,12 @@ func (r *decReaderAtV10) ReadAt(p []byte, offset int64) (n int, err error) { return 0, errUnexpectedSize } - buffer := r.bufPool.Get().(*[]byte) - defer r.bufPool.Put(buffer) + buffer := packageBufferPool.Get().([]byte)[:maxPackageSize] + defer recyclePackageBufferPool(buffer) decReader := decReaderV10{ authDecV10: r.ad, src: §ionReader{r.src, t * maxPackageSize}, - buffer: packageV10(*buffer), + buffer: packageV10(buffer), offset: 0, } decReader.SeqNum = uint32(t) diff --git a/reader-v2.go b/reader-v2.go index e4f0b9f..23cdf44 100644 --- a/reader-v2.go +++ b/reader-v2.go @@ -28,10 +28,24 @@ type encReaderV20 struct { buffer packageV20 offset int lastByte byte + stateErr error firstRead bool } +var packageBufferPool = sync.Pool{New: func() interface{} { return make([]byte, maxPackageSize) }} + +func recyclePackageBufferPool(b []byte) { + if cap(b) < maxPackageSize { + return + } + // Clear so we don't potentially leak info between callers + for i := range b { + b[i] = 0 + } + packageBufferPool.Put(b) +} + // encryptReaderV20 returns an io.Reader wrapping the given io.Reader. // The returned io.Reader encrypts everything it reads using DARE 2.0. func encryptReaderV20(src io.Reader, config *Config) (*encReaderV20, error) { @@ -42,12 +56,20 @@ func encryptReaderV20(src io.Reader, config *Config) (*encReaderV20, error) { return &encReaderV20{ authEncV20: ae, src: src, - buffer: make(packageV20, maxPackageSize), + buffer: packageBufferPool.Get().([]byte)[:maxPackageSize], firstRead: true, }, nil } +func (r *encReaderV20) recycle() { + recyclePackageBufferPool(r.buffer) + r.buffer = nil +} + func (r *encReaderV20) Read(p []byte) (n int, err error) { + if r.stateErr != nil { + return 0, r.stateErr + } if r.firstRead { r.firstRead = false _, err = io.ReadFull(r.src, r.buffer[headerSize:headerSize+1]) @@ -56,6 +78,8 @@ func (r *encReaderV20) Read(p []byte) (n int, err error) { } if err == io.EOF { r.finalized = true + r.stateErr = io.EOF + r.recycle() return 0, io.EOF } r.lastByte = r.buffer[headerSize] @@ -72,6 +96,8 @@ func (r *encReaderV20) Read(p []byte) (n int, err error) { r.offset = 0 } if r.finalized { + r.stateErr = io.EOF + r.recycle() return n, io.EOF } for len(p) >= maxPackageSize { @@ -93,12 +119,16 @@ func (r *encReaderV20) Read(p []byte) (n int, err error) { r.buffer[headerSize] = r.lastByte nn, err := io.ReadFull(r.src, r.buffer[headerSize+1:headerSize+1+maxPayloadSize]) // try to read the max. payload if err != nil && err != io.EOF && err != io.ErrUnexpectedEOF { + r.stateErr = err + r.recycle() return n, err // failed to read from src } if err == io.EOF || err == io.ErrUnexpectedEOF { // read less than 64KB -> final package r.SealFinal(r.buffer, r.buffer[headerSize:headerSize+1+nn]) if len(p) > r.buffer.Length() { n += copy(p, r.buffer[:r.buffer.Length()]) + r.stateErr = io.EOF + r.recycle() return n, io.EOF } } else { @@ -115,8 +145,9 @@ type decReaderV20 struct { authDecV20 src io.Reader - buffer packageV20 - offset int + stateErr error + buffer packageV20 + offset int } // decryptReaderV20 returns an io.Reader wrapping the given io.Reader. @@ -129,11 +160,19 @@ func decryptReaderV20(src io.Reader, config *Config) (*decReaderV20, error) { return &decReaderV20{ authDecV20: ad, src: src, - buffer: make(packageV20, maxPackageSize), + buffer: packageBufferPool.Get().([]byte)[:maxPackageSize], }, nil } +func (r *decReaderV20) recycle() { + recyclePackageBufferPool(r.buffer) + r.buffer = nil +} + func (r *decReaderV20) Read(p []byte) (n int, err error) { + if r.stateErr != nil { + return 0, r.stateErr + } if r.offset > 0 { // write the buffered plaintext to p remaining := len(r.buffer.Payload()) - r.offset if len(p) < remaining { @@ -151,9 +190,13 @@ func (r *decReaderV20) Read(p []byte) (n int, err error) { return n, errUnexpectedEOF // reached EOF but not seen final package yet } if err != nil && err != io.ErrUnexpectedEOF { + r.recycle() + r.stateErr = err return n, err // reading from src failed or reached EOF } if err = r.Open(p, r.buffer[:nn]); err != nil { + r.recycle() + r.stateErr = err return n, err // decryption failed } p = p[len(r.buffer.Payload()):] @@ -162,9 +205,13 @@ func (r *decReaderV20) Read(p []byte) (n int, err error) { if len(p) > 0 { nn, err := io.ReadFull(r.src, r.buffer) if err == io.EOF && !r.finalized { + r.recycle() + r.stateErr = errUnexpectedEOF return n, errUnexpectedEOF // reached EOF but not seen final package yet } if err != nil && err != io.ErrUnexpectedEOF { + r.recycle() + r.stateErr = err return n, err // reading from src failed or reached EOF } if err = r.Open(r.buffer[headerSize:], r.buffer[:nn]); err != nil { @@ -217,8 +264,7 @@ func decryptBufferV20(dst, src []byte, config *Config) ([]byte, error) { type decReaderAtV20 struct { src io.ReaderAt - ad authDecV20 - bufPool sync.Pool + ad authDecV20 } // decryptReaderAtV20 returns an io.ReaderAt wrapping the given io.ReaderAt. @@ -232,12 +278,7 @@ func decryptReaderAtV20(src io.ReaderAt, config *Config) (*decReaderAtV20, error ad: ad, src: src, } - r.bufPool = sync.Pool{ - New: func() interface{} { - b := make([]byte, maxPackageSize) - return &b - }, - } + return r, nil } @@ -251,14 +292,13 @@ func (r *decReaderAtV20) ReadAt(p []byte, offset int64) (n int, err error) { return 0, errUnexpectedSize } - buffer := r.bufPool.Get().(*[]byte) - defer r.bufPool.Put(buffer) decReader := decReaderV20{ authDecV20: r.ad, src: §ionReader{r.src, t * maxPackageSize}, - buffer: packageV20(*buffer), + buffer: packageBufferPool.Get().([]byte)[:maxPackageSize], offset: 0, } + defer decReader.recycle() decReader.SeqNum = uint32(t) if k := offset % int64(maxPayloadSize); k > 0 { if _, err := io.CopyN(ioutil.Discard, &decReader, k); err != nil { diff --git a/sio_test.go b/sio_test.go index 494a216..3e6f55f 100644 --- a/sio_test.go +++ b/sio_test.go @@ -321,7 +321,10 @@ func TestWriter(t *testing.T) { t.Errorf("Version %d: Test %d: Writing failed: %v", version, i, err) } if err := encWriter.Close(); err != nil { - t.Errorf("Version %d: Test: %d: Failed to close writer: %v", version, i, err) + t.Errorf("Version %d: Test: %d: Failed to close encrypt writer: %v", version, i, err) + } + if err := decWriter.Close(); err != nil { + t.Errorf("Version %d: Test: %d: Failed to close decode writer: %v", version, i, err) } if !bytes.Equal(data, output.Bytes()) { t.Errorf("Version %d: Test: %d: Failed to encrypt and decrypt data", version, i) @@ -663,19 +666,22 @@ func BenchmarkDecryptWriter_1MB(b *testing.B) { benchmarkDecryptWrite(1024*102 func benchmarkEncryptRead(size int64, b *testing.B) { data := make([]byte, size) - buffer := make([]byte, 32+size*(size/(64*1024)+32)) config := Config{Key: make([]byte, 32)} b.SetBytes(size) b.ResetTimer() - for i := 0; i < b.N; i++ { - reader, err := EncryptReader(bytes.NewReader(data), config) - if err != nil { - b.Fatal(err) - } - if _, err := io.ReadFull(reader, buffer); err != nil && err != io.ErrUnexpectedEOF { - b.Fatal(err) + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + reader, err := EncryptReader(bytes.NewReader(data), config) + if err != nil { + b.Fatal(err) + } + _, err = io.Copy(ioutil.Discard, reader) + if err != nil && err != io.ErrUnexpectedEOF { + b.Fatal(err) + } } - } + }) } func benchmarkDecryptRead(size int64, b *testing.B) { @@ -695,15 +701,19 @@ func benchmarkDecryptRead(size int64, b *testing.B) { b.SetBytes(size) b.ResetTimer() - for i := 0; i < b.N; i++ { - reader, err := DecryptReader(bytes.NewReader(encrypted.Bytes()), config) - if err != nil { - b.Fatal(err) - } - if _, err := io.ReadFull(reader, data); err != nil && err != io.EOF { - b.Fatal(err) + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + reader, err := DecryptReader(bytes.NewReader(encrypted.Bytes()), config) + if err != nil { + b.Fatal(err) + } + _, err = io.Copy(ioutil.Discard, reader) + if err != nil && err != io.EOF { + b.Fatal(err) + } } - } + }) } func benchmarkDecryptReadAt(size int64, b *testing.B) { @@ -723,39 +733,45 @@ func benchmarkDecryptReadAt(size int64, b *testing.B) { b.SetBytes(size) b.ResetTimer() - for i := 0; i < b.N; i++ { - reader, err := DecryptReaderAt(bytes.NewReader(encrypted.Bytes()), config) - if err != nil { - b.Fatal(err) - } - if _, err := reader.ReadAt(data[:len(data)/2], 0); err != nil && err != io.EOF { - b.Fatal(err) - } - if _, err := reader.ReadAt(data[len(data)/2:], int64(len(data)/2)); err != nil && err != io.EOF { - b.Fatal(err) + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + data := make([]byte, size) + for pb.Next() { + reader, err := DecryptReaderAt(bytes.NewReader(encrypted.Bytes()), config) + if err != nil { + b.Fatal(err) + } + if _, err := reader.ReadAt(data[:len(data)/2], 0); err != nil && err != io.EOF { + b.Fatal(err) + } + if _, err := reader.ReadAt(data[len(data)/2:], int64(len(data)/2)); err != nil && err != io.EOF { + b.Fatal(err) + } } - - } + }) } func benchmarkEncryptWrite(size int64, b *testing.B) { data := make([]byte, size) - buffer := make([]byte, 32+size*(size/(64*1024)+32)) config := Config{Key: make([]byte, 32)} b.SetBytes(size) b.ResetTimer() - for i := 0; i < b.N; i++ { - encryptWriter, err := EncryptWriter(bytes.NewBuffer(buffer[:0]), config) - if err != nil { - b.Fatal(err) - } - if _, err = encryptWriter.Write(data); err != nil { - b.Fatal(err) - } - if err = encryptWriter.Close(); err != nil { - b.Fatal(err) + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + buffer := make([]byte, 32+size*(size/(64*1024)+32)) + for pb.Next() { + encryptWriter, err := EncryptWriter(bytes.NewBuffer(buffer[:0]), config) + if err != nil { + b.Fatal(err) + } + if _, err = encryptWriter.Write(data); err != nil { + b.Fatal(err) + } + if err = encryptWriter.Close(); err != nil { + b.Fatal(err) + } } - } + }) } func benchmarkDecryptWrite(size int64, b *testing.B) { @@ -775,16 +791,20 @@ func benchmarkDecryptWrite(size int64, b *testing.B) { b.SetBytes(size) b.ResetTimer() - for i := 0; i < b.N; i++ { - decryptWriter, err := DecryptWriter(bytes.NewBuffer(data[:0]), config) - if err != nil { - b.Fatal(err) - } - if _, err := decryptWriter.Write(encrypted.Bytes()); err != nil { - b.Fatal(err) - } - if err := decryptWriter.Close(); err != nil { - b.Fatal(err) + b.ReportAllocs() + b.RunParallel(func(pb *testing.PB) { + data := make([]byte, size) + for pb.Next() { + decryptWriter, err := DecryptWriter(bytes.NewBuffer(data[:0]), config) + if err != nil { + b.Fatal(err) + } + if _, err := decryptWriter.Write(encrypted.Bytes()); err != nil { + b.Fatal(err) + } + if err := decryptWriter.Close(); err != nil { + b.Fatal(err) + } } - } + }) } diff --git a/writer-v1.go b/writer-v1.go index 64eb811..e97a711 100644 --- a/writer-v1.go +++ b/writer-v1.go @@ -20,8 +20,9 @@ type decWriterV10 struct { authDecV10 dst io.Writer - buffer packageV10 - offset int + buffer packageV10 + offset int + closeErr error } // decryptWriterV10 returns an io.WriteCloser wrapping the given io.Writer. @@ -37,7 +38,7 @@ func decryptWriterV10(dst io.Writer, config *Config) (*decWriterV10, error) { return &decWriterV10{ authDecV10: ad, dst: dst, - buffer: make(packageV10, maxPackageSize), + buffer: packageBufferPool.Get().([]byte)[:maxPackageSize], }, nil } @@ -93,7 +94,15 @@ func (w *decWriterV10) Write(p []byte) (n int, err error) { return n, nil } -func (w *decWriterV10) Close() error { +func (w *decWriterV10) Close() (err error) { + if w.buffer == nil { + return w.closeErr + } + defer func() { + w.closeErr = err + recyclePackageBufferPool(w.buffer) + w.buffer = nil + }() if w.offset > 0 { if w.offset <= headerSize+tagSize { return errInvalidPayloadSize // the payload is always > 0 @@ -102,13 +111,14 @@ func (w *decWriterV10) Close() error { if w.offset < headerSize+header.Len()+tagSize { return errInvalidPayloadSize // there is less data than specified by the header } - if err := w.Open(w.buffer.Payload(), w.buffer[:w.buffer.Length()]); err != nil { + if err = w.Open(w.buffer.Payload(), w.buffer[:w.buffer.Length()]); err != nil { return err } - if err := flush(w.dst, w.buffer.Payload()); err != nil { // write to underlying io.Writer + if err = flush(w.dst, w.buffer.Payload()); err != nil { // write to underlying io.Writer return err } } + if dst, ok := w.dst.(io.Closer); ok { return dst.Close() } @@ -122,6 +132,7 @@ type encWriterV10 struct { buffer packageV10 offset int payloadSize int + closeErr error } // encryptWriterV10 returns an io.WriteCloser wrapping the given io.Writer. @@ -138,7 +149,7 @@ func encryptWriterV10(dst io.Writer, config *Config) (*encWriterV10, error) { return &encWriterV10{ authEncV10: ae, dst: dst, - buffer: make(packageV10, maxPackageSize), + buffer: packageBufferPool.Get().([]byte)[:maxPackageSize], payloadSize: config.PayloadSize, }, nil } @@ -174,7 +185,16 @@ func (w *encWriterV10) Write(p []byte) (n int, err error) { return } -func (w *encWriterV10) Close() error { +func (w *encWriterV10) Close() (err error) { + if w.buffer == nil { + return w.closeErr + } + defer func() { + w.closeErr = err + recyclePackageBufferPool(w.buffer) + w.buffer = nil + }() + if w.offset > 0 { w.Seal(w.buffer[:], w.buffer[headerSize:headerSize+w.offset]) if err := flush(w.dst, w.buffer[:w.buffer.Length()]); err != nil { // write to underlying io.Writer @@ -192,7 +212,7 @@ func flush(w io.Writer, p []byte) error { if err != nil { return err } - if n != len(p) { // not neccasary if the w follows the io.Writer doc *precisely* + if n != len(p) { // not necessary if the w follows the io.Writer doc *precisely* return io.ErrShortWrite } return nil diff --git a/writer-v2.go b/writer-v2.go index 3ae200b..c48afd8 100644 --- a/writer-v2.go +++ b/writer-v2.go @@ -22,8 +22,9 @@ type encWriterV20 struct { authEncV20 dst io.Writer - buffer packageV20 - offset int + buffer packageV20 + offset int + closeErr error } // encryptWriterV20 returns an io.WriteCloser wrapping the given io.Writer. @@ -40,7 +41,7 @@ func encryptWriterV20(dst io.Writer, config *Config) (*encWriterV20, error) { return &encWriterV20{ authEncV20: ae, dst: dst, - buffer: make(packageV20, maxPackageSize), + buffer: packageBufferPool.Get().([]byte)[:maxPackageSize], }, nil } @@ -79,10 +80,19 @@ func (w *encWriterV20) Write(p []byte) (n int, err error) { return n, nil } -func (w *encWriterV20) Close() error { +func (w *encWriterV20) Close() (err error) { + if w.buffer == nil { + return w.closeErr + } + defer func() { + w.closeErr = err + recyclePackageBufferPool(w.buffer) + w.buffer = nil + }() + if w.offset > 0 { // true if at least one Write call happened w.SealFinal(w.buffer, w.buffer[headerSize:headerSize+w.offset]) - if err := flush(w.dst, w.buffer[:headerSize+w.offset+tagSize]); err != nil { // write to underlying io.Writer + if err = flush(w.dst, w.buffer[:headerSize+w.offset+tagSize]); err != nil { // write to underlying io.Writer return err } w.offset = 0 @@ -97,8 +107,9 @@ type decWriterV20 struct { authDecV20 dst io.Writer - buffer packageV20 - offset int + buffer packageV20 + offset int + closeErr error } // decryptWriterV20 returns an io.WriteCloser wrapping the given io.Writer. @@ -114,7 +125,7 @@ func decryptWriterV20(dst io.Writer, config *Config) (*decWriterV20, error) { return &decWriterV20{ authDecV20: ad, dst: dst, - buffer: make(packageV20, maxPackageSize), + buffer: packageBufferPool.Get().([]byte)[:maxPackageSize], }, nil } @@ -157,19 +168,28 @@ func (w *decWriterV20) Write(p []byte) (n int, err error) { return n, nil } -func (w *decWriterV20) Close() error { +func (w *decWriterV20) Close() (err error) { + if w.buffer == nil { + return w.closeErr + } + defer func() { + w.closeErr = err + recyclePackageBufferPool(w.buffer) + w.buffer = nil + }() if w.offset > 0 { if w.offset <= headerSize+tagSize { // the payload is always > 0 return errInvalidPayloadSize } - if err := w.Open(w.buffer[headerSize:w.offset-tagSize], w.buffer[:w.offset]); err != nil { + if err = w.Open(w.buffer[headerSize:w.offset-tagSize], w.buffer[:w.offset]); err != nil { return err } - if err := flush(w.dst, w.buffer[headerSize:w.offset-tagSize]); err != nil { // write to underlying io.Writer + if err = flush(w.dst, w.buffer[headerSize:w.offset-tagSize]); err != nil { // write to underlying io.Writer return err } w.offset = 0 } + if closer, ok := w.dst.(io.Closer); ok { return closer.Close() }