Skip to content

Commit

Permalink
Merge pull request #8 from d--j/body-in-memory
Browse files Browse the repository at this point in the history
Capture Body in memory
  • Loading branch information
d--j authored Apr 12, 2023
2 parents 29b2f01 + c008a2d commit a988b8e
Show file tree
Hide file tree
Showing 7 changed files with 276 additions and 14 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,8 +64,8 @@ func main() {
subject, _ := trx.Headers().Subject()
if !strings.HasPrefix(subject, "[⚠️EXTERNAL] ") {
subject = "[⚠️EXTERNAL] " + subject
trx.Headers().SetSubject(subject)
}
trx.Headers().SetSubject(subject)
}
return mailfilter.Accept, nil
},
Expand Down
108 changes: 108 additions & 0 deletions internal/body/body.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// Package body implements a write-once read-multiple [io.ReadSeekCloser] that is backed by a temporary file when too much data gets written into it.
package body

import (
"bytes"
"io"
"os"
)

// New creates a new Body that switches from memory-backed storage to file-backed storage
// when more than maxMem bytes were written to it.
//
// If maxMem is less than 1 a temporary file gets always used.
func New(maxMem int) *Body {
return &Body{maxMem: maxMem}
}

// Body is an [io.ReadSeekCloser] and [io.Writer] that starts buffering all data written to it in memory
// but when more than a configured amount of bytes is written to it Body will switch to writing to a temporary file.
//
// After a call to Read or Seek no more data can be written to Body.
// Body is an [io.Seeker] so you can read it multiple times or get the size of the Body.
type Body struct {
maxMem int
buf bytes.Buffer
mem *bytes.Reader
file *os.File
reading bool
}

// Write implements the io.Writer interface.
// Write will create a temporary file on-the-fly when you write more than the configured amount of bytes.
func (b *Body) Write(p []byte) (n int, err error) {
if b.reading {
panic("cannot write after read")
}
if b.file != nil {
return b.file.Write(p)
}
n, _ = b.buf.Write(p)
if b.buf.Len() > b.maxMem {
b.file, err = os.CreateTemp("", "body-*")
if err != nil {
return
}
_, err = io.Copy(b.file, &b.buf)
b.buf.Reset()
}
return
}

func (b *Body) switchToReading() error {
if !b.reading {
b.reading = true
if b.file != nil {
if _, err := b.file.Seek(0, io.SeekStart); err != nil {
return err
}
} else {
b.mem = bytes.NewReader(b.buf.Bytes())
}
}
return nil
}

// Read implements the io.Reader interface.
// After calling Read you cannot call Write anymore.
func (b *Body) Read(p []byte) (n int, err error) {
if err := b.switchToReading(); err != nil {
return 0, err
}
if b.file != nil {

return b.file.Read(p)
}
return b.mem.Read(p)
}

// Close implements the io.Closer interface.
// If a temporary file got created it will be deleted.
func (b *Body) Close() error {
if b.file != nil {
err1 := b.file.Close()
err2 := os.Remove(b.file.Name())
if err1 != nil {
return err1
}
if os.IsNotExist(err2) {
err2 = nil
}
return err2
}
b.mem = nil
b.buf.Reset()
return nil
}

// Seek implements the io.Seeker interface.
// After calling Seek you cannot call Write anymore.
func (b *Body) Seek(offset int64, whence int) (int64, error) {
if err := b.switchToReading(); err != nil {
return 0, err
}
if b.file != nil {
return b.file.Seek(offset, whence)
}
return b.mem.Seek(offset, whence)
}
158 changes: 158 additions & 0 deletions internal/body/body_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
package body

import (
"bytes"
"io"
"os"
"testing"
)

func getBody(maxMem int, data []byte) *Body {
b := New(maxMem)
_, _ = b.Write(data)
return b
}

func TestBody_Close(t *testing.T) {
fileAlreadyRemoved := getBody(2, []byte("test"))
_ = os.Remove(fileAlreadyRemoved.file.Name())
tests := []struct {
name string
body *Body
wantErr bool
}{
{"noop", getBody(10, nil), false},
{"mem", getBody(10, []byte("test")), false},
{"file", getBody(2, []byte("test")), false},
{"file-already-removed", fileAlreadyRemoved, false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if err := tt.body.Close(); (err != nil) != tt.wantErr {
t.Errorf("Close() error = %v, wantErr %v", err, tt.wantErr)
}
})
}
}

func TestBody(t *testing.T) {
t.Run("mem", func(t *testing.T) {
b := getBody(10, []byte("test"))
defer b.Close()
_, err := b.Write([]byte("test"))
if err != nil {
t.Fatal("b.Write got error", err)
}
if b.file != nil {
t.Fatal("b.file needs to be nil")
}
var buf [10]byte
n, err := b.Read(buf[:])
if err != nil {
t.Fatal("b.Read got error", err)
}
if !bytes.Equal([]byte("testtest"), buf[:n]) {
t.Fatalf("b.Read got %q expected %q", buf[:n], []byte("testtest"))
}
pos, err := b.Seek(0, io.SeekStart)
if err != nil {
t.Fatal("b.Seek got error", err)
}
if pos != 0 {
t.Fatal("b.Seek got pos", pos)
}
n, err = b.Read(buf[:])
if err != nil {
t.Fatal("b.Read got error", err)
}
if !bytes.Equal([]byte("testtest"), buf[:n]) {
t.Fatalf("b.Read got %q expected %q", buf[:n], []byte("testtest"))
}
})
t.Run("file", func(t *testing.T) {
b := getBody(2, []byte("test"))
defer func() {
if b != nil {
b.Close()
}
}()
if b.file == nil {
t.Fatal("b.file is nil")
}
_, err := b.Write([]byte("test"))
if err != nil {
t.Fatal("b.Write got error", err)
}
var buf [10]byte
n, err := b.Read(buf[:])
if err != nil {
t.Fatal("b.Read got error", err)
}
if !bytes.Equal([]byte("testtest"), buf[:n]) {
t.Fatalf("b.Read got %q expected %q", buf[:n], []byte("testtest"))
}
pos, err := b.Seek(0, io.SeekStart)
if err != nil {
t.Fatal("b.Seek got error", err)
}
if pos != 0 {
t.Fatal("b.Seek got pos", pos)
}
n, err = b.Read(buf[:])
if err != nil {
t.Fatal("b.Read got error", err)
}
if !bytes.Equal([]byte("testtest"), buf[:n]) {
t.Fatalf("b.Read got %q expected %q", buf[:n], []byte("testtest"))
}
name := b.file.Name()
err = b.Close()
b = nil
if err != nil {
t.Fatal("b.Close got error", err)
}
_, err = os.Stat(name)
if err == nil || !os.IsNotExist(err) {
t.Fatalf("got %v expected to not find file", err)
}
})
t.Run("panic on Write after Read", func(t *testing.T) {
defer func() { _ = recover() }()
b := getBody(10, []byte("test"))
var buf [10]byte
_, _ = b.Read(buf[:])
_, _ = b.Write([]byte("test"))
t.Errorf("did not panic")
})
t.Run("panic on Write after Seek", func(t *testing.T) {
defer func() { _ = recover() }()
b := getBody(10, []byte("test"))
_, _ = b.Seek(0, io.SeekEnd)
_, _ = b.Write([]byte("test"))
t.Errorf("did not panic")
})
t.Run("temp file fail", func(t *testing.T) {
tmpdir := os.Getenv("TMPDIR")
tmp := os.Getenv("TMP")
_ = os.Setenv("TMPDIR", "/this does not exist")
_ = os.Setenv("TMP", "/this does not exist")
defer func() {
_ = os.Setenv("TMPDIR", tmpdir)
_ = os.Setenv("TMP", tmp)
}()
b := getBody(6, []byte("test"))
_, err := b.Write([]byte("test"))
if err == nil {
_ = b.Close()
t.Fatal("b.Write got nil error")
}
})
t.Run("file close fail", func(t *testing.T) {
b := getBody(2, []byte("test"))
_ = b.file.Close()
err := b.Close()
if err == nil {
t.Fatal("b.Close got nil error")
}
})
}
2 changes: 1 addition & 1 deletion mailfilter/example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,8 @@ func ExampleNew() {
subject, _ := trx.Headers().Subject()
if !strings.HasPrefix(subject, "[⚠️EXTERNAL] ") {
subject = "[⚠️EXTERNAL] " + subject
trx.Headers().SetSubject(subject)
}
trx.Headers().SetSubject(subject)
}
return mailfilter.Accept, nil
},
Expand Down
2 changes: 1 addition & 1 deletion mailfilter/testtrx/trx.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ type Modification struct {
// Trx implements [mailfilter.Trx] for unit tests.
// Use this struct when you want to test your decision functions.
// You can use the fluent Set* methods of this struct to build up the transaction you want to test.
// After you passed the Trx to your decision function, you can call [Trx.Modifications] and [Trx.Log] to
// After you passed the Trx to your decision function, you can call [Trx.Modifications] to
// check that your decision function did what was expected of it.
type Trx struct {
mta mailfilter.MTA
Expand Down
10 changes: 3 additions & 7 deletions mailfilter/transaction.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,10 @@ import (
"bytes"
"context"
"io"
"os"
"regexp"

"github.com/d--j/go-milter"
"github.com/d--j/go-milter/internal/body"
"github.com/d--j/go-milter/internal/header"
"github.com/d--j/go-milter/internal/rcptto"
"github.com/d--j/go-milter/mailfilter/addr"
Expand Down Expand Up @@ -58,7 +58,7 @@ type transaction struct {
headers *header.Header
origHeaders *header.Header
enforceHeaderOrder bool
body *os.File
body *body.Body
replacementBody io.Reader
queueId string
hasDecision bool
Expand Down Expand Up @@ -92,7 +92,6 @@ func (t *transaction) cleanup() {
t.closeReplacementBody()
if t.body != nil {
_ = t.body.Close()
_ = os.Remove(t.body.Name())
t.body = nil
}
}
Expand Down Expand Up @@ -249,10 +248,7 @@ func (t *transaction) addHeader(key string, raw []byte) {

func (t *transaction) addBodyChunk(chunk []byte) (err error) {
if t.body == nil {
t.body, err = os.CreateTemp("", "body-*")
if err != nil {
return
}
t.body = body.New(200 * 1024)
}
_, err = t.body.Write(chunk)
return
Expand Down
8 changes: 4 additions & 4 deletions milterutil/buffer.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,8 @@ import (
"sync"
)

// FixedBufferScanner is a wrapper around a bufio.Scanner that produces fixed size chunks of data
// given an io.Reader.
// FixedBufferScanner is a wrapper around a [bufio.Scanner] that produces fixed size chunks of data
// given an [io.Reader].
type FixedBufferScanner struct {
bufferSize uint32
buffer []byte
Expand Down Expand Up @@ -54,9 +54,9 @@ func (f *FixedBufferScanner) Err() error {
}

// Close need to be called when you are done with the FixedBufferScanner because we maintain a shared pool
// of FixedBufferScanner.
// of FixedBufferScanner objects.
//
// Close does not close the underlying io.Reader. It is the responsibility of the caller to do this.
// Close does not close the underlying [io.Reader]. It is the responsibility of the caller to do this.
func (f *FixedBufferScanner) Close() {
f.pool.Put(f)
}
Expand Down

0 comments on commit a988b8e

Please sign in to comment.