From 5853a1ca2357d2f65cd5929aef6de4f075c3bf4b Mon Sep 17 00:00:00 2001 From: Daniel Jagszent Date: Thu, 2 Mar 2023 15:44:08 +0100 Subject: [PATCH 1/4] feat: introduce mail filter abstraction --- README.md | 108 +-- client_test.go | 5 +- go.mod | 3 + go.sum | 3 + log.go | 5 +- macro.go | 4 +- mailfilter/addr.go | 131 ++++ mailfilter/addr_test.go | 212 ++++++ mailfilter/backend.go | 230 ++++++ mailfilter/backend_test.go | 425 +++++++++++ mailfilter/decision.go | 46 ++ mailfilter/decision_test.go | 111 +++ mailfilter/example_test.go | 46 ++ mailfilter/header.go | 494 +++++++++++++ mailfilter/header_test.go | 1232 ++++++++++++++++++++++++++++++++ mailfilter/mailfilter.go | 144 ++++ mailfilter/option.go | 79 ++ mailfilter/transaction.go | 308 ++++++++ mailfilter/transaction_test.go | 284 ++++++++ milter.go | 6 +- modifier.go | 45 +- options.go | 4 +- response.go | 8 +- server.go | 85 ++- session.go | 12 +- session_test.go | 1 - 26 files changed, 3912 insertions(+), 119 deletions(-) create mode 100644 mailfilter/addr.go create mode 100644 mailfilter/addr_test.go create mode 100644 mailfilter/backend.go create mode 100644 mailfilter/backend_test.go create mode 100644 mailfilter/decision.go create mode 100644 mailfilter/decision_test.go create mode 100644 mailfilter/example_test.go create mode 100644 mailfilter/header.go create mode 100644 mailfilter/header_test.go create mode 100644 mailfilter/mailfilter.go create mode 100644 mailfilter/option.go create mode 100644 mailfilter/transaction.go create mode 100644 mailfilter/transaction_test.go diff --git a/README.md b/README.md index 22f637b..d9fbf26 100644 --- a/README.md +++ b/README.md @@ -6,81 +6,81 @@ A Go library to write mail filters. -With this library you can write both the client (MTA/SMTP-Server) and server (milter filter) -in pure Go without sendmail's libmilter. - ## Features +* With this library you can write both the client (MTA/SMTP-Server) and server (milter filter) + in pure Go without sendmail's libmilter. +* Easy wrapper of the milter protocol that abstracts away many milter protocol quirks + and lets you write mail filters with little effort. +* UTF-8 support +* IDNA support * Client & Server support milter protocol version 6 with all features. E.g.: * all milter events including DATA, UNKNOWN, ABORT and QUIT NEW CONNECTION * milter can skip e.g. body chunks when it does not need all chunks * milter can send progress notifications when response can take some time * milter can automatically instruct the MTA which macros it needs. -* UTF-8 support + +## Installation + +```shell +go get -u github.com/d--j/go-milter +``` ## Usage +The following example is a milter filter that adds `[⚠️EXTERNAL] ` to the subject of all messages of unauthenticated users. + +See [GoDoc](https://godoc.org/github.com/d--j/go-milter/mailfilter) for more documentation and an example for a milter client or a raw milter server. + ```go package main import ( - "log" - "net" - "sync" + "context" + "flag" + "log" + "strings" - "github.com/d--j/go-milter" + "github.com/d--j/go-milter/mailfilter" ) -type ExampleBackend struct { - milter.NoOpMilter -} - -func (b *ExampleBackend) RcptTo(rcptTo string, esmtpArgs string, m *milter.Modifier) (*milter.Response, error) { - // reject the mail when it goes to other-spammer@example.com and is a local delivery - if rcptTo == "other-spammer@example.com" && m.Macros.Get(milter.MacroRcptMailer) == "local" { - return milter.RejectWithCodeAndReason(550, "We do not like you\r\nvery much, please go away") - } - return milter.RespContinue, nil -} - func main() { - // create socket to listen on - socket, err := net.Listen("tcp4", "127.0.0.1:6785") - if err != nil { - log.Fatal(err) - } - defer socket.Close() - - // define the backend, required actions, protocol options and macros we want - server := milter.NewServer( - milter.WithMilter(func() milter.Milter { - return &ExampleBackend{} - }), - milter.WithProtocol(milter.OptNoConnect|milter.OptNoHelo|milter.OptNoMailFrom|milter.OptNoBody|milter.OptNoHeaders|milter.OptNoEOH|milter.OptNoUnknown|milter.OptNoData), - milter.WithAction(milter.OptChangeFrom|milter.OptAddRcpt|milter.OptRemoveRcpt), - milter.WithMaroRequest(milter.StageRcpt, []milter.MacroName{milter.MacroRcptMailer}), - ) - defer server.Close() - - // start the milter - var wgDone sync.WaitGroup - wgDone.Add(1) - go func(socket net.Listener) { - if err := server.Serve(socket); err != nil { - log.Fatal(err) - } - wgDone.Done() - }(socket) - - log.Printf("Started milter on %s:%s", socket.Addr().Network(), socket.Addr().String()) - - // quit when milter quits - wgDone.Wait() + // parse commandline arguments + var protocol, address string + flag.StringVar(&protocol, "proto", "tcp", "Protocol family (unix or tcp)") + flag.StringVar(&address, "addr", "127.0.0.1:10003", "Bind to address or unix domain socket") + flag.Parse() + + // create and start the mail filter + mailFilter, err := mailfilter.New(protocol, address, + func(_ context.Context, trx *mailfilter.Transaction) (mailfilter.Decision, error) { + // Reject message when it was sent to our SPAM trap + if trx.HasRcptTo("spam-trap@スパム.example.com") { + return mailfilter.CustomErrorResponse(550, "5.7.1 No thank you"), nil + } + // Prefix subject with [⚠️EXTERNAL] when user is not logged in + if trx.MailFrom.AuthenticatedUser() == "" { + subject, _ := trx.Headers.Subject() + if !strings.HasPrefix(subject, "[⚠️EXTERNAL] ") { + subject = "[⚠️EXTERNAL] " + subject + } + trx.Headers.SetSubject(subject) + } + return mailfilter.Accept, nil + }, + // optimization: we do not need the body of the message for our decision + mailfilter.WithoutBody(), + ) + if err != nil { + log.Fatal(err) + } + log.Printf("Started milter on %s:%s", mailFilter.Addr().Network(), mailFilter.Addr().String()) + + // wait for the mail filter to end + mailFilter.Wait() } ``` -See [![GoDoc](https://godoc.org/github.com/d--j/go-milter?status.svg)](https://godoc.org/github.com/d--j/go-milter) for more documentation and an example for a milter client. - ## License BSD 2-Clause diff --git a/client_test.go b/client_test.go index 23317fd..2a16b70 100644 --- a/client_test.go +++ b/client_test.go @@ -128,6 +128,10 @@ func (mm *MockMilter) Header(name string, value string, m *Modifier) (*Response, if mm.HdrMod != nil { mm.HdrMod(m) } + if mm.Hdr == nil { + mm.Hdr = make(nettextproto.MIMEHeader) + } + mm.Hdr.Add(name, value) return mm.HdrResp, mm.HdrErr } @@ -135,7 +139,6 @@ func (mm *MockMilter) Headers(m *Modifier) (*Response, error) { if mm.HdrsMod != nil { mm.HdrsMod(m) } - mm.Hdr = m.Headers return mm.HdrsResp, mm.HdrsErr } diff --git a/go.mod b/go.mod index 8734ad9..90da13c 100644 --- a/go.mod +++ b/go.mod @@ -4,5 +4,8 @@ go 1.18 require ( github.com/emersion/go-message v0.16.0 + golang.org/x/net v0.7.0 golang.org/x/text v0.7.0 ) + +require github.com/emersion/go-textwrapper v0.0.0-20200911093747-65d896831594 // indirect diff --git a/go.sum b/go.sum index 692a773..0bf7c3a 100644 --- a/go.sum +++ b/go.sum @@ -1,6 +1,9 @@ github.com/emersion/go-message v0.16.0 h1:uZLz8ClLv3V5fSFF/fFdW9jXjrZkXIpE1Fn8fKx7pO4= github.com/emersion/go-message v0.16.0/go.mod h1:pDJDgf/xeUIF+eicT6B/hPX/ZbEorKkUMPOxrPVG2eQ= +github.com/emersion/go-textwrapper v0.0.0-20200911093747-65d896831594 h1:IbFBtwoTQyw0fIM5xv1HF+Y+3ZijDR839WMulgxCcUY= github.com/emersion/go-textwrapper v0.0.0-20200911093747-65d896831594/go.mod h1:aqO8z8wPrjkscevZJFVE1wXJrLpC5LtJG7fqLOsPb2U= +golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g= +golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo= golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8= diff --git a/log.go b/log.go index 08b1873..3691b00 100644 --- a/log.go +++ b/log.go @@ -10,9 +10,8 @@ func logWarning(format string, v ...interface{}) { } // LogWarning is called by this library when it wants to output a warning. -// Warnings can happen even when the library user did everything right (because the other end did something wrong -// but we recovered from it) +// Warnings can happen even when the library user did everything right (because the other end did something wrong) // -// The default implementation uses log.Print to output the warning. +// The default implementation uses [log.Print] to output the warning. // You can re-assign LogWarning to something more suitable for your application. But do not assign nil to it. var LogWarning = logWarning diff --git a/macro.go b/macro.go index 23bd8e0..6d3cecd 100644 --- a/macro.go +++ b/macro.go @@ -83,7 +83,7 @@ type Macros interface { // A MacroBag is safe for concurrent use by multiple goroutines. // It has special handling for the date related macros and can be copied. // -// The zero value of MacroBag is invalid. Use NewMacroBag to create an empty MacroBag. +// The zero value of MacroBag is invalid. Use [NewMacroBag] to create an empty MacroBag. type MacroBag struct { macros map[MacroName]string mutex sync.RWMutex @@ -138,7 +138,7 @@ func (m *MacroBag) Set(name MacroName, value string) { } // Copy copies the macros to a new MacroBag. -// The time.Time values set by SetCurrentDate and SetHeaderDate do not get copied. +// The time.Time values set by [MacroBag.SetCurrentDate] and [MacroBag.SetHeaderDate] do not get copied. func (m *MacroBag) Copy() *MacroBag { m.mutex.Lock() defer m.mutex.Unlock() diff --git a/mailfilter/addr.go b/mailfilter/addr.go new file mode 100644 index 0000000..5df9693 --- /dev/null +++ b/mailfilter/addr.go @@ -0,0 +1,131 @@ +package mailfilter + +import ( + "strings" + + "golang.org/x/net/idna" +) + +// split an user@domain address into user and domain. +// Includes the input address as third array element to quickly check if splitting must be re-done +func split(addr string) []string { + at := strings.LastIndex(addr, "@") + if at < 0 { + return []string{addr, "", addr} + } + + return []string{addr[:at], addr[at+1:], addr} +} + +type addr struct { + Addr string + Args string + parts []string + asciiDomain string + unicodeDomain string +} + +func (a *addr) initParts() { + if len(a.parts) != 3 || a.parts[2] != a.Addr { + a.parts = split(a.Addr) + a.asciiDomain = "" + a.unicodeDomain = "" + } +} + +func (a *addr) Local() string { + a.initParts() + return a.parts[0] +} + +func (a *addr) Domain() string { + a.initParts() + return a.parts[1] +} + +func (a *addr) AsciiDomain() string { + domain := a.Domain() + if domain == "" { + return "" + } + if a.asciiDomain != "" { + return a.asciiDomain + } + + asciiDomain, err := idna.Lookup.ToASCII(domain) + if err != nil { + a.asciiDomain = domain + return domain + } + a.asciiDomain = asciiDomain + return asciiDomain +} + +func (a *addr) UnicodeDomain() string { + domain := a.Domain() + if domain == "" { + return "" + } + if a.unicodeDomain != "" { + return a.unicodeDomain + } + + unicodeDomain, err := idna.Lookup.ToUnicode(domain) + if err != nil { + a.unicodeDomain = domain + return domain + } + a.unicodeDomain = unicodeDomain + return unicodeDomain +} + +type MailFrom struct { + addr + transport string + authenticatedUser string + authenticationMethod string +} + +func (m *MailFrom) Transport() string { + return m.transport +} + +func (m *MailFrom) AuthenticatedUser() string { + return m.authenticatedUser +} + +func (m *MailFrom) AuthenticationMethod() string { + return m.authenticationMethod +} + +type RcptTo struct { + addr + transport string +} + +func (r *RcptTo) Transport() string { + return r.transport +} + +func calculateRcptToDiff(orig []RcptTo, changed []RcptTo) (deletions []RcptTo, additions []RcptTo) { + foundOrig := make(map[string]*RcptTo) + foundChanged := make(map[string]bool) + for _, r := range orig { + foundOrig[r.Addr] = &r + } + for _, r := range changed { + if o := foundOrig[r.Addr]; o == nil && !foundChanged[r.Addr] { + additions = append(additions, r) + } else if o != nil && o.Args != r.Args && !foundChanged[r.Addr] { + deletions = append(deletions, *o) + additions = append(additions, r) + } + foundChanged[r.Addr] = true + } + for _, r := range orig { + if !foundChanged[r.Addr] { + deletions = append(deletions, r) + } + } + return +} diff --git a/mailfilter/addr_test.go b/mailfilter/addr_test.go new file mode 100644 index 0000000..bcc3c36 --- /dev/null +++ b/mailfilter/addr_test.go @@ -0,0 +1,212 @@ +package mailfilter + +import ( + "reflect" + "testing" + "unsafe" +) + +func Test_addr_AsciiDomain(t *testing.T) { + tests := []struct { + name string + Addr string + want string + }{ + {"empty", "", ""}, + {"no domain", "root", ""}, + {"normal", "root@localhost", "localhost"}, + {"IDNA", "root@スパム.example.com", "xn--zck5b2b.example.com"}, + {"IDNA encoded", "root@xn--zck5b2b.example.com", "xn--zck5b2b.example.com"}, + {"IDNA broken", "root@スパム\u0000\u0000\u0000\u0000.example.com", "スパム\u0000\u0000\u0000\u0000.example.com"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := addr{ + Addr: tt.Addr, + } + if got := a.AsciiDomain(); got != tt.want { + t.Errorf("AsciiDomain() = %v, want %v", got, tt.want) + } + }) + } + t.Run("cache", func(t *testing.T) { + a := addr{ + Addr: "root@localhost", + } + got1 := a.AsciiDomain() + got2 := a.AsciiDomain() + + hdr1 := (*reflect.StringHeader)(unsafe.Pointer(&got1)) + hdr2 := (*reflect.StringHeader)(unsafe.Pointer(&got2)) + + if hdr1.Data != hdr2.Data { + t.Errorf("AsciiDomain() did not cache value") + } + }) +} + +func Test_addr_Domain(t *testing.T) { + tests := []struct { + name string + Addr string + want string + }{ + {"empty", "", ""}, + {"no domain", "root", ""}, + {"normal", "root@localhost", "localhost"}, + {"IDNA", "root@スパム.example.com", "スパム.example.com"}, + {"IDNA encoded", "root@xn--zck5b2b.example.com", "xn--zck5b2b.example.com"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := addr{ + Addr: tt.Addr, + } + if got := a.Domain(); got != tt.want { + t.Errorf("Domain() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_addr_Local(t *testing.T) { + tests := []struct { + name string + Addr string + want string + }{ + {"empty", "", ""}, + {"no domain", "root", "root"}, + {"normal", "root@localhost", "root"}, + {"IDNA", "root@スパム.example.com", "root"}, + {"IDNA encoded", "root@xn--zck5b2b.example.com", "root"}, + {"bogus", "local root@localhost", "local root"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := addr{ + Addr: tt.Addr, + } + if got := a.Local(); got != tt.want { + t.Errorf("Local() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_addr_UnicodeDomain(t *testing.T) { + tests := []struct { + name string + Addr string + want string + }{ + {"empty", "", ""}, + {"no domain", "root", ""}, + {"normal", "root@localhost", "localhost"}, + {"IDNA", "root@スパム.example.com", "スパム.example.com"}, + {"IDNA encoded", "root@xn--zck5b2b.example.com", "スパム.example.com"}, + {"IDNA broken", "root@xn--zck5b2b\u0000\u0000\u0000\u0000.example.com", "xn--zck5b2b\u0000\u0000\u0000\u0000.example.com"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + a := addr{ + Addr: tt.Addr, + } + if got := a.UnicodeDomain(); got != tt.want { + t.Errorf("UnicodeDomain() = %v, want %v", got, tt.want) + } + }) + } + t.Run("cache", func(t *testing.T) { + a := addr{ + Addr: "root@localhost", + } + got1 := a.UnicodeDomain() + got2 := a.UnicodeDomain() + + hdr1 := (*reflect.StringHeader)(unsafe.Pointer(&got1)) + hdr2 := (*reflect.StringHeader)(unsafe.Pointer(&got2)) + + if hdr1.Data != hdr2.Data { + t.Errorf("UnicodeDomain() did not cache value") + } + }) +} + +func TestMailFrom(t *testing.T) { + m := MailFrom{ + addr: addr{Addr: "root@localhost", Args: "A=B"}, + transport: "smtpd", + authenticatedUser: "root", + authenticationMethod: "PLAIN", + } + if v := m.Transport(); v != "smtpd" { + t.Errorf("Transoprt() = %q, want %q", v, "smtpd") + } + if v := m.AuthenticatedUser(); v != "root" { + t.Errorf("AuthenticatedUser() = %q, want %q", v, "root") + } + if v := m.AuthenticationMethod(); v != "PLAIN" { + t.Errorf("AuthenticationMethod() = %q, want %q", v, "PLAIN") + } +} + +func TestRcptTo(t *testing.T) { + m := RcptTo{ + addr: addr{Addr: "root@localhost", Args: "A=B"}, + transport: "lmtp", + } + if v := m.Transport(); v != "lmtp" { + t.Errorf("Transoprt() = %q, want %q", v, "lmtp") + } +} + +func Test_calculateRcptToDiff(t *testing.T) { + type args struct { + orig []RcptTo + changed []RcptTo + } + tests := []struct { + name string + args args + wantDeletions []RcptTo + wantAdditions []RcptTo + }{ + {"nil", args{nil, nil}, nil, nil}, + {"empty", args{[]RcptTo{}, []RcptTo{}}, nil, nil}, + {"remove", args{[]RcptTo{{addr: addr{Addr: "one"}}}, []RcptTo{}}, []RcptTo{{addr: addr{Addr: "one"}}}, nil}, + {"add", args{[]RcptTo{}, []RcptTo{{addr: addr{Addr: "one"}}}}, nil, []RcptTo{{addr: addr{Addr: "one"}}}}, + {"add double", args{[]RcptTo{}, []RcptTo{{addr: addr{Addr: "one"}}, {addr: addr{Addr: "one"}}}}, nil, []RcptTo{{addr: addr{Addr: "one"}}}}, + {"change", args{[]RcptTo{{addr: addr{Addr: "one"}}}, []RcptTo{{addr: addr{Addr: "one", Args: "A=B"}}}}, []RcptTo{{addr: addr{Addr: "one"}}}, []RcptTo{{addr: addr{Addr: "one", Args: "A=B"}}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotDeletions, gotAdditions := calculateRcptToDiff(tt.args.orig, tt.args.changed) + if !reflect.DeepEqual(gotDeletions, tt.wantDeletions) { + t.Errorf("calculateRcptToDiff() gotDeletions = %v, want %v", gotDeletions, tt.wantDeletions) + } + if !reflect.DeepEqual(gotAdditions, tt.wantAdditions) { + t.Errorf("calculateRcptToDiff() gotAdditions = %v, want %v", gotAdditions, tt.wantAdditions) + } + }) + } +} + +func Test_split(t *testing.T) { + tests := []struct { + name string + addr string + want []string + }{ + {"empty", "", []string{"", "", ""}}, + {"no domain", "root", []string{"root", "", "root"}}, + {"normal", "root@localhost", []string{"root", "localhost", "root@localhost"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := split(tt.addr); !reflect.DeepEqual(got, tt.want) { + t.Errorf("split() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/mailfilter/backend.go b/mailfilter/backend.go new file mode 100644 index 0000000..6a2bf87 --- /dev/null +++ b/mailfilter/backend.go @@ -0,0 +1,230 @@ +package mailfilter + +import ( + "fmt" + "strings" + "time" + + "github.com/d--j/go-milter" + "golang.org/x/net/context" +) + +type backend struct { + milter.NoOpMilter + opts options + leadingSpace bool + decision DecisionModificationFunc + transaction *Transaction +} + +func (b *backend) decideOrContinue(stage DecisionAt, m *milter.Modifier) (*milter.Response, error) { + if b.opts.decisionAt == stage { + b.makeDecision(m) + if !b.transaction.hasModifications() { + if b.transaction.decisionErr != nil { + return b.error(b.transaction.decisionErr) + } + return b.transaction.response(), nil + } + } + return milter.RespContinue, nil +} + +func (b *backend) error(err error) (*milter.Response, error) { + b.Cleanup() + switch b.opts.errorHandling { + case Error: + return nil, err + case AcceptWhenError: + milter.LogWarning("milter: accept message despite error: %s", err) + return milter.RespAccept, err + case TempFailWhenError: + milter.LogWarning("milter: temp fail message because of error: %s", err) + return milter.RespTempFail, err + case RejectWhenError: + milter.LogWarning("milter: reject message because of error: %s", err) + return milter.RespReject, err + default: + panic(b.opts.errorHandling) + } +} + +func (b *backend) makeDecision(m *milter.Modifier) { + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + ctx, cancel := context.WithCancel(context.Background()) + done := make(chan struct{}) + go func() { + b.transaction.makeDecision(ctx, b.decision) + done <- struct{}{} + }() + for { + select { + case <-done: + return + case <-ticker.C: + err := m.Progress() + if err != nil { + // instruct decision function to abort + cancel() + // wait for decision function + <-done + // if there was no error in the decision function (e.g. it did not actually check ctx.Done()) + // set the Progress error so that we will not actually think we should continue + if b.transaction.decisionErr == nil { + b.transaction.decisionErr = err + } + return + } + } + } +} + +func (b *backend) Connect(host string, family string, port uint16, addr string, m *milter.Modifier) (*milter.Response, error) { + b.Cleanup() + b.transaction.Connect = Connect{ + Host: host, + Family: family, + Port: port, + Addr: addr, + IfName: m.Macros.Get(milter.MacroIfName), + IfAddr: m.Macros.Get(milter.MacroIfAddr), + } + return b.decideOrContinue(DecisionAtConnect, m) +} + +func (b *backend) Helo(name string, m *milter.Modifier) (*milter.Response, error) { + if b.transaction.hasDecision { + return milter.RespContinue, nil + } + b.transaction.Helo = Helo{ + Name: name, + TlsVersion: m.Macros.Get(milter.MacroTlsVersion), + Cipher: m.Macros.Get(milter.MacroCipher), + CipherBits: m.Macros.Get(milter.MacroCipherBits), + CertSubject: m.Macros.Get(milter.MacroCertSubject), + CertIssuer: m.Macros.Get(milter.MacroCertIssuer), + } + return b.decideOrContinue(DecisionAtHelo, m) +} + +func (b *backend) MailFrom(from string, esmtpArgs string, m *milter.Modifier) (*milter.Response, error) { + if b.transaction.hasDecision { + return milter.RespContinue, nil + } + b.transaction.mailFrom = MailFrom{ + addr: addr{Addr: from, Args: esmtpArgs}, + transport: m.Macros.Get(milter.MacroMailMailer), + authenticatedUser: m.Macros.Get(milter.MacroAuthAuthen), + authenticationMethod: m.Macros.Get(milter.MacroAuthType), + } + return b.decideOrContinue(DecisionAtMailFrom, m) +} + +func (b *backend) RcptTo(rcptTo string, esmtpArgs string, m *milter.Modifier) (*milter.Response, error) { + if b.transaction.hasDecision { + return milter.RespSkip, nil + } + b.transaction.rcptTos = append(b.transaction.rcptTos, RcptTo{ + addr: addr{Addr: rcptTo, Args: esmtpArgs}, + transport: m.Macros.Get(milter.MacroRcptMailer), + }) + return milter.RespContinue, nil +} + +func (b *backend) Data(m *milter.Modifier) (*milter.Response, error) { + if b.transaction.hasDecision { + return milter.RespContinue, nil + } + b.transaction.QueueId = m.Macros.Get(milter.MacroQueueId) + return b.decideOrContinue(DecisionAtData, m) +} + +func (b *backend) Header(name string, value string, _ *milter.Modifier) (*milter.Response, error) { + if b.transaction.hasDecision { + return milter.RespSkip, nil + } + name = strings.Trim(name, " \t\r\n") + if b.leadingSpace { + // the MTA did not actually *not* swallow the space, so we add a space because it is required + if len(value) > 0 && value[0] != ' ' && value[0] != '\t' { + value = " " + value + } + } else { + // we only add a space when the first character is not a tab - sendmail swallows the first space + if len(value) == 0 || value[0] != '\t' { + value = " " + value + } + } + if name == "" || value == "" { + milter.LogWarning("milter: skip header %q because we got an empty value or name", name) + } else { + b.transaction.addHeader(name, fmt.Sprintf("%s:%s", name, value)) + } + return milter.RespContinue, nil +} + +func (b *backend) Headers(m *milter.Modifier) (*milter.Response, error) { + if b.transaction.hasDecision { + return milter.RespContinue, nil + } + return b.decideOrContinue(DecisionAtEndOfHeaders, m) +} + +func (b *backend) BodyChunk(chunk []byte, _ *milter.Modifier) (*milter.Response, error) { + if b.transaction.hasDecision || b.opts.skipBody { + return milter.RespSkip, nil + } + err := b.transaction.addBodyChunk(chunk) + if err != nil { + return b.error(err) + } + return milter.RespContinue, nil +} + +func (b *backend) readyForNewMessage() { + if b.transaction != nil { + connect, helo := b.transaction.Connect, b.transaction.Helo + b.Cleanup() + b.transaction.Connect, b.transaction.Helo = connect, helo + } else { + b.Cleanup() + } +} + +func (b *backend) EndOfMessage(m *milter.Modifier) (*milter.Response, error) { + if !b.transaction.hasDecision && b.transaction.QueueId == "" { + b.transaction.QueueId = m.Macros.Get(milter.MacroQueueId) + } + if !b.transaction.hasDecision { + b.makeDecision(m) + } + + if b.transaction.decisionErr != nil { + return b.error(b.transaction.decisionErr) + } + + if err := b.transaction.sendModifications(m); err != nil { + return b.error(err) + } + + response := b.transaction.response() + + b.readyForNewMessage() + + return response, nil +} + +func (b *backend) Abort(_ *milter.Modifier) error { + b.readyForNewMessage() + return nil +} + +func (b *backend) Cleanup() { + if b.transaction != nil { + b.transaction.cleanup() + } + b.transaction = &Transaction{} +} + +var _ milter.Milter = &backend{} diff --git a/mailfilter/backend_test.go b/mailfilter/backend_test.go new file mode 100644 index 0000000..c13852b --- /dev/null +++ b/mailfilter/backend_test.go @@ -0,0 +1,425 @@ +package mailfilter + +import ( + "context" + "errors" + "io" + "reflect" + "testing" + "time" + + "github.com/d--j/go-milter" + "github.com/d--j/go-milter/internal/wire" +) + +type mockSession struct { + modifications []*wire.Message + progressCalled int + macros *milter.MacroBag + WritePacket, WriteProgress func(msg *wire.Message) error +} + +func (s *mockSession) writePacket(msg *wire.Message) error { + s.modifications = append(s.modifications, msg) + return nil +} + +func (s *mockSession) writeProgress(_ *wire.Message) error { + s.progressCalled++ + return nil +} + +func (s *mockSession) newModifier() *milter.Modifier { + if s.macros == nil { + m := milter.NewMacroBag() + m.Set(milter.MacroIfName, "ifname") + m.Set(milter.MacroIfAddr, "127.0.0.3") + m.Set(milter.MacroTlsVersion, "tls-version") + m.Set(milter.MacroCipher, "cipher") + m.Set(milter.MacroCipherBits, "cipher-bits") + m.Set(milter.MacroCertSubject, "cert-subject") + m.Set(milter.MacroCertIssuer, "cert-issuer") + m.Set(milter.MacroMailMailer, "mail-mailer") + m.Set(milter.MacroAuthAuthen, "auth-authen") + m.Set(milter.MacroAuthType, "auth-type") + m.Set(milter.MacroRcptMailer, "rcpt-mailer") + m.Set(milter.MacroQueueId, "Q123") + s.macros = m + } + if s.WritePacket == nil { + s.WritePacket = s.writePacket + } + if s.WriteProgress == nil { + s.WriteProgress = s.writeProgress + } + return milter.NewTestModifier(s.macros, s.WritePacket, s.WriteProgress, milter.AllClientSupportedActionMasks, milter.DataSize64K) +} + +func newMockBackend() (*backend, *mockSession) { + return &backend{ + opts: options{ + decisionAt: DecisionAtEndOfMessage, + errorHandling: Error, + }, + leadingSpace: false, + decision: nil, + transaction: &Transaction{}, + }, &mockSession{} +} + +func assertContinue(t *testing.T, resp *milter.Response, err error) { + t.Helper() + if err != nil { + t.Fatalf("got err %s", err) + } + if resp != milter.RespContinue { + t.Fatalf("got resp %v expected continue", resp) + } +} + +func Test_backend_Abort(t *testing.T) { + t.Parallel() + b, s := newMockBackend() + trx := Transaction{Connect: Connect{Host: "host"}, Helo: Helo{Name: "name"}} + b.transaction = &trx + if err := b.Abort(s.newModifier()); err != nil { + t.Errorf("expected nil, got %s", err) + } + if b.transaction == &trx { + t.Errorf("expected new transaction") + } + if b.transaction.Connect.Host != "host" || b.transaction.Helo.Name != "name" { + t.Errorf("expected Connect and Helo to persist") + } + b.transaction = nil + if err := b.Abort(s.newModifier()); err != nil { + t.Errorf("expected nil, got %s", err) + } + if b.transaction == nil { + t.Errorf("expected new transaction") + } +} + +func Test_backend_BodyChunk(t *testing.T) { + t.Parallel() + b, s := newMockBackend() + resp, err := b.BodyChunk([]byte("test"), s.newModifier()) + assertContinue(t, resp, err) + resp, err = b.BodyChunk([]byte("test"), s.newModifier()) + assertContinue(t, resp, err) + if b.transaction.body == nil { + t.Fatal("body file is nil") + } + _, _ = b.transaction.body.Seek(0, io.SeekStart) + data, err := io.ReadAll(b.transaction.body) + b.transaction.cleanup() + if string(data) != "testtest" { + t.Fatalf("got %q, expected %q", data, "testtest") + } +} + +func Test_backend_Cleanup(t *testing.T) { + t.Parallel() + b, _ := newMockBackend() + trx := Transaction{} + b.transaction = &trx + b.Cleanup() + if b.transaction == &trx { + t.Errorf("expected new transaction") + } +} + +func Test_backend_Connect(t *testing.T) { + t.Parallel() + b, s := newMockBackend() + resp, err := b.Connect("host", "family", 123, "127.0.0.2", s.newModifier()) + assertContinue(t, resp, err) + expect := Connect{ + Host: "host", + Family: "family", + Port: 123, + Addr: "127.0.0.2", + IfName: "ifname", + IfAddr: "127.0.0.3", + } + got := b.transaction.Connect + if !reflect.DeepEqual(got, expect) { + t.Fatalf("Connect() = %v, expected %v", got, expect) + } +} + +func Test_backend_Data(t *testing.T) { + t.Parallel() + b, s := newMockBackend() + resp, err := b.Data(s.newModifier()) + assertContinue(t, resp, err) + expect := "Q123" + got := b.transaction.QueueId + if !reflect.DeepEqual(got, expect) { + t.Fatalf("Data() = %q, expected %q", got, expect) + } +} + +func Test_backend_EndOfMessage(t *testing.T) { + t.Parallel() + b, s := newMockBackend() + expectedErr := errors.New("error") + b.decision = func(_ context.Context, trx *Transaction) (Decision, error) { + if trx.QueueId != "Q123" { + t.Fatalf("QueueId = %q, expected %q", trx.QueueId, "Q123") + } + return nil, expectedErr + } + resp, err := b.EndOfMessage(s.newModifier()) + if resp != nil || err != expectedErr { + t.Fatalf("wrong return %v, %v", resp, err) + } + b.Cleanup() + b.transaction.addHeader("subject", "subject: test") + b.decision = func(_ context.Context, trx *Transaction) (Decision, error) { + if subj := trx.Headers.Get("Subject"); subj != " test" { + t.Fatalf("Subject = %q, expected %q", subj, " test") + } + return nil, expectedErr + } + resp, err = b.EndOfMessage(s.newModifier()) + if resp != nil || err != expectedErr { + t.Fatalf("wrong return %v, %v", resp, err) + } + b.Cleanup() + b.decision = func(_ context.Context, trx *Transaction) (Decision, error) { + return TempFail, nil + } + resp, err = b.EndOfMessage(s.newModifier()) + if resp != milter.RespTempFail || err != nil { + t.Fatalf("wrong return %v, %v", resp, err) + } + b.Cleanup() + b.decision = func(_ context.Context, trx *Transaction) (Decision, error) { + return Reject, nil + } + resp, err = b.EndOfMessage(s.newModifier()) + if resp != milter.RespReject || err != nil { + t.Fatalf("wrong return %v, %v", resp, err) + } + b.Cleanup() + b.decision = func(_ context.Context, trx *Transaction) (Decision, error) { + return Discard, nil + } + resp, err = b.EndOfMessage(s.newModifier()) + if resp != milter.RespDiscard || err != nil { + t.Fatalf("wrong return %v, %v", resp, err) + } + b.Cleanup() + b.decision = func(_ context.Context, trx *Transaction) (Decision, error) { + return CustomErrorResponse(400, "not right now"), nil + } + resp, err = b.EndOfMessage(s.newModifier()) + if resp == nil || resp.Response().Code != wire.Code(wire.ActReplyCode) || err != nil { + t.Fatalf("wrong return %v, %v", resp, err) + } + b.Cleanup() + b.decision = func(_ context.Context, trx *Transaction) (Decision, error) { + return CustomErrorResponse(200, "not right now"), nil + } + resp, err = b.EndOfMessage(s.newModifier()) + if resp != milter.RespTempFail || err != nil { + t.Fatalf("wrong return %v, %v", resp, err) + } +} + +func Test_backend_Header(t *testing.T) { + t.Parallel() + b, s := newMockBackend() + resp, err := b.Header("from", "root", s.newModifier()) + assertContinue(t, resp, err) + b.leadingSpace = true + resp, err = b.Header("To", " root, nobody", s.newModifier()) + assertContinue(t, resp, err) + resp, err = b.Header("To", "root, nobody", s.newModifier()) + assertContinue(t, resp, err) + b.leadingSpace = false + resp, err = b.Header("To", "\troot, nobody", s.newModifier()) + assertContinue(t, resp, err) + expect := []*headerField{ + {0, "From", "from: root"}, + {1, "To", "To: root, nobody"}, + {2, "To", "To: root, nobody"}, + {3, "To", "To:\troot, nobody"}, + } + got := b.transaction.headers.fields + if !reflect.DeepEqual(got, expect) { + t.Fatalf("Header() = %s, expected %s", outputFields(got), outputFields(expect)) + } +} + +func Test_backend_Headers(t *testing.T) { +} + +func Test_backend_Helo(t *testing.T) { + t.Parallel() + b, s := newMockBackend() + resp, err := b.Helo("helohost", s.newModifier()) + assertContinue(t, resp, err) + expect := Helo{ + Name: "helohost", + TlsVersion: "tls-version", + Cipher: "cipher", + CipherBits: "cipher-bits", + CertSubject: "cert-subject", + CertIssuer: "cert-issuer", + } + got := b.transaction.Helo + if !reflect.DeepEqual(got, expect) { + t.Fatalf("Helo() = %v, expected %v", got, expect) + } +} + +func Test_backend_MailFrom(t *testing.T) { + t.Parallel() + b, s := newMockBackend() + resp, err := b.MailFrom("root@localhost", "A=B", s.newModifier()) + assertContinue(t, resp, err) + expect := MailFrom{ + addr: addr{Addr: "root@localhost", Args: "A=B"}, + transport: "mail-mailer", + authenticatedUser: "auth-authen", + authenticationMethod: "auth-type", + } + got := b.transaction.mailFrom + if !reflect.DeepEqual(got, expect) { + t.Fatalf("MailFrom() = %v, expected %v", got, expect) + } +} + +func Test_backend_RcptTo(t *testing.T) { + t.Parallel() + b, s := newMockBackend() + resp, err := b.RcptTo("root@localhost", "A=B", s.newModifier()) + assertContinue(t, resp, err) + s.macros.Set(milter.MacroRcptMailer, "2") + resp, err = b.RcptTo("nobody@localhost", "", s.newModifier()) + assertContinue(t, resp, err) + expect := []RcptTo{{ + addr: addr{Addr: "root@localhost", Args: "A=B"}, + transport: "rcpt-mailer", + }, { + addr: addr{Addr: "nobody@localhost", Args: ""}, + transport: "2", + }} + got := b.transaction.rcptTos + if !reflect.DeepEqual(got, expect) { + t.Fatalf("RcptTo() = %v, expected %v", got, expect) + } +} + +func Test_backend_decideOrContinue(t *testing.T) { + t.Parallel() + b, s := newMockBackend() + resp, err := b.decideOrContinue(DecisionAtHelo, s.newModifier()) + assertContinue(t, resp, err) + b.opts.decisionAt = DecisionAtHelo + b.decision = func(ctx context.Context, trx *Transaction) (Decision, error) { + return Accept, nil + } + resp, err = b.decideOrContinue(DecisionAtHelo, s.newModifier()) + if err != nil { + t.Fatalf("got err %s", err) + } + if resp != milter.RespAccept { + t.Fatalf("got resp %v expected accept", resp) + } + b.Cleanup() + b.decision = func(ctx context.Context, trx *Transaction) (Decision, error) { + return nil, io.EOF + } + _, err = b.decideOrContinue(DecisionAtHelo, s.newModifier()) + if err != io.EOF { + t.Fatalf("got err %v, want io.EOF", err) + } +} + +func Test_backend_error(t *testing.T) { + savedWarning := milter.LogWarning + defer func() { + milter.LogWarning = savedWarning + }() + warningCalled := 0 + milter.LogWarning = func(_ string, _ ...interface{}) { + warningCalled++ + } + expected := errors.New("error") + b, _ := newMockBackend() + resp, err := b.error(expected) + if err != expected || resp != nil { + t.Fatalf("error() wrong return values %v, %v", resp, err) + } + if warningCalled != 0 { + t.Fatalf("wrong warningCalled value %d", warningCalled) + } + b.opts.errorHandling = AcceptWhenError + resp, err = b.error(expected) + if err != expected || resp != milter.RespAccept { + t.Fatalf("error() wrong return values %v, %v", resp, err) + } + if warningCalled != 1 { + t.Fatalf("wrong warningCalled value %d", warningCalled) + } + b.opts.errorHandling = TempFailWhenError + resp, err = b.error(expected) + if err != expected || resp != milter.RespTempFail { + t.Fatalf("error() wrong return values %v, %v", resp, err) + } + if warningCalled != 2 { + t.Fatalf("wrong warningCalled value %d", warningCalled) + } + b.opts.errorHandling = RejectWhenError + resp, err = b.error(expected) + if err != expected || resp != milter.RespReject { + t.Fatalf("error() wrong return values %v, %v", resp, err) + } + if warningCalled != 3 { + t.Fatalf("wrong warningCalled value %d", warningCalled) + } + + defer func() { _ = recover() }() + b.opts.errorHandling = 99 + _, _ = b.error(expected) + t.Errorf("did not panic") +} + +func Test_backend_makeDecision(t *testing.T) { + t.Parallel() + b, s := newMockBackend() + b.decision = func(ctx context.Context, trx *Transaction) (Decision, error) { + return Accept, nil + } + b.makeDecision(s.newModifier()) + if b.transaction.decision != Accept || b.transaction.decisionErr != nil { + t.Fatal("values not set") + } + if s.progressCalled > 0 { + t.Fatal("progress called") + } + b.Cleanup() + b.decision = func(ctx context.Context, trx *Transaction) (Decision, error) { + time.Sleep(time.Second + 30*time.Millisecond) + return Accept, nil + } + b.makeDecision(s.newModifier()) + if b.transaction.decision != Accept || b.transaction.decisionErr != nil { + t.Fatal("values not set") + } + if s.progressCalled != 1 { + t.Fatal("progress not called") + } + b.Cleanup() + expect := errors.New("error") + s.WriteProgress = func(_ *wire.Message) error { + return expect + } + b.makeDecision(s.newModifier()) + if b.transaction.decision != Accept || b.transaction.decisionErr != expect { + t.Fatal("values not set") + } +} diff --git a/mailfilter/decision.go b/mailfilter/decision.go new file mode 100644 index 0000000..157f41f --- /dev/null +++ b/mailfilter/decision.go @@ -0,0 +1,46 @@ +package mailfilter + +import "strconv" + +type Decision interface { + getCode() uint16 + getReason() string +} + +type decision string + +func (d decision) getCode() uint16 { + c, _ := strconv.ParseUint(string(d[:3]), 10, 16) + return uint16(c) +} + +func (d decision) getReason() string { + return string(d[4:]) +} + +const ( + Accept decision = "250 accept" + Reject decision = "550 5.7.1 Command rejected" + TempFail decision = "451 4.7.1 Service unavailable - try again later" + Discard decision = "250 discard" +) + +type customResponse struct { + code uint16 + reason string +} + +func (c customResponse) getCode() uint16 { + return c.code +} + +func (c customResponse) getReason() string { + return c.reason +} + +func CustomErrorResponse(code uint16, reason string) Decision { + return &customResponse{ + code: code, + reason: reason, + } +} diff --git a/mailfilter/decision_test.go b/mailfilter/decision_test.go new file mode 100644 index 0000000..b49fc6f --- /dev/null +++ b/mailfilter/decision_test.go @@ -0,0 +1,111 @@ +package mailfilter + +import ( + "reflect" + "testing" +) + +func TestCustomErrorResponse(t *testing.T) { + type args struct { + code uint16 + reason string + } + tests := []struct { + name string + args args + want Decision + }{ + {"works", args{400, "test"}, &customResponse{400, "test"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := CustomErrorResponse(tt.args.code, tt.args.reason); !reflect.DeepEqual(got, tt.want) { + t.Errorf("CustomErrorResponse() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_customResponse_getCode(t *testing.T) { + type fields struct { + code uint16 + reason string + } + tests := []struct { + name string + fields fields + want uint16 + }{ + {"works", fields{400, "test"}, 400}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := customResponse{ + code: tt.fields.code, + reason: tt.fields.reason, + } + if got := c.getCode(); got != tt.want { + t.Errorf("getCode() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_customResponse_getReason(t *testing.T) { + type fields struct { + code uint16 + reason string + } + tests := []struct { + name string + fields fields + want string + }{ + {"works", fields{400, "test"}, "test"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + c := customResponse{ + code: tt.fields.code, + reason: tt.fields.reason, + } + if got := c.getReason(); got != tt.want { + t.Errorf("getReason() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_decision_getCode(t *testing.T) { + tests := []struct { + name string + d decision + want uint16 + }{ + {"works", Accept, 250}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.d.getCode(); got != tt.want { + t.Errorf("getCode() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_decision_getReason(t *testing.T) { + tests := []struct { + name string + d decision + want string + }{ + {"works", Accept, "accept"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := tt.d.getReason(); got != tt.want { + t.Errorf("getReason() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/mailfilter/example_test.go b/mailfilter/example_test.go new file mode 100644 index 0000000..badbc73 --- /dev/null +++ b/mailfilter/example_test.go @@ -0,0 +1,46 @@ +package mailfilter_test + +import ( + "context" + "flag" + "log" + "strings" + + "github.com/d--j/go-milter/mailfilter" +) + +func ExampleNew() { + // parse commandline arguments + var protocol, address string + flag.StringVar(&protocol, "proto", "tcp", "Protocol family (unix or tcp)") + flag.StringVar(&address, "addr", "127.0.0.1:10003", "Bind to address or unix domain socket") + flag.Parse() + + // create and start the mail filter + mailFilter, err := mailfilter.New(protocol, address, + func(_ context.Context, trx *mailfilter.Transaction) (mailfilter.Decision, error) { + // Reject message when it was sent to our SPAM trap + if trx.HasRcptTo("spam-trap@スパム.example.com") { + return mailfilter.CustomErrorResponse(550, "5.7.1 No thank you"), nil + } + // Prefix subject with [⚠️EXTERNAL] when user is not logged in + if trx.MailFrom.AuthenticatedUser() == "" { + subject, _ := trx.Headers.Subject() + if !strings.HasPrefix(subject, "[⚠️EXTERNAL] ") { + subject = "[⚠️EXTERNAL] " + subject + } + trx.Headers.SetSubject(subject) + } + return mailfilter.Accept, nil + }, + // optimization: we do not need the body of the message for our decision + mailfilter.WithoutBody(), + ) + if err != nil { + log.Fatal(err) + } + log.Printf("Started milter on %s:%s", mailFilter.Addr().Network(), mailFilter.Addr().String()) + + // wait for the mail filter to end + mailFilter.Wait() +} diff --git a/mailfilter/header.go b/mailfilter/header.go new file mode 100644 index 0000000..0c79817 --- /dev/null +++ b/mailfilter/header.go @@ -0,0 +1,494 @@ +package mailfilter + +import ( + "io" + netMail "net/mail" + "net/textproto" + "strings" + "time" + + "github.com/emersion/go-message/mail" +) + +const helperKey = "Helper" +const dateLayout = "Mon, 02 Jan 2006 15:04:05 -0700" + +func newHelper() *mail.Header { + helper := mail.HeaderFromMap(map[string][]string{helperKey: {" "}}) + return &helper +} + +type headerField struct { + index int + canonicalKey string + raw string +} + +func (f *headerField) key() string { + return f.raw[:len(f.canonicalKey)] +} + +func (f *headerField) value() string { + return f.raw[len(f.canonicalKey)+1:] +} + +func (f *headerField) deleted() bool { + return len(f.raw) <= len(f.canonicalKey)+1 +} + +type Header struct { + fields []*headerField + helper *mail.Header +} + +func (h *Header) copy() *Header { + h2 := Header{} + h2.fields = make([]*headerField, len(h.fields)) + for i, f := range h.fields { + c := *f + h2.fields[i] = &c + } + return &h2 +} + +func (h *Header) addRaw(key string, raw string) { + h.fields = append(h.fields, &headerField{len(h.fields), textproto.CanonicalMIMEHeaderKey(key), raw}) +} + +func (h *Header) Add(key string, value string) { + h.fields = append(h.fields, &headerField{-1, textproto.CanonicalMIMEHeaderKey(key), getRaw(key, value)}) +} + +func (h *Header) Get(key string) string { + canonicalKey := textproto.CanonicalMIMEHeaderKey(key) + for _, f := range h.fields { + if f.canonicalKey == canonicalKey { + return f.value() + } + } + return "" +} + +func (h *Header) GetText(key string) (string, error) { + if h.helper == nil { + h.helper = newHelper() + } + canonicalKey := textproto.CanonicalMIMEHeaderKey(key) + for _, f := range h.fields { + if f.canonicalKey == canonicalKey { + h.helper.Set(helperKey, f.value()) + return h.helper.Text(helperKey) + } + } + return "", nil +} + +func (h *Header) GetAddressList(key string) ([]*mail.Address, error) { + if h.helper == nil { + h.helper = newHelper() + } + canonicalKey := textproto.CanonicalMIMEHeaderKey(key) + for _, f := range h.fields { + if f.canonicalKey == canonicalKey { + h.helper.Set(helperKey, f.value()) + return h.helper.AddressList(helperKey) + } + } + return []*mail.Address{}, nil +} + +func (h *Header) Set(key string, value string) { + canonicalKey := textproto.CanonicalMIMEHeaderKey(key) + for i := range h.fields { + if h.fields[i].canonicalKey == canonicalKey { + h.fields[i] = &headerField{ + index: h.fields[i].index, + canonicalKey: canonicalKey, + raw: getRaw(h.fields[i].key(), value), + } + return + } + } + if value != "" { + h.Add(key, value) + } +} + +func (h *Header) SetText(key string, value string) { + if h.helper == nil { + h.helper = newHelper() + } + h.helper.SetText(helperKey, value) + h.Set(key, h.helper.Get(helperKey)) +} + +func (h *Header) SetAddressList(key string, addresses []*mail.Address) { + if h.helper == nil { + h.helper = newHelper() + } + h.helper.SetAddressList(helperKey, addresses) + h.Set(key, h.helper.Get(helperKey)) +} + +func (h *Header) Subject() (string, error) { + return h.GetText("Subject") +} + +func (h *Header) SetSubject(value string) { + h.SetText("Subject", value) +} + +func (h *Header) Date() (time.Time, error) { + return netMail.ParseDate(h.Get("Date")) +} + +// SetDate sets the Date header to the value. +// The zero value of [time.Time] as valid. This will delete the Date header when it exists. +func (h *Header) SetDate(value time.Time) { + if value.IsZero() { + h.Set("Date", "") + } else { + h.Set("Date", value.Format(dateLayout)) + } +} + +func (h *Header) Fields() *HeaderFields { + return &HeaderFields{ + cursor: -1, + skip: 0, + h: h, + helper: newHelper(), + } +} + +func (h *Header) Reader() io.Reader { + const crlf = "\r\n" + readers := make([]io.Reader, 0, len(h.fields)*2+1) + for _, f := range h.fields { + if !f.deleted() { // skip deleted + readers = append(readers, strings.NewReader(f.raw)) + readers = append(readers, strings.NewReader(crlf)) + } + } + readers = append(readers, strings.NewReader(crlf)) + return io.MultiReader(readers...) +} + +type HeaderFields struct { + cursor int + skip int + h *Header + helper *mail.Header +} + +func (f *HeaderFields) Next() bool { + f.cursor += f.skip // skip the InsertAfter headers + f.skip = 0 + f.cursor += 1 + return f.cursor < len(f.h.fields) +} + +// Len returns the number of fields in the header. +// This also includes deleted headers fields. +// Initially no fields are deleted so Len returns the actual number of header fields. +func (f *HeaderFields) Len() int { + return len(f.h.fields) +} + +func (f *HeaderFields) index() int { + if f.cursor < 0 || f.cursor >= len(f.h.fields) { + panic("index called before call to Next() or after Next() returned false") + } + return f.cursor +} + +func (f *HeaderFields) raw() string { + return f.h.fields[f.index()].raw +} + +func (f *HeaderFields) Key() string { + return f.h.fields[f.index()].key() +} + +func (f *HeaderFields) CanonicalKey() string { + return f.h.fields[f.index()].canonicalKey +} + +// IsDeleted returns true when a previous header modification deleted this header. +// You can "undelete" the header by just calling [HeaderFields.Set] with a non-empty value. +func (f *HeaderFields) IsDeleted() bool { + return f.h.fields[f.index()].deleted() +} + +func (f *HeaderFields) Get() string { + return f.h.fields[f.index()].value() +} + +func (f *HeaderFields) GetText() (string, error) { + f.helper.Set(helperKey, f.Get()) + return f.helper.Text(helperKey) +} + +func (f *HeaderFields) GetAddressList() ([]*mail.Address, error) { + f.helper.Set(helperKey, f.Get()) + return f.helper.AddressList(helperKey) +} + +func getRaw(key string, value string) string { + if len(value) > 0 && !(value[0] == ' ' || value[0] == '\t') { + return key + ": " + value + } else { + return key + ":" + value + } +} + +func (f *HeaderFields) Set(value string) { + idx := f.index() + f.h.fields[idx] = &headerField{f.h.fields[idx].index, f.CanonicalKey(), getRaw(f.Key(), value)} +} + +func (f *HeaderFields) text(value string) string { + f.helper.SetText(helperKey, value) + return f.helper.Get(helperKey) +} + +func (f *HeaderFields) SetText(value string) { + f.Set(f.text(value)) +} + +func (f *HeaderFields) addressList(value []*mail.Address) string { + f.helper.SetAddressList(helperKey, value) + return f.helper.Get(helperKey) +} + +func (f *HeaderFields) SetAddressList(value []*mail.Address) { + f.Set(f.addressList(value)) +} + +func (f *HeaderFields) Del() { + f.Set("") +} + +func (f *HeaderFields) Replace(key string, value string) { + idx := f.index() + f.h.fields[idx] = &headerField{f.h.fields[idx].index, textproto.CanonicalMIMEHeaderKey(key), getRaw(key, value)} +} + +func (f *HeaderFields) ReplaceText(key string, value string) { + f.Replace(key, f.text(value)) +} + +func (f *HeaderFields) ReplaceAddressList(key string, value []*mail.Address) { + f.Replace(key, f.addressList(value)) +} + +func (f *HeaderFields) insert(index int, key string, value string) { + tail := make([]*headerField, 1, 1+len(f.h.fields)-index) + tail[0] = &headerField{-1, textproto.CanonicalMIMEHeaderKey(key), getRaw(key, value)} + tail = append(tail, f.h.fields[index:]...) + f.h.fields = append(f.h.fields[:index], tail...) +} + +func (f *HeaderFields) InsertBefore(key string, value string) { + f.insert(f.index(), key, value) + f.cursor += 1 +} + +func (f *HeaderFields) InsertTextBefore(key string, value string) { + f.InsertBefore(key, f.text(value)) +} + +func (f *HeaderFields) InsertAddressListBefore(key string, value []*mail.Address) { + f.InsertBefore(key, f.addressList(value)) +} + +func (f *HeaderFields) InsertAfter(key string, value string) { + f.skip += 1 + f.insert(f.index()+f.skip, key, value) +} + +func (f *HeaderFields) InsertTextAfter(key string, value string) { + f.InsertAfter(key, f.text(value)) +} + +func (f *HeaderFields) InsertAddressListAfter(key string, value []*mail.Address) { + f.InsertAfter(key, f.addressList(value)) +} + +const ( + kindEqual = iota + kindChange + kindInsert +) + +type headerFieldDiff struct { + kind int + field *headerField + index int +} + +func diffHeaderFieldsMiddle(orig []*headerField, changed []*headerField, index int) (diffs []headerFieldDiff) { + // either orig and changed are empty or the first element is different + origLen, changedLen := len(orig), len(changed) + changedI := 0 + switch { + case origLen == 0 && changedLen == 0: + return nil + case origLen == 0: + // orig empty -> everything must be inserts + for _, c := range changed { + diffs = append(diffs, headerFieldDiff{kindInsert, c, index}) + } + return + case changedLen == 0: + // This should not happen since we do not delete headerField entries + // but if the user completely replaces the headers it could indeed happen. + // Panic in this case so the programming error surfaces. + panic("internal structure error: do not completely replace transaction.Headers – use its methods to alter it") + default: // origLen > 0 && changedLen > 0 + o := orig[0] + if o.index < 0 { + panic("internal structure error: all elements in orig need to have an index bigger than -1: do not completely replace transaction.Headers – use its methods to alter it") + } + found := false + // find o in changed + for i, c := range changed { + if c.index == o.index { + found = true + index = o.index + changedI = i + for i = 0; i < changedI; i++ { + diffs = append(diffs, headerFieldDiff{kindInsert, changed[i], index - 1}) + } + if changed[changedI].raw == o.raw { + diffs = append(diffs, headerFieldDiff{kindEqual, o, o.index}) + } else if changed[changedI].key() == o.key() { + diffs = append(diffs, headerFieldDiff{kindChange, changed[changedI], o.index}) + } else { + // a HeaderFields.Replace call, delete the original + diffs = append(diffs, headerFieldDiff{ + kind: kindChange, + field: &headerField{ + index: o.index, + canonicalKey: o.canonicalKey, + raw: o.key() + ":", + }, + index: o.index, + }) + // insert changed in front of deleted header + diffs = append(diffs, headerFieldDiff{kindInsert, &headerField{ + index: -1, + canonicalKey: changed[changedI].canonicalKey, + raw: changed[changedI].raw, + }, index}) + index-- // in this special case we actually do not need to increase the index below + } + changedI++ + break + } else if c.index > o.index { + break + } + } + // if o not in changed we need to delete it + if !found { + diffs = append(diffs, headerFieldDiff{ + kind: kindChange, + field: &headerField{ + index: o.index, + canonicalKey: o.canonicalKey, + raw: o.key() + ":", + }, + index: o.index, + }) + } + // we only consumed the first element of orig + index++ + restDiffs := diffHeaderFields(orig[1:], changed[changedI:], index) + if len(restDiffs) > 0 { + diffs = append(diffs, restDiffs...) + } + return + } +} + +func diffHeaderFields(orig []*headerField, changed []*headerField, index int) (diffs []headerFieldDiff) { + origLen, changedLen := len(orig), len(changed) + // find common prefix + commonPrefixLen, commonSuffixLen := 0, 0 + for i := 0; i < origLen && i < changedLen; i++ { + if orig[i].raw != changed[i].raw || orig[i].index != changed[i].index { + break + } + commonPrefixLen += 1 + index = orig[i].index + } + // find common suffix (down to the commonPrefixLen element) + i, j := origLen-1, changedLen-1 + for i > commonPrefixLen-1 && j > commonPrefixLen-1 { + if orig[i].raw != changed[j].raw || orig[i].index != changed[j].index { + break + } + commonSuffixLen += 1 + i-- + j-- + } + for i := 0; i < commonPrefixLen; i++ { + diffs = append(diffs, headerFieldDiff{kindEqual, orig[i], orig[i].index}) + } + // find the changed parts, recursively calls diffHeaderFields afterwards + middleDiffs := diffHeaderFieldsMiddle(orig[commonPrefixLen:origLen-commonSuffixLen], changed[commonPrefixLen:changedLen-commonSuffixLen], index) + if len(middleDiffs) > 0 { + diffs = append(diffs, middleDiffs...) + } + for i := origLen - commonSuffixLen; i < origLen; i++ { + diffs = append(diffs, headerFieldDiff{kindEqual, orig[i], orig[i].index}) + } + return +} + +type headerOp struct { + Index int + Name string + Value string +} + +// calculateHeaderModifications finds differences between orig and changed. +// The differences are expressed as change and insert operations – to be mapped to milter modification actions. +// Deletions are changes to an empty value. +func calculateHeaderModifications(orig *Header, changed *Header) (changeOps []headerOp, insertOps []headerOp) { + origFields := orig.Fields() + origLen := origFields.Len() + origIndexByKeyCounter := make(map[string]int) + origIndexByKey := make([]int, origLen) + for i := 0; origFields.Next(); i++ { + origIndexByKeyCounter[origFields.CanonicalKey()] += 1 + origIndexByKey[i] = origIndexByKeyCounter[origFields.CanonicalKey()] + } + diffs := diffHeaderFields(orig.fields, changed.fields, -1) + for _, diff := range diffs { + switch diff.kind { + case kindInsert: + insertOps = append(insertOps, headerOp{ + Index: diff.index + 1, + Name: diff.field.key(), + Value: diff.field.value(), + }) + case kindChange: + if diff.index < origLen { + changeOps = append(changeOps, headerOp{ + Index: origIndexByKey[diff.index], + Name: diff.field.key(), + Value: diff.field.value(), + }) + } else { // should not happen but just make inserts out of it + insertOps = append(insertOps, headerOp{ + Index: diff.index + 1, + Name: diff.field.key(), + Value: diff.field.value(), + }) + } + } + } + + return +} diff --git a/mailfilter/header_test.go b/mailfilter/header_test.go new file mode 100644 index 0000000..5bbb85c --- /dev/null +++ b/mailfilter/header_test.go @@ -0,0 +1,1232 @@ +package mailfilter + +import ( + "fmt" + "io" + "reflect" + "strings" + "testing" + "time" + + "github.com/emersion/go-message/mail" +) + +func testHeader() *Header { + return &Header{fields: []*headerField{ + { + index: 0, + canonicalKey: "From", + raw: "From: ", + }, + { + index: 1, + canonicalKey: "To", + raw: "To: , ", + }, + { + index: 2, + canonicalKey: "Subject", + raw: "subject: =?UTF-8?Q?=F0=9F=9F=A2?=", // 🟢 + }, + { + index: 3, + canonicalKey: "Date", + raw: "DATE:\tWed, 01 Mar 2023 15:47:33 +0100", + }, + }} +} + +var root, nobody = mail.Address{ + Name: "", + Address: "root@localhost", +}, mail.Address{ + Name: "", + Address: "nobody@localhost", +} + +func TestHeaderFields_CanonicalKey(t *testing.T) { + type fields struct { + cursor int + h *Header + } + tests := []struct { + name string + fields fields + want string + }{ + {"From", fields{0, testHeader()}, "From"}, + {"To", fields{1, testHeader()}, "To"}, + {"Subject", fields{2, testHeader()}, "Subject"}, + {"Date", fields{3, testHeader()}, "Date"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &HeaderFields{ + cursor: tt.fields.cursor, + h: tt.fields.h, + helper: newHelper(), + } + if got := f.CanonicalKey(); got != tt.want { + t.Errorf("CanonicalKey() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHeaderFields_Del(t *testing.T) { + type fields struct { + cursor int + h *Header + } + tests := []struct { + name string + fields fields + want *headerField + }{ + {"First", fields{0, testHeader()}, &headerField{0, "From", "From:"}}, + {"Third", fields{2, testHeader()}, &headerField{2, "Subject", "subject:"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &HeaderFields{ + cursor: tt.fields.cursor, + h: tt.fields.h, + helper: newHelper()} + f.Del() + got := f.h.fields[f.index()] + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Del() = %+v, want %+v", got, tt.want) + } + }) + } +} + +func TestHeaderFields_Get(t *testing.T) { + type fields struct { + cursor int + h *Header + } + tests := []struct { + name string + fields fields + want string + }{ + {"From", fields{0, testHeader()}, " "}, + {"To", fields{1, testHeader()}, " , "}, + {"Subject", fields{2, testHeader()}, " =?UTF-8?Q?=F0=9F=9F=A2?="}, + {"Date", fields{3, testHeader()}, "\tWed, 01 Mar 2023 15:47:33 +0100"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &HeaderFields{ + cursor: tt.fields.cursor, + h: tt.fields.h, + helper: newHelper(), + } + if got := f.Get(); got != tt.want { + t.Errorf("Get() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHeaderFields_GetAddressList(t *testing.T) { + type fields struct { + cursor int + h *Header + } + tests := []struct { + name string + fields fields + want []*mail.Address + wantErr bool + }{ + {"From", fields{0, testHeader()}, []*mail.Address{&root}, false}, + {"To", fields{1, testHeader()}, []*mail.Address{&root, &nobody}, false}, + {"Subject", fields{2, testHeader()}, nil, true}, + {"Date", fields{3, testHeader()}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &HeaderFields{ + cursor: tt.fields.cursor, + h: tt.fields.h, + helper: newHelper(), + } + got, err := f.GetAddressList() + if (err != nil) != tt.wantErr { + t.Errorf("GetAddressList() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetAddressList() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHeaderFields_GetText(t *testing.T) { + type fields struct { + cursor int + h *Header + } + tests := []struct { + name string + fields fields + want string + wantErr bool + }{ + {"From", fields{0, testHeader()}, " ", false}, + {"To", fields{1, testHeader()}, " , ", false}, + {"Subject", fields{2, testHeader()}, " 🟢", false}, + {"Date", fields{3, testHeader()}, "\tWed, 01 Mar 2023 15:47:33 +0100", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &HeaderFields{ + cursor: tt.fields.cursor, + h: tt.fields.h, + helper: newHelper(), + } + got, err := f.GetText() + if (err != nil) != tt.wantErr { + t.Errorf("GetText() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("GetText() got = %v, want %v", got, tt.want) + } + }) + } +} + +func outputFields(fields []*headerField) string { + h := Header{fields: fields} + bytes, _ := io.ReadAll(h.Reader()) + return string(bytes) +} + +func TestHeaderFields_InsertAfter(t *testing.T) { + type fields struct { + cursor int + h *Header + } + type args struct { + key string + value string + } + addOne := []args{{"Test", "one"}} + expectOne := []*headerField{{-1, "Test", "Test: one"}} + addTwo := []args{{"Test", "one"}, {"Test", "two"}} + expectTwo := []*headerField{{-1, "Test", "Test: one"}, {-1, "Test", "Test: two"}} + tests := []struct { + name string + fields fields + args []args + want []*headerField + wantSkip int + wantNext bool + }{ + {"From", fields{0, testHeader()}, addOne, append(testHeader().fields[:1], append(expectOne, testHeader().fields[1:]...)...), 1, true}, + {"To", fields{1, testHeader()}, addOne, append(testHeader().fields[:2], append(expectOne, testHeader().fields[2:]...)...), 1, true}, + {"Subject", fields{2, testHeader()}, addOne, append(testHeader().fields[:3], append(expectOne, testHeader().fields[3:]...)...), 1, true}, + {"Date", fields{3, testHeader()}, addOne, append(testHeader().fields[:4], append(expectOne, testHeader().fields[4:]...)...), 1, false}, + {"From2", fields{0, testHeader()}, addTwo, append(testHeader().fields[:1], append(expectTwo, testHeader().fields[1:]...)...), 2, true}, + {"Date2", fields{3, testHeader()}, addTwo, append(testHeader().fields[:4], append(expectTwo, testHeader().fields[4:]...)...), 2, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &HeaderFields{ + cursor: tt.fields.cursor, + h: tt.fields.h, + helper: newHelper(), + } + for _, arg := range tt.args { + f.InsertAfter(arg.key, arg.value) + } + got := tt.fields.h.fields + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("InsertAfter() = %q, want %q", outputFields(got), outputFields(tt.want)) + } + gotSkip := f.skip + if gotSkip != tt.wantSkip { + t.Errorf("InsertAfter() = %v, wantSkip %v", gotSkip, tt.wantSkip) + } + gotNext := f.Next() + if gotNext != tt.wantNext { + t.Errorf("InsertAfter() = %v, want %v", gotNext, tt.wantNext) + } + }) + } +} + +func TestHeaderFields_InsertBefore(t *testing.T) { + type fields struct { + cursor int + h *Header + } + type args struct { + key string + value string + } + addOne := []args{{"Test", "one"}} + expectOne := []*headerField{{-1, "Test", "Test: one"}} + addTwo := []args{{"Test", "one"}, {"Test", "two"}} + expectTwo := []*headerField{{-1, "Test", "Test: one"}, {-1, "Test", "Test: two"}} + tests := []struct { + name string + fields fields + args []args + want []*headerField + wantSkip int + wantNext bool + }{ + {"From", fields{0, testHeader()}, addOne, append(testHeader().fields[:0], append(expectOne, testHeader().fields[0:]...)...), 0, true}, + {"To", fields{1, testHeader()}, addOne, append(testHeader().fields[:1], append(expectOne, testHeader().fields[1:]...)...), 0, true}, + {"Subject", fields{2, testHeader()}, addOne, append(testHeader().fields[:2], append(expectOne, testHeader().fields[2:]...)...), 0, true}, + {"Date", fields{3, testHeader()}, addOne, append(testHeader().fields[:3], append(expectOne, testHeader().fields[3:]...)...), 0, false}, + {"From2", fields{0, testHeader()}, addTwo, append(testHeader().fields[:0], append(expectTwo, testHeader().fields[0:]...)...), 0, true}, + {"Date2", fields{3, testHeader()}, addTwo, append(testHeader().fields[:3], append(expectTwo, testHeader().fields[3:]...)...), 0, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &HeaderFields{ + cursor: tt.fields.cursor, + h: tt.fields.h, + helper: newHelper(), + } + for _, arg := range tt.args { + f.InsertBefore(arg.key, arg.value) + } + got := tt.fields.h.fields + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("InsertBefore() = %q, want %q", outputFields(got), outputFields(tt.want)) + } + gotSkip := f.skip + if gotSkip != tt.wantSkip { + t.Errorf("InsertAfter() = %v, wantSkip %v", gotSkip, tt.wantSkip) + } + gotNext := f.Next() + if gotNext != tt.wantNext { + t.Errorf("InsertAfter() = %v, want %v", gotNext, tt.wantNext) + } + }) + } +} + +func TestHeaderFields_Key(t *testing.T) { + type fields struct { + cursor int + h *Header + } + tests := []struct { + name string + fields fields + want string + }{ + {"From", fields{0, testHeader()}, "From"}, + {"To", fields{1, testHeader()}, "To"}, + {"Subject", fields{2, testHeader()}, "subject"}, + {"Date", fields{3, testHeader()}, "DATE"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &HeaderFields{ + cursor: tt.fields.cursor, + h: tt.fields.h, + helper: newHelper(), + } + if got := f.Key(); got != tt.want { + t.Errorf("Key() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHeaderFields_Len(t *testing.T) { + type fields struct { + cursor int + h *Header + } + tests := []struct { + name string + fields fields + want int + }{ + {"works", fields{0, testHeader()}, 4}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &HeaderFields{ + cursor: tt.fields.cursor, + h: tt.fields.h, + helper: newHelper(), + } + if got := f.Len(); got != tt.want { + t.Errorf("Len() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHeaderFields_Next(t *testing.T) { + type fields struct { + cursor int + skip int + h *Header + } + tests := []struct { + name string + fields fields + want bool + }{ + {"From", fields{0, 0, testHeader()}, true}, + {"To", fields{1, 0, testHeader()}, true}, + {"Subject", fields{2, 0, testHeader()}, true}, + {"Date", fields{3, 0, testHeader()}, false}, + {"From1", fields{0, 1, testHeader()}, true}, + {"To1", fields{1, 1, testHeader()}, true}, + {"Subject1", fields{2, 1, testHeader()}, false}, + {"Date1", fields{3, 1, testHeader()}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &HeaderFields{ + cursor: tt.fields.cursor, + skip: tt.fields.skip, + h: tt.fields.h, + helper: newHelper(), + } + if got := f.Next(); got != tt.want { + t.Errorf("Next() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHeaderFields_Replace(t *testing.T) { + type fields struct { + cursor int + h *Header + } + type args struct { + key string + value string + } + tests := []struct { + name string + fields fields + args args + want []*headerField + }{ + {"works", fields{0, testHeader()}, args{"new", "header"}, append([]*headerField{{0, "New", "new: header"}}, testHeader().fields[1:]...)}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &HeaderFields{ + cursor: tt.fields.cursor, + h: tt.fields.h, + helper: newHelper(), + } + f.Replace(tt.args.key, tt.args.value) + got := tt.fields.h.fields + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Replace() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHeaderFields_Set(t *testing.T) { + type fields struct { + cursor int + h *Header + } + type args struct { + value string + } + tests := []struct { + name string + fields fields + args args + want *headerField + }{ + {"First", fields{0, testHeader()}, args{"set"}, &headerField{0, "From", "From: set"}}, + {"Third", fields{2, testHeader()}, args{"\tset"}, &headerField{2, "Subject", "subject:\tset"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &HeaderFields{ + cursor: tt.fields.cursor, + h: tt.fields.h, + helper: newHelper(), + } + f.Set(tt.args.value) + got := f.h.fields[f.index()] + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Set() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHeaderFields_SetAddressList(t *testing.T) { + type fields struct { + cursor int + h *Header + } + type args struct { + value []*mail.Address + } + tests := []struct { + name string + fields fields + args args + want *headerField + }{ + {"One", fields{0, testHeader()}, args{[]*mail.Address{&nobody}}, &headerField{0, "From", "From: "}}, + {"Two", fields{1, testHeader()}, args{[]*mail.Address{&nobody, &root}}, &headerField{1, "To", "To: , "}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &HeaderFields{ + cursor: tt.fields.cursor, + h: tt.fields.h, + helper: newHelper(), + } + f.SetAddressList(tt.args.value) + got := f.h.fields[f.index()] + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SetAddressList() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHeaderFields_SetText(t *testing.T) { + type fields struct { + cursor int + h *Header + } + type args struct { + value string + } + tests := []struct { + name string + fields fields + args args + want *headerField + }{ + {"Set", fields{0, testHeader()}, args{"set"}, &headerField{0, "From", "From: set"}}, + {"UTF-8", fields{2, testHeader()}, args{"🔴"}, &headerField{2, "Subject", "subject: =?utf-8?q?=F0=9F=94=B4?="}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &HeaderFields{ + cursor: tt.fields.cursor, + h: tt.fields.h, + helper: newHelper(), + } + f.SetText(tt.args.value) + got := f.h.fields[f.index()] + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SetText() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHeader_Add(t *testing.T) { + type args struct { + key string + value string + } + tests := []struct { + name string + fields []*headerField + args args + want []*headerField + }{ + {"works", testHeader().fields, args{"key", "value"}, append(testHeader().fields, &headerField{-1, "Key", "key: value"})}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &Header{ + fields: tt.fields, + } + h.Add(tt.args.key, tt.args.value) + got := h.fields + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Add() = %q, want %q", outputFields(got), outputFields(tt.want)) + } + }) + } +} + +func TestHeader_Date(t *testing.T) { + brokenDate := testHeader() + brokenDate.fields[3].raw = "Date: broken" + tests := []struct { + name string + fields []*headerField + want time.Time + wantErr bool + }{ + {"Date", testHeader().fields, time.Date(2023, time.March, 1, 15, 47, 33, 0, time.FixedZone("CET", 60*60)), false}, + {"Broken", brokenDate.fields, time.Time{}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &Header{ + fields: tt.fields, + } + got, err := h.Date() + if (err != nil) != tt.wantErr { + t.Errorf("Date() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !got.Equal(tt.want) { + t.Errorf("Date() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHeader_Fields(t *testing.T) { + h := testHeader() + fields := h.Fields() + for fields.Next() { + if fields.CanonicalKey() == "From" { + fields.InsertBefore("X-From1", fields.Get()) + fields.InsertBefore("X-From2", fields.Get()) + if fields.CanonicalKey() != "From" { + t.Error("InsertBefore changed cursor") + } + } + + if fields.CanonicalKey() == "To" { + fields.InsertAfter("X-After1", "value1") + fields.InsertAfter("X-After2", "value2") + if fields.CanonicalKey() != "To" { + t.Error("InsertAfter changed cursor") + } + } + + if fields.CanonicalKey() == "Subject" { + fields.Del() + } + + switch fields.CanonicalKey() { + case "X-From1", "X-From2", "X-After1", "X-After2": + t.Error("iterated over inserted key", fields.CanonicalKey()) + } + } + + if fields.Next() { + t.Error("Next() call should return false") + } + + b, _ := io.ReadAll(h.Reader()) + got := string(b) + expect := "X-From1: \r\nX-From2: \r\nFrom: \r\nTo: , \r\nX-After1: value1\r\nX-After2: value2\r\nDATE:\tWed, 01 Mar 2023 15:47:33 +0100\r\n\r\n" + if got != expect { + t.Errorf("got %q, expect %q", got, expect) + } + + fields2 := h.Fields() + for fields2.Next() { + if fields2.CanonicalKey() == "X-From1" { + fields2.SetAddressList([]*mail.Address{&nobody}) + fields2.ReplaceAddressList("X-To", []*mail.Address{&nobody, &root}) + fields2.ReplaceText("X-Text", "🟡") + } + if fields2.CanonicalKey() == "Subject" { + if !fields2.IsDeleted() { + t.Error("Subject should be deleted") + } + // before/after order scrambled on purpose + fields2.InsertTextBefore("X-Before1", "🟡") + fields2.InsertTextAfter("X-After1", "🟡") + fields2.InsertAddressListAfter("X-After2", []*mail.Address{&nobody}) + fields2.InsertAddressListBefore("X-Before2", []*mail.Address{&nobody}) + } + switch fields2.CanonicalKey() { + case "X-From2", "X-After1", "X-After2": + fields2.Del() + } + } + + b, _ = io.ReadAll(h.Reader()) + got = string(b) + expect = "X-Text: =?utf-8?q?=F0=9F=9F=A1?=\r\nFrom: \r\nTo: , \r\nX-Before1: =?utf-8?q?=F0=9F=9F=A1?=\r\nX-Before2: \r\nX-After1: =?utf-8?q?=F0=9F=9F=A1?=\r\nX-After2: \r\nDATE:\tWed, 01 Mar 2023 15:47:33 +0100\r\n\r\n" + if got != expect { + t.Errorf("got %q, expect %q", got, expect) + } +} + +func TestHeader_Get(t *testing.T) { + type args struct { + key string + } + tests := []struct { + name string + fields []*headerField + args args + want string + }{ + {"works", testHeader().fields, args{"fRoM"}, " "}, + {"not found", testHeader().fields, args{"not-there"}, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &Header{ + fields: tt.fields, + } + if got := h.Get(tt.args.key); got != tt.want { + t.Errorf("Get() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHeader_GetAddressList(t *testing.T) { + type args struct { + key string + } + tests := []struct { + name string + fields []*headerField + args args + want []*mail.Address + wantErr bool + }{ + {"From", testHeader().fields, args{"fRoM"}, []*mail.Address{&root}, false}, + {"To", testHeader().fields, args{"tO"}, []*mail.Address{&root, &nobody}, false}, + {"Subject", testHeader().fields, args{"SUBJECT"}, nil, true}, + {"Date", testHeader().fields, args{"Date"}, nil, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &Header{ + fields: tt.fields, + } + got, err := h.GetAddressList(tt.args.key) + if (err != nil) != tt.wantErr { + t.Errorf("GetAddressList() error = %v, wantErr %v", err, tt.wantErr) + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("GetAddressList() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHeader_GetText(t *testing.T) { + brokenSubject := testHeader() + brokenSubject.fields[2].raw = "Subject: =?e-404?Q?=F0=9F=9F=A2?=" + type args struct { + key string + } + tests := []struct { + name string + fields []*headerField + args args + want string + wantErr bool + }{ + {"works", testHeader().fields, args{"subJeCt"}, " 🟢", false}, + {"broken", brokenSubject.fields, args{"subJeCt"}, " =?e-404?Q?=F0=9F=9F=A2?=", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &Header{ + fields: tt.fields, + } + got, err := h.GetText(tt.args.key) + if (err != nil) != tt.wantErr { + t.Errorf("GetText() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("GetText() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHeader_Reader(t *testing.T) { + tests := []struct { + name string + fields []*headerField + want string + }{ + {"works", testHeader().fields, "From: \r\nTo: , \r\nsubject: =?UTF-8?Q?=F0=9F=9F=A2?=\r\nDATE:\tWed, 01 Mar 2023 15:47:33 +0100\r\n\r\n"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &Header{ + fields: tt.fields, + } + b, err := io.ReadAll(h.Reader()) + if err != nil { + t.Fatal(err) + } + got := string(b) + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Reader() = %q, want %q", got, tt.want) + } + }) + } +} + +func TestHeader_Set(t *testing.T) { + type args struct { + key string + value string + } + tests := []struct { + name string + fields []*headerField + args args + want []*headerField + }{ + {"found", testHeader().fields, args{"suBJect", "value"}, append(testHeader().fields[:2], append([]*headerField{{2, "Subject", "subject: value"}}, testHeader().fields[3:]...)...)}, + {"not-found", testHeader().fields, args{"x-spam", "value"}, append(testHeader().fields, &headerField{-1, "X-Spam", "x-spam: value"})}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &Header{ + fields: tt.fields, + } + h.Set(tt.args.key, tt.args.value) + got := h.fields + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Set() = %q, want %q", outputFields(got), outputFields(tt.want)) + } + }) + } +} + +func TestHeader_SetAddressList(t *testing.T) { + type args struct { + key string + addresses []*mail.Address + } + tests := []struct { + name string + fields []*headerField + args args + want []*headerField + }{ + {"works", testHeader().fields, args{"x-to", []*mail.Address{&root}}, append(testHeader().fields, &headerField{-1, "X-To", "x-to: "})}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &Header{ + fields: tt.fields, + } + h.SetAddressList(tt.args.key, tt.args.addresses) + got := h.fields + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SetAddressList() = %q, want %q", outputFields(got), outputFields(tt.want)) + } + }) + } +} + +func TestHeader_SetDate(t *testing.T) { + type args struct { + value time.Time + } + tests := []struct { + name string + fields []*headerField + args args + want []*headerField + }{ + {"works", testHeader().fields, args{time.Date(1980, time.January, 1, 12, 0, 0, 0, time.UTC)}, append(testHeader().fields[:3], &headerField{3, "Date", "DATE: Tue, 01 Jan 1980 12:00:00 +0000"})}, + {"zero-ok", testHeader().fields, args{time.Time{}}, append(testHeader().fields[:3], &headerField{3, "Date", "DATE:"})}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &Header{ + fields: tt.fields, + } + h.SetDate(tt.args.value) + got := h.fields + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SetDate() = %q, want %q", outputFields(got), outputFields(tt.want)) + } + }) + } +} + +func TestHeader_SetSubject(t *testing.T) { + type args struct { + value string + } + tests := []struct { + name string + fields []*headerField + args args + want []*headerField + }{ + {"works", testHeader().fields, args{"set"}, append(testHeader().fields[:2], &headerField{2, "Subject", "subject: set"}, testHeader().fields[3])}, + {"zero-ok", testHeader().fields, args{""}, append(testHeader().fields[:2], &headerField{2, "Subject", "subject:"}, testHeader().fields[3])}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &Header{ + fields: tt.fields, + } + h.SetSubject(tt.args.value) + got := h.fields + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SetSubject() = %q, want %q", outputFields(got), outputFields(tt.want)) + } + }) + } +} + +func TestHeader_SetText(t *testing.T) { + type args struct { + key string + value string + } + tests := []struct { + name string + fields []*headerField + args args + want []*headerField + }{ + {"works", testHeader().fields, args{"SubJect", "set"}, append(testHeader().fields[:2], &headerField{2, "Subject", "subject: set"}, testHeader().fields[3])}, + {"zero-ok", testHeader().fields, args{"Subject", ""}, append(testHeader().fields[:2], &headerField{2, "Subject", "subject:"}, testHeader().fields[3])}, + {"add", testHeader().fields, args{"x-red", "🔴"}, append(testHeader().fields, &headerField{-1, "X-Red", "x-red: =?utf-8?q?=F0=9F=94=B4?="})}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &Header{ + fields: tt.fields, + } + h.SetText(tt.args.key, tt.args.value) + got := h.fields + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("SetText() = %q, want %q", outputFields(got), outputFields(tt.want)) + } + }) + } +} + +func TestHeader_Subject(t *testing.T) { + brokenSubject := testHeader() + brokenSubject.fields[2].raw = "Subject: =?e-404?Q?=F0=9F=9F=A2?=" + tests := []struct { + name string + fields []*headerField + want string + wantErr bool + }{ + {"works", testHeader().fields, " 🟢", false}, + {"broken", brokenSubject.fields, " =?e-404?Q?=F0=9F=9F=A2?=", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &Header{ + fields: tt.fields, + } + got, err := h.Subject() + if (err != nil) != tt.wantErr { + t.Errorf("Subject() error = %v, wantErr %v", err, tt.wantErr) + return + } + if got != tt.want { + t.Errorf("Subject() got = %v, want %v", got, tt.want) + } + }) + } +} + +func TestHeader_addRaw(t *testing.T) { + type args struct { + key string + raw string + } + tests := []struct { + name string + fields []*headerField + args args + want []*headerField + }{ + {"works", nil, args{key: "TEST", raw: "TEST: value"}, []*headerField{&headerField{canonicalKey: "Test", raw: "TEST: value"}}}, + {"empty-is-ok", nil, args{key: "TEST", raw: "TEST:"}, []*headerField{&headerField{canonicalKey: "Test", raw: "TEST:"}}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + h := &Header{ + fields: tt.fields, + } + h.addRaw(tt.args.key, tt.args.raw) + }) + } +} + +func TestHeader_copy(t *testing.T) { + h := Header{fields: []*headerField{{0, "Test", "Test:"}}} + h2 := h.copy() + h.fields[0].canonicalKey = "Changed" + if len(h2.fields) != len(h.fields) { + t.Fatal("did not copy fields") + } + if h.fields[0].canonicalKey == h2.fields[0].canonicalKey { + t.Fatal("did not copy deep copy fields") + } +} + +func outputDiff(diff []headerFieldDiff) string { + s := strings.Builder{} + for i, d := range diff { + s.WriteString(fmt.Sprintf("%02d %02d ", i, d.index)) + switch d.kind { + case kindEqual: + s.WriteString("equal ") + case kindInsert: + s.WriteString("insert ") + case kindChange: + s.WriteString("change ") + } + s.WriteString(fmt.Sprintf("(c:%s raw:%q idx:%d)\n", d.field.canonicalKey, d.field.raw, d.field.index)) + } + return s.String() +} + +func Test_diffHeaderFields(t *testing.T) { + orig := testHeader() + addOne := testHeader() + addOne.Add("X-Test", "1") + addOneInFront := testHeader() + fields := addOneInFront.Fields() + fields.Next() + fields.InsertBefore("X-Test", "1") + equals := []headerFieldDiff{ + {kindEqual, orig.fields[0], 0}, + {kindEqual, orig.fields[1], 1}, + {kindEqual, orig.fields[2], 2}, + {kindEqual, orig.fields[3], 3}, + } + complexChanges := testHeader() + fields = complexChanges.Fields() + for fields.Next() { + fields.InsertBefore("X-Test", "1") + fields.InsertAfter("X-Test", "1") + if fields.CanonicalKey() == "Subject" { + fields.Set("changed") + } + if fields.CanonicalKey() == "Date" { + fields.Replace("X-Test", "1") + } + } + xTest := headerField{-1, "X-Test", "X-Test: 1"} + subjectChanged := headerField{2, "Subject", "subject: changed"} + dateDel := headerField{3, "Date", "DATE:"} + + type args struct { + orig []*headerField + changed []*headerField + } + tests := []struct { + name string + args args + wantDiffs []headerFieldDiff + }{ + {"equal", args{orig.fields, orig.fields}, equals}, + {"add-one", args{orig.fields, addOne.fields}, append(equals, headerFieldDiff{kindInsert, &xTest, 3})}, + {"add-one-in-front", args{orig.fields, addOneInFront.fields}, append([]headerFieldDiff{{kindInsert, &xTest, -1}}, equals...)}, + {"complex", args{orig.fields, complexChanges.fields}, []headerFieldDiff{ + {kindInsert, &xTest, -1}, + equals[0], + {kindInsert, &xTest, 0}, + {kindInsert, &xTest, 0}, + equals[1], + {kindInsert, &xTest, 1}, + {kindInsert, &xTest, 1}, + {kindChange, &subjectChanged, 2}, + {kindInsert, &xTest, 2}, + {kindInsert, &xTest, 2}, + {kindChange, &dateDel, 3}, + {kindInsert, &xTest, 3}, + {kindInsert, &xTest, 3}, + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotDiffs := diffHeaderFields(tt.args.orig, tt.args.changed, -1); !reflect.DeepEqual(gotDiffs, tt.wantDiffs) { + t.Errorf("diffHeaderFields() = %s, want %s", outputDiff(gotDiffs), outputDiff(tt.wantDiffs)) + } + }) + } +} + +func Test_calculateHeaderModifications(t *testing.T) { + orig := testHeader() + addOne := testHeader() + addOne.Add("X-Test", "1") + addOneInFront := testHeader() + fields := addOneInFront.Fields() + fields.Next() + fields.InsertBefore("X-Test", "1") + complexChanges := testHeader() + fields = complexChanges.Fields() + for fields.Next() { + fields.InsertBefore("X-Test", "1") + fields.InsertAfter("X-Test", "1") + if fields.CanonicalKey() == "Subject" { + fields.Set("changed") + } + if fields.CanonicalKey() == "Date" { + fields.Replace("X-Test", "1") + } + } + type args struct { + orig *Header + changed *Header + } + tests := []struct { + name string + args args + wantChangeOps []headerOp + wantInsertOps []headerOp + }{ + {"equal", args{orig, orig}, nil, nil}, + {"add-one", args{orig, addOne}, nil, []headerOp{{Index: 4, Name: "X-Test", Value: " 1"}}}, + {"add-one-in-front", args{orig, addOneInFront}, nil, []headerOp{{Index: 0, Name: "X-Test", Value: " 1"}}}, + {"complex", args{orig, complexChanges}, []headerOp{ + {Index: 1, Name: "subject", Value: " changed"}, + {Index: 1, Name: "DATE", Value: ""}, + }, []headerOp{ + {Index: 0, Name: "X-Test", Value: " 1"}, + {Index: 1, Name: "X-Test", Value: " 1"}, + {Index: 1, Name: "X-Test", Value: " 1"}, + {Index: 2, Name: "X-Test", Value: " 1"}, + {Index: 2, Name: "X-Test", Value: " 1"}, + {Index: 3, Name: "X-Test", Value: " 1"}, + {Index: 3, Name: "X-Test", Value: " 1"}, + {Index: 4, Name: "X-Test", Value: " 1"}, + {Index: 4, Name: "X-Test", Value: " 1"}, + }}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + gotChangeOps, gotInsertOps := calculateHeaderModifications(tt.args.orig, tt.args.changed) + if !reflect.DeepEqual(gotChangeOps, tt.wantChangeOps) { + t.Errorf("calculateHeaderModifications() gotChangeOps = %+v, want %+v", gotChangeOps, tt.wantChangeOps) + } + if !reflect.DeepEqual(gotInsertOps, tt.wantInsertOps) { + t.Errorf("calculateHeaderModifications() gotInsertOps = %+v, want %+v", gotInsertOps, tt.wantInsertOps) + } + }) + } +} + +func Test_getRaw(t *testing.T) { + type args struct { + key string + value string + } + tests := []struct { + name string + args args + want string + }{ + {"empty", args{"TO", ""}, "TO:"}, + {"no space", args{"TO", ""}, "TO: "}, + {"space", args{"TO", " "}, "TO: "}, + {"tab", args{"TO", "\t"}, "TO:\t"}, + {"two spaces", args{"TO", " "}, "TO: "}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := getRaw(tt.args.key, tt.args.value); got != tt.want { + t.Errorf("getRaw() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_headerField_deleted(t *testing.T) { + type fields struct { + canonicalKey string + raw string + } + tests := []struct { + name string + fields fields + want bool + }{ + {"deleted", fields{"To", "To:"}, true}, + {"not deleted", fields{"To", "To: "}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &headerField{ + canonicalKey: tt.fields.canonicalKey, + raw: tt.fields.raw, + } + if got := f.deleted(); got != tt.want { + t.Errorf("deleted() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_headerField_key(t *testing.T) { + type fields struct { + canonicalKey string + raw string + } + tests := []struct { + name string + fields fields + want string + }{ + {"same as canonical", fields{"To", "To: "}, "To"}, + {"different", fields{"To", "TO: "}, "TO"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &headerField{ + canonicalKey: tt.fields.canonicalKey, + raw: tt.fields.raw, + } + if got := f.key(); got != tt.want { + t.Errorf("key() = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_headerField_value(t *testing.T) { + type fields struct { + canonicalKey string + raw string + } + tests := []struct { + name string + fields fields + want string + }{ + {"value", fields{"To", "To: "}, " "}, + {"empty", fields{"To", "To:"}, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + f := &headerField{ + canonicalKey: tt.fields.canonicalKey, + raw: tt.fields.raw, + } + if got := f.value(); got != tt.want { + t.Errorf("value() = %v, want %v", got, tt.want) + } + }) + } +} diff --git a/mailfilter/mailfilter.go b/mailfilter/mailfilter.go new file mode 100644 index 0000000..7a39303 --- /dev/null +++ b/mailfilter/mailfilter.go @@ -0,0 +1,144 @@ +// Package mailfilter allows you to write milter filters without boilerplate code +package mailfilter + +import ( + "context" + "fmt" + "log" + "log/syslog" + "net" + "sync" + + "github.com/d--j/go-milter" +) + +type DecisionModificationFunc func(ctx context.Context, transaction *Transaction) (decision Decision, err error) + +type MailFilter struct { + wgDone sync.WaitGroup + socket net.Listener + server *milter.Server +} + +// New creates and starts a new [MailFilter] with a socket listening on network and address. +// decision is the callback that should implement the filter logic. +// opts are optional [Option] function that configure/fine-tune the mail filter. +func New(network, address string, decision DecisionModificationFunc, opts ...Option) (*MailFilter, error) { + resolvedOptions := options{ + decisionAt: DecisionAtEndOfMessage, + errorHandling: TempFailWhenError, + } + + for _, o := range opts { + o(&resolvedOptions) + } + + if resolvedOptions.syslogPrefix != "" { + sysLogger, err := syslog.NewLogger(syslog.LOG_MAIL, 0) + if err != nil { + return nil, err + } + sysLogger.SetPrefix(resolvedOptions.syslogPrefix) + milter.LogWarning = func(format string, v ...interface{}) { + log.Printf(fmt.Sprintf("milter: warning: %s", format), v...) + sysLogger.Printf(format, v...) + } + } + + actions := milter.AllClientSupportedActionMasks + protocols := milter.OptHeaderLeadingSpace | milter.OptNoUnknown + + switch resolvedOptions.decisionAt { + case DecisionAtConnect: + protocols = protocols | milter.OptNoHelo | milter.OptNoMailFrom | milter.OptNoRcptTo | milter.OptNoData | milter.OptNoHeaders | milter.OptNoEOH | milter.OptNoBody + case DecisionAtHelo: + protocols = protocols | milter.OptNoConnReply | milter.OptNoMailFrom | milter.OptNoRcptTo | milter.OptNoData | milter.OptNoHeaders | milter.OptNoEOH | milter.OptNoBody + case DecisionAtMailFrom: + protocols = protocols | milter.OptNoConnReply | milter.OptNoHeloReply | milter.OptNoRcptTo | milter.OptNoData | milter.OptNoHeaders | milter.OptNoEOH | milter.OptNoBody + case DecisionAtData: + protocols = protocols | milter.OptNoConnReply | milter.OptNoHeloReply | milter.OptNoRcptReply | milter.OptNoHeaders | milter.OptNoEOH | milter.OptNoBody + case DecisionAtEndOfHeaders: + protocols = protocols | milter.OptNoConnReply | milter.OptNoHeloReply | milter.OptNoRcptReply | milter.OptNoHeaderReply | milter.OptNoBody + default: + protocols = protocols | milter.OptNoConnReply | milter.OptNoHeloReply | milter.OptNoRcptReply | milter.OptNoHeaderReply | milter.OptNoEOHReply | milter.OptNoBodyReply + } + if resolvedOptions.skipBody { + protocols = protocols | milter.OptNoBody + } + macroStages := make([][]milter.MacroName, 0, 5) + macroStages = append(macroStages, []milter.MacroName{milter.MacroIfName, milter.MacroIfAddr}) // StageConnect + if resolvedOptions.decisionAt > DecisionAtConnect { + // StageHelo + macroStages = append(macroStages, []milter.MacroName{milter.MacroTlsVersion, milter.MacroCipher, milter.MacroCipherBits, milter.MacroCertSubject, milter.MacroCertIssuer}) + } + if resolvedOptions.decisionAt > DecisionAtHelo { // StageMail + macroStages = append(macroStages, []milter.MacroName{milter.MacroMailMailer, milter.MacroAuthAuthen, milter.MacroAuthType}) + } + if resolvedOptions.decisionAt > DecisionAtMailFrom { + macroStages = append(macroStages, []milter.MacroName{milter.MacroRcptMailer}) // StageRcpt + // try two different stages to get the queue ID, normally at the beginning of the DATA command it is already assigned + // but if it is not, try at the end of the message + macroStages = append(macroStages, []milter.MacroName{milter.MacroQueueId}) //StageData + macroStages = append(macroStages, []milter.MacroName{milter.MacroQueueId}) //StageEOM + } + + milterOptions := []milter.Option{ + milter.WithDynamicMilter(func(version uint32, action milter.OptAction, protocol milter.OptProtocol, maxData milter.DataSize) milter.Milter { + return &backend{ + opts: resolvedOptions, + decision: decision, + leadingSpace: protocol&milter.OptHeaderLeadingSpace != 0, + } + }), + milter.WithActions(actions), + milter.WithProtocols(protocols), + } + for i, macros := range macroStages { + milterOptions = append(milterOptions, milter.WithMacroRequest(milter.MacroStage(i), macros)) + } + + // create socket to listen on + socket, err := net.Listen(network, address) + if err != nil { + return nil, err + } + + // create server with assembled options + server := milter.NewServer(milterOptions...) + + f := &MailFilter{ + socket: socket, + server: server, + } + + // start the milter + f.wgDone.Add(1) + go func(socket net.Listener) { + if err := server.Serve(socket); err != nil { + milter.LogWarning("server.Server() error: %s", err) + } + f.wgDone.Done() + }(socket) + + return f, nil +} + +// Addr returns the [net.Addr] of the listening socket of this [MailFilter]. +// This method returns nil when the socket is not set. +func (f *MailFilter) Addr() net.Addr { + if f.socket == nil { + return nil + } + return f.socket.Addr() +} + +// Wait waits for the end of the [MailFilter] server. +func (f *MailFilter) Wait() { + f.wgDone.Wait() + _ = f.server.Close() +} + +// Close stops the [MailFilter] server. +func (f *MailFilter) Close() { + _ = f.server.Close() +} diff --git a/mailfilter/option.go b/mailfilter/option.go new file mode 100644 index 0000000..b1dcdb3 --- /dev/null +++ b/mailfilter/option.go @@ -0,0 +1,79 @@ +package mailfilter + +// DecisionAt defines when the filter decision is made. +type DecisionAt int + +const ( + // The DecisionAtConnect constant makes the mail filter call the decision function after the connect event. + DecisionAtConnect DecisionAt = iota + + // The DecisionAtHelo constant makes the mail filter call the decision function after the HELO/EHLO event. + DecisionAtHelo + + // The DecisionAtMailFrom constant makes the mail filter call the decision function after the MAIL FROM event. + DecisionAtMailFrom + + // The DecisionAtData constant makes the mail filter call the decision function after the DATA event (all RCPT TO were sent). + DecisionAtData + + // The DecisionAtEndOfHeaders constant makes the mail filter call the decision function after the EOH event (all headers were sent). + DecisionAtEndOfHeaders + + // The DecisionAtEndOfMessage constant makes the mail filter call the decision function at the end of the SMTP transaction. + // This is the default. + DecisionAtEndOfMessage +) + +type ErrorHandling int + +const ( + // Error just throws the error + Error ErrorHandling = iota + // AcceptWhenError accepts the transaction despite the error (it gets logged). + AcceptWhenError + // TempFailWhenError temporarily rejects the transaction (and logs the error). + TempFailWhenError + // RejectWhenError rejects the transaction (and logs the error). + RejectWhenError +) + +type options struct { + decisionAt DecisionAt + errorHandling ErrorHandling + skipBody bool + syslogPrefix string +} + +type Option func(opt *options) + +// WithDecisionAt sets the decision point for the [MailFilter]. +// The default is [DecisionAtEndOfMessage]. +func WithDecisionAt(decisionAt DecisionAt) Option { + return func(opt *options) { + opt.decisionAt = decisionAt + } +} + +// WithErrorHandling sets the error handling for the [MailFilter]. +// The default is [TempFailWhenError]. +func WithErrorHandling(errorHandling ErrorHandling) Option { + return func(opt *options) { + opt.errorHandling = errorHandling + } +} + +// WithoutBody configures the [MailFilter] to not request and collect the mail body. +func WithoutBody() Option { + return func(opt *options) { + opt.skipBody = true + } +} + +// WithSyslog enables logging to syslog with a prefix of prefix. +// This is a global option. +// All calls to [github.com/d--j/go-milter.LogWarning] will be also send to the syslog. +func WithSyslog(prefix string) Option { + return func(opt *options) { + opt.syslogPrefix = prefix + } +} diff --git a/mailfilter/transaction.go b/mailfilter/transaction.go new file mode 100644 index 0000000..4a27db6 --- /dev/null +++ b/mailfilter/transaction.go @@ -0,0 +1,308 @@ +package mailfilter + +import ( + "context" + "io" + "os" + + "github.com/d--j/go-milter" +) + +type Connect struct { + Host string // The host name the MTA figured out for the remote client. + Family string // "unknown", "unix", "tcp4" or "tcp6" + Port uint16 // If Family is "tcp4" or "tcp6" the remote port of client connecting to the MTA + Addr string // If Family "unix" the path to the unix socket. If "tcp4" or "tcp6" the IPv4 or IPv6 address of the remote client connecting to the MTA + IfName string // The Name of the network interface the MTA connection was accepted at. Might be empty. + IfAddr string // The IP address of the network interface the MTA connection was accepted at. Might be empty. +} + +type Helo struct { + Name string // The HELO/EHLO hostname the client provided + TlsVersion string // TLSv1.3, TLSv1.2, ... or empty when no STARTTLS was used. Might even be empty when STARTTLS was used (when the MTA does not support the corresponding macro – almost all do). + Cipher string // The Cipher that client and MTA negotiated. + CipherBits string // The bits of the cipher used. E.g. 256. Might be "RSA equivalent" bits for e.g. elliptic curve ciphers. + CertSubject string // If MutualTLS was used for the connection between client and MTA this holds the subject of the validated client certificate. + CertIssuer string // If MutualTLS was used for the connection between client and MTA this holds the subject of the issuer of the client certificate (CA or Sub-CA). +} + +// Transaction can be used to examine the data of the current mail transaction and +// also send changes to the message back to the MTA. +type Transaction struct { + // Connect holds the [Connect] information of this transaction. + Connect Connect + + // Helo holds the [Helo] information of this transaction. + // + // Only populated if [WithDecisionAt] is bigger than [DecisionAtConnect]. + Helo Helo + + // MailFrom holds the [MailFrom] of this transaction. + // You can change this and your changes get send back to the MTA. + // + // Only populated if [WithDecisionAt] is bigger than [DecisionAtHelo]. + MailFrom MailFrom + + // RcptTos holds the [RcptTo] recipient slice of this transaction. + // You can change this and your changes get send back to the MTA. + // + // Only populated if [WithDecisionAt] is bigger than [DecisionAtMailFrom]. + RcptTos []RcptTo + + // QueueId is the queue ID the MTA assigned for this transaction. + // You cannot change this value. + // + // Only populated if [WithDecisionAt] is bigger than [DecisionAtMailFrom]. + QueueId string + + // Headers are the [Header] fields of this message. + // You can use methods of this to change the header fields of the current message. + // + // Do not replace this variable. Always use the modification methods of [Header] and [Header.Fields]. + // The mail filter might panic if you do replace Headers. + // + // Only populated if [WithDecisionAt] is bigger than [DecisionAtData]. + Headers *Header + + hasDecision bool + decision Decision + decisionErr error + headers *Header + body *os.File + mailFrom MailFrom + rcptTos []RcptTo + replacementBody io.Reader +} + +func (t *Transaction) cleanup() { + t.Headers = nil + t.headers = nil + t.RcptTos = nil + t.rcptTos = nil + if t.replacementBody != nil { + if closer, ok := t.replacementBody.(io.Closer); ok { + if err := closer.Close(); err != nil { + milter.LogWarning("error while closing replacement body: %s", err) + } + } + t.replacementBody = nil + } + if t.body != nil { + _ = t.body.Close() + _ = os.Remove(t.body.Name()) + t.body = nil + } +} + +func (t *Transaction) response() *milter.Response { + switch t.decision { + case Accept: + return milter.RespAccept + case TempFail: + return milter.RespTempFail + case Reject: + return milter.RespReject + case Discard: + return milter.RespDiscard + default: + resp, err := milter.RejectWithCodeAndReason(t.decision.getCode(), t.decision.getReason()) + if err != nil { + milter.LogWarning("milter: reject with custom reason failed, temp-fail instead: %s", err) + return milter.RespTempFail + } + return resp + } +} + +func (t *Transaction) makeDecision(ctx context.Context, decide DecisionModificationFunc) { + if t.hasDecision { + panic("calling makeDecision on a Transaction that already has made a decision") + } + // make copies of data that user can change + t.MailFrom = t.mailFrom + t.RcptTos = make([]RcptTo, len(t.rcptTos)) + for i, r := range t.rcptTos { + t.RcptTos[i] = r + } + if t.headers != nil { + t.Headers = t.headers.copy() + } else { + t.headers = &Header{} + t.Headers = &Header{} + } + // call the decider + d, err := decide(ctx, t) + // save decision + t.hasDecision = true + t.decision = d + t.decisionErr = err +} + +// hasModifications checks quickly if there are any modifications - it does not actually compute them +func (t *Transaction) hasModifications() bool { + if !t.hasDecision { + return false + } + if t.mailFrom.Addr != t.MailFrom.Addr || t.mailFrom.Args != t.MailFrom.Args { + return true + } + if t.replacementBody != nil { + return true + } + if len(t.rcptTos) != len(t.RcptTos) { + return true + } + for i, r := range t.rcptTos { // might give false positives because order does not matter + if r.Addr != t.RcptTos[i].Addr || r.Args != t.RcptTos[i].Args { + return true + } + } + origFields := t.headers.Fields() + changedFields := t.Headers.Fields() + if origFields.Len() != changedFields.Len() { + return true + } + for origFields.Next() && changedFields.Next() { + if origFields.raw() != changedFields.raw() { + return true + } + } + return false +} + +func (t *Transaction) sendModifications(m *milter.Modifier) error { + if t.mailFrom.Addr != t.MailFrom.Addr || t.mailFrom.Args != t.MailFrom.Args { + if err := m.ChangeFrom(t.MailFrom.Addr, t.MailFrom.Args); err != nil { + return err + } + } + deletions, additions := calculateRcptToDiff(t.rcptTos, t.RcptTos) + for _, r := range deletions { + if err := m.DeleteRecipient(r.Addr); err != nil { + return err + } + } + for _, r := range additions { + if err := m.AddRecipient(r.Addr, r.Args); err != nil { + return err + } + } + changeOps, insertOps := calculateHeaderModifications(t.headers, t.Headers) + for _, op := range changeOps { + if err := m.ChangeHeader(op.Index, op.Name, op.Value); err != nil { + return err + } + } + // apply insert operations in reverse for the indexes to be correct + if len(insertOps) > 0 { + for i := len(insertOps) - 1; i > -1; i-- { + op := insertOps[i] + if err := m.InsertHeader(op.Index, op.Name, op.Value); err != nil { + return err + } + } + } + if t.replacementBody != nil { + defer func() { + if closer, ok := t.replacementBody.(io.Closer); ok { + if err := closer.Close(); err != nil { + milter.LogWarning("error while closing replacement body: %s", err) + } + } + t.replacementBody = nil + }() + if err := m.ReplaceBody(t.replacementBody); err != nil { + return err + } + } + return nil +} + +func (t *Transaction) addHeader(key string, raw string) { + if t.headers == nil { + t.headers = &Header{} + } + t.headers.addRaw(key, raw) +} + +func (t *Transaction) addBodyChunk(chunk []byte) (err error) { + if t.body == nil { + t.body, err = os.CreateTemp("", "body-*") + if err != nil { + return + } + } + _, err = t.body.Write(chunk) + return +} + +// HasRcptTo returns true when rcptTo is in the list of recipients. +// +// rcptTo gets compared to the existing recipients IDNA address aware. +func (t *Transaction) HasRcptTo(rcptTo string) bool { + findR := RcptTo{ + addr: addr{Addr: rcptTo, Args: ""}, + transport: "", + } + findLocal, findDomain := findR.Local(), findR.AsciiDomain() + for _, r := range t.RcptTos { + if r.Local() == findLocal && r.AsciiDomain() == findDomain { + return true + } + } + return false +} + +// AddRcptTo adds the rcptTo (without angles) to the list of recipients with the ESMTP arguments esmtpArgs. +// If rcptTo is already in the list of recipients only the esmtpArgs of this recipient get updated. +// +// rcptTo gets compared to the existing recipients IDNA address aware. +func (t *Transaction) AddRcptTo(rcptTo string, esmtpArgs string) { + addR := RcptTo{ + addr: addr{Addr: rcptTo, Args: esmtpArgs}, + transport: "smtp", + } + findLocal, findDomain := addR.Local(), addR.AsciiDomain() + for i, r := range t.RcptTos { + if r.Local() == findLocal && r.AsciiDomain() == findDomain { + t.RcptTos[i].Args = esmtpArgs + return + } + } + t.RcptTos = append(t.RcptTos, addR) +} + +// DelRcptTo deletes the rcptTo (without angles) from the list of recipients. +// +// rcptTo gets compared to the existing recipients IDNA address aware. +func (t *Transaction) DelRcptTo(rcptTo string) { + findR := RcptTo{ + addr: addr{Addr: rcptTo, Args: ""}, + transport: "", + } + findLocal, findDomain := findR.Local(), findR.AsciiDomain() + for i, r := range t.RcptTos { + if r.Local() == findLocal && r.AsciiDomain() == findDomain { + t.RcptTos = append(t.RcptTos[:i], t.RcptTos[i+1:]...) + return + } + } +} + +// Body gets you a [io.ReadSeeker] of the body. The reader seeked to the start of the body. +// +// This method returns nil when you used [WithDecisionAt] with anything other than [DecisionAtEndOfMessage] +// or you used [WithoutBody]. +func (t *Transaction) Body() io.ReadSeeker { + if t.body == nil { + return nil + } + _, _ = t.body.Seek(0, io.SeekStart) + return t.body +} + +// ReplaceBody replaces the body of the current message with the contents +// of the [io.Reader] r. +func (t *Transaction) ReplaceBody(r io.Reader) { + t.replacementBody = r +} diff --git a/mailfilter/transaction_test.go b/mailfilter/transaction_test.go new file mode 100644 index 0000000..8918226 --- /dev/null +++ b/mailfilter/transaction_test.go @@ -0,0 +1,284 @@ +package mailfilter + +import ( + "context" + "errors" + "fmt" + "io" + "reflect" + "strings" + "testing" + + "github.com/d--j/go-milter/internal/wire" + "github.com/emersion/go-message/mail" +) + +func rcptFromAddr(in []addr) []RcptTo { + if in == nil { + return nil + } + var out = []RcptTo{} + for _, i := range in { + out = append(out, RcptTo{addr: i}) + } + return out +} +func addrFromRcp(in []RcptTo) []addr { + if in == nil { + return nil + } + var out = []addr{} + for _, i := range in { + out = append(out, addr{Addr: i.addr.Addr, Args: i.addr.Args}) + } + return out +} + +func TestTransaction_AddRcptTo(t1 *testing.T) { + type args struct { + rcptTo string + esmtpArgs string + } + tests := []struct { + name string + existing []addr + args args + want []addr + }{ + {"nil", nil, args{"", ""}, []addr{{}}}, + {"empty", []addr{}, args{"", ""}, []addr{{}}}, + {"set-esmtp-args", []addr{{Args: ""}}, args{"", "A=B"}, []addr{{Args: "A=B"}}}, + {"add", []addr{{}}, args{"root@localhost", "A=B"}, []addr{{}, {Addr: "root@localhost", Args: "A=B"}}}, + {"idna-utf8", []addr{{Addr: "root@スパム.example.com"}}, args{"root@xn--zck5b2b.example.com", "A=B"}, []addr{{Addr: "root@スパム.example.com", Args: "A=B"}}}, + {"idna-ascii", []addr{{Addr: "root@xn--zck5b2b.example.com"}}, args{"root@スパム.example.com", "A=B"}, []addr{{Addr: "root@xn--zck5b2b.example.com", Args: "A=B"}}}, + } + for _, tt := range tests { + t1.Run(tt.name, func(t1 *testing.T) { + + t := &Transaction{ + RcptTos: rcptFromAddr(tt.existing), + } + t.AddRcptTo(tt.args.rcptTo, tt.args.esmtpArgs) + got := addrFromRcp(t.RcptTos) + if !reflect.DeepEqual(got, tt.want) { + t1.Fatalf("RcptTos = %+v, want %+v", got, tt.want) + } + }) + } +} + +func TestTransaction_DelRcptTo(t1 *testing.T) { + type args struct { + rcptTo string + } + tests := []struct { + name string + existing []addr + args args + want []addr + }{ + {"nil", nil, args{""}, nil}, + {"empty", []addr{}, args{""}, []addr{}}, + {"del", []addr{{Addr: "root@localhost"}}, args{"root@localhost"}, []addr{}}, + {"idna-utf8", []addr{{Addr: "root@スパム.example.com"}}, args{"root@xn--zck5b2b.example.com"}, []addr{}}, + {"idna-ascii", []addr{{Addr: "root@xn--zck5b2b.example.com"}}, args{"root@スパム.example.com"}, []addr{}}, + } + for _, tt := range tests { + t1.Run(tt.name, func(t1 *testing.T) { + t := &Transaction{ + RcptTos: rcptFromAddr(tt.existing), + } + t.DelRcptTo(tt.args.rcptTo) + got := addrFromRcp(t.RcptTos) + if !reflect.DeepEqual(got, tt.want) { + t1.Fatalf("RcptTos = %v, want %v", got, tt.want) + } + }) + } +} + +func TestTransaction_HasRcptTo(t1 *testing.T) { + type args struct { + rcptTo string + } + tests := []struct { + name string + existing []addr + args args + want bool + }{ + {"nil", nil, args{""}, false}, + {"empty", []addr{}, args{""}, false}, + {"no", []addr{{Addr: "root@localhost"}}, args{""}, false}, + {"yes", []addr{{Addr: "root@localhost"}}, args{"root@localhost"}, true}, + } + for _, tt := range tests { + t1.Run(tt.name, func(t1 *testing.T) { + t := &Transaction{ + RcptTos: rcptFromAddr(tt.existing), + } + if got := t.HasRcptTo(tt.args.rcptTo); got != tt.want { + t1.Errorf("HasRcptTo() = %v, want %v", got, tt.want) + } + }) + } +} + +func outputMessages(messages []*wire.Message) string { + b := strings.Builder{} + for i, msg := range messages { + b.WriteString(fmt.Sprintf("%02d %c %q\n", i, msg.Code, msg.Data)) + } + return b.String() +} + +func TestTransaction_sendModifications(t1 *testing.T) { + expectErr := errors.New("error") + writeErr := func(_ *wire.Message) error { + return expectErr + } + mod := func(act wire.ModifyActCode, data []byte) *wire.Message { + return &wire.Message{Code: wire.Code(act), Data: data} + } + tests := []struct { + name string + decider DecisionModificationFunc + want []*wire.Message + wantErr bool + }{ + {"noop", func(_ context.Context, trx *Transaction) (Decision, error) { + return Accept, nil + }, nil, false}, + {"mail-from", func(_ context.Context, trx *Transaction) (Decision, error) { + trx.MailFrom.Addr = "root@localhost" + trx.MailFrom.Args = "A=B" + return Accept, nil + }, []*wire.Message{mod(wire.ActChangeFrom, []byte("\u0000A=B\u0000"))}, false}, + {"mail-from-err", func(ctx context.Context, trx *Transaction) (Decision, error) { + trx.MailFrom.Addr = "root@localhost" + ctx.Value("s").(*mockSession).WritePacket = writeErr + return Accept, nil + }, nil, true}, + {"del-rcpt", func(_ context.Context, trx *Transaction) (Decision, error) { + trx.DelRcptTo("root@localhost") + return Accept, nil + }, []*wire.Message{mod(wire.ActDelRcpt, []byte("\u0000"))}, false}, + {"del-rcpt-noop", func(_ context.Context, trx *Transaction) (Decision, error) { + trx.DelRcptTo("someone@localhost") + return Accept, nil + }, nil, false}, + {"del-rcpt-err", func(ctx context.Context, trx *Transaction) (Decision, error) { + trx.DelRcptTo("root@localhost") + ctx.Value("s").(*mockSession).WritePacket = writeErr + return Accept, nil + }, nil, true}, + {"add-rcpt", func(_ context.Context, trx *Transaction) (Decision, error) { + trx.AddRcptTo("someone@localhost", "") + return Accept, nil + }, []*wire.Message{mod(wire.ActAddRcpt, []byte("\u0000"))}, false}, + {"add-rcpt-par", func(_ context.Context, trx *Transaction) (Decision, error) { + trx.AddRcptTo("someone@localhost", "A=B") + return Accept, nil + }, []*wire.Message{mod(wire.ActAddRcptPar, []byte("\u0000A=B\u0000"))}, false}, + {"add-rcpt-noop", func(_ context.Context, trx *Transaction) (Decision, error) { + trx.AddRcptTo("root@localhost", "") + return Accept, nil + }, nil, false}, + {"add-rcpt-err", func(ctx context.Context, trx *Transaction) (Decision, error) { + trx.AddRcptTo("someone@localhost", "") + ctx.Value("s").(*mockSession).WritePacket = writeErr + return Accept, nil + }, nil, true}, + {"replace-rcpt", func(_ context.Context, trx *Transaction) (Decision, error) { + trx.RcptTos[0].Addr = "someone@localhost" + return Accept, nil + }, []*wire.Message{ + mod(wire.ActDelRcpt, []byte("\u0000")), + mod(wire.ActAddRcpt, []byte("\u0000")), + }, false}, + {"replace-body", func(_ context.Context, trx *Transaction) (Decision, error) { + got, _ := io.ReadAll(trx.Body()) + if string(got) != "body" { + t1.Fatalf("wrong body %q", got) + } + trx.ReplaceBody(io.NopCloser(strings.NewReader("test"))) + return Accept, nil + }, []*wire.Message{ + mod(wire.ActReplBody, []byte("test")), + }, false}, + {"replace-body-err", func(ctx context.Context, trx *Transaction) (Decision, error) { + trx.ReplaceBody(io.NopCloser(strings.NewReader("test"))) + ctx.Value("s").(*mockSession).WritePacket = writeErr + return Accept, nil + }, nil, true}, + {"add-header", func(_ context.Context, trx *Transaction) (Decision, error) { + trx.Headers.Add("X-Test", "1") + return Accept, nil + }, []*wire.Message{ + mod(wire.ActInsertHeader, []byte("\u0000\u0000\u0000\u0003X-Test\u0000 1\u0000")), + }, false}, + {"prepend-header", func(_ context.Context, trx *Transaction) (Decision, error) { + f := trx.Headers.Fields() + f.Next() + f.InsertBefore("X-Test", "1") + return Accept, nil + }, []*wire.Message{ + mod(wire.ActInsertHeader, []byte("\u0000\u0000\u0000\u0000X-Test\u0000 1\u0000")), + }, false}, + {"prepend-header-err", func(ctx context.Context, trx *Transaction) (Decision, error) { + f := trx.Headers.Fields() + f.Next() + f.InsertBefore("X-Test", "1") + ctx.Value("s").(*mockSession).WritePacket = writeErr + return Accept, nil + }, nil, true}, + {"change-header", func(_ context.Context, trx *Transaction) (Decision, error) { + f := trx.Headers.Fields() + f.Next() + f.SetAddressList([]*mail.Address{{Address: "root@localhost", Name: "root"}}) + return Accept, nil + }, []*wire.Message{ + mod(wire.ActChangeHeader, []byte("\u0000\u0000\u0000\u0001From\u0000 \"root\" \u0000")), + }, false}, + {"change-header-err", func(ctx context.Context, trx *Transaction) (Decision, error) { + f := trx.Headers.Fields() + f.Next() + f.SetAddressList([]*mail.Address{{Address: "root@localhost", Name: "root"}}) + ctx.Value("s").(*mockSession).WritePacket = writeErr + return Accept, nil + }, nil, true}, + } + for _, tt := range tests { + t1.Run(tt.name, func(t1 *testing.T) { + b, s := newMockBackend() + t1.Cleanup(b.transaction.cleanup) + _, _ = b.MailFrom("", "", s.newModifier()) + _, _ = b.RcptTo("root@localhost", "", s.newModifier()) + _, _ = b.Header("From", " <>", s.newModifier()) + _, _ = b.Header("To", " ", s.newModifier()) + _, _ = b.Header("Subject", " test", s.newModifier()) + _, _ = b.BodyChunk([]byte("body"), s.newModifier()) + b.transaction.makeDecision(context.WithValue(context.Background(), "s", s), tt.decider) + if b.transaction.decisionErr != nil { + t1.Fatal(b.transaction.decisionErr) + } + if tt.wantErr == false { + gotHas := b.transaction.hasModifications() + expectHas := false + if len(tt.want) > 0 { + expectHas = true + } + if gotHas != expectHas { + t1.Errorf("hasModifications() = %v, want %v", gotHas, expectHas) + } + } + if err := b.transaction.sendModifications(s.newModifier()); (err != nil) != tt.wantErr { + t1.Errorf("sendModifications() error = %v, wantErr %v", err, tt.wantErr) + } + got := s.modifications + if !reflect.DeepEqual(got, tt.want) { + t1.Errorf("sendModifications() sent %v, want %v", outputMessages(got), outputMessages(tt.want)) + } + }) + } +} diff --git a/milter.go b/milter.go index ac47771..f85ec36 100644 --- a/milter.go +++ b/milter.go @@ -45,7 +45,7 @@ const ( OptNoEOHReply OptProtocol = 1 << 18 // Milter does not send a reply to end of header event. SMFIP_NR_EOH [v6] OptNoBodyReply OptProtocol = 1 << 19 // Milter does not send a reply to body chunk event. SMFIP_NR_BODY [v6] - // OptHeaderLeadingSpace lets the Milter request that the MTA does not swallow a leading space + // OptHeaderLeadingSpace lets the [Milter] request that the MTA does not swallow a leading space // when passing the header value to the milter. // Sendmail by default eats one space (not tab) after the colon. So the header line (spaces replaced with _): // Subject:__Test @@ -53,9 +53,9 @@ const ( // sends OptHeaderLeadingSpace to the MTA it requests that it wants the header value as is. // So the MTA should send HeaderName "Subject" and HeaderValue "__Test". // - // Milter that do e.g. DKIM signing may need the additional space to create valid DKIM signatures. + // [Milter] that do e.g. DKIM signing may need the additional space to create valid DKIM signatures. // - // The Client and ClientSession does not handle this option. It is the responsibility of the MTA to check if the milter + // The [Client] and [ClientSession] does not handle this option. It is the responsibility of the MTA to check if the milter // asked for this and obey this request. In the simplest case just never swallow the space. // // SMFIP_HDR_LEADSPC [v6] diff --git a/modifier.go b/modifier.go index fedb3d9..bf9dd4b 100644 --- a/modifier.go +++ b/modifier.go @@ -35,7 +35,7 @@ type Action struct { } // StopProcessing returns true when the milter wants to immediately stop this SMTP connection. -// You can use SMTPReply to send as reply to the current SMTP command. +// You can use [Action.SMTPReply] to send as reply to the current SMTP command. func (a Action) StopProcessing() bool { return a.SMTPCode > 0 } @@ -111,10 +111,22 @@ type ModifyAction struct { Body []byte // Index of the header field to be changed if Type = ActionChangeHeader or Type = ActionInsertHeader. - // Index is 1-based and is per value of HdrName. - // E.g. HeaderIndex = 3 and HdrName = "DKIM-Signature" mean "change third - // DKIM-Signature field". Order is the same as of HeaderField calls. - // A HeaderIndex of 0 for Type = ActionInsertHeader has the special meaning "at the very beginning". + // Index is 1-based. + // + // If Type = ActionChangeHeader the index is per canonical value of HdrName. + // E.g. HeaderIndex = 3 and HdrName = "DKIM-Signature" mean "change third DKIM-Signature field". + // Order is the same as of HeaderField calls. + // + // If Type = ActionInsertHeader the index is global to all headers, 1-based and means "insert after the HeaderIndex header". + // A HeaderIndex of 0 has the special meaning "at the very beginning". + // + // Deleted headers (Type = ActionChangeHeader and HeaderValue == "") do not change the indexes of the other headers. + // They will be skipped by the MTA but still occupy their place in the header list of the MTA. + // + // In both cases when you provide an index that is bigger than allowed the header gets added at the very end of the header list. + // This is NOT always semantically equal to Type == ActionAddHeader. + // Type == ActionAddHeader may actually replace an existing header instead of adding a new one. + // This will only happen with MTA generated headers besides "Received", "X400-Received", "Via" and "Mail-From". HeaderIndex uint32 // Header field name to be added/changed if Type == ActionAddHeader or @@ -205,11 +217,10 @@ func parseModifyAct(msg *wire.Message) (*ModifyAction, error) { return act, nil } -// Modifier provides access to Macros and Headers to callback handlers. It also defines a +// Modifier provides access to [Macros] to callback handlers. It also defines a // number of functions that can be used by callback handlers to modify processing of the email message. -// Besides Progress() they can only be called in the EndOfMessage callback. +// Besides [Modifier.Progress] they can only be called in the EndOfMessage callback. type Modifier struct { - Headers textproto.MIMEHeader Macros Macros writeProgressPacket func(*wire.Message) error writePacket func(*wire.Message) error @@ -252,7 +263,7 @@ func (m *Modifier) AddRecipient(r string, esmtpArgs string) error { var buffer bytes.Buffer buffer.WriteString(addHats(r)) buffer.WriteByte(0) - // send wire.ActAddRcptPar when that is the only allowed action or we need to send it because esmptArgs ist set + // send wire.ActAddRcptPar when that is the only allowed action, or we need to send it because esmptArgs ist set if (esmtpArgs != "" && m.actions&OptAddRcptWithArgs != 0) || (esmtpArgs == "" && m.actions&OptAddRcptWithArgs != 0) { buffer.WriteString(esmtpArgs) buffer.WriteByte(0) @@ -298,7 +309,7 @@ func (m *Modifier) ReplaceBodyRawChunk(chunk []byte) error { // wrappedR := transform.NewReader(r, t) // m.ReplaceBody(wrappedR) // -// This function tries to use as few calls to [ReplaceBodyRawChunk] as possible. +// This function tries to use as few calls to [Modifier.ReplaceBodyRawChunk] as possible. // // You can call ReplaceBody multiple times. The MTA will combine all those calls into one message. // @@ -398,7 +409,7 @@ func errorWriteReadOnly(m *wire.Message) error { return fmt.Errorf("tried to send action %c in read-only state", m.Code) } -// newModifier creates a new [Modifier] instance from [serverSession] +// newModifier creates a new [Modifier] instance from s. If it is readOnly then all modification actions will throw an error. func newModifier(s *serverSession, readOnly bool) *Modifier { writePacket := s.writePacket if readOnly { @@ -406,10 +417,20 @@ func newModifier(s *serverSession, readOnly bool) *Modifier { } return &Modifier{ Macros: ¯oReader{macrosStages: s.macros}, - Headers: s.headers, writePacket: writePacket, writeProgressPacket: s.writePacket, actions: s.actions, maxDataSize: s.maxDataSize, } } + +// NewTestModifier is only exported for unit-tests. It can only be use internally since it uses the internal package [wire]. +func NewTestModifier(macros Macros, writePacket, writeProgress func(msg *wire.Message) error, actions OptAction, maxDataSize DataSize) *Modifier { + return &Modifier{ + Macros: macros, + writePacket: writePacket, + writeProgressPacket: writeProgress, + actions: actions, + maxDataSize: maxDataSize, + } +} diff --git a/options.go b/options.go index bfe5617..282293e 100644 --- a/options.go +++ b/options.go @@ -1,6 +1,8 @@ package milter -import "time" +import ( + "time" +) // NewMilterFunc is the signature of a function that can be used with [WithDynamicMilter] to configure the [Milter] backend. // The parameters version, action, protocol and maxData are the negotiated values. diff --git a/response.go b/response.go index b6e667a..5c984d4 100644 --- a/response.go +++ b/response.go @@ -22,7 +22,7 @@ func (c *Response) Response() *wire.Message { } // Continue returns false if the MTA should stop sending events for this transaction, true otherwise. -// A RespDiscard Response will return false because the MTA should end sending events for the current +// A [RespDiscard] Response will return false because the MTA should end sending events for the current // SMTP transaction to this milter. func (c *Response) Continue() bool { switch wire.ActionCode(c.code) { @@ -33,12 +33,12 @@ func (c *Response) Continue() bool { } } -// newResponse generates a new Response suitable for wire.WritePacket +// newResponse generates a new Response suitable for [wire.WritePacket] func newResponse(code wire.Code, data []byte) *Response { return &Response{code, data} } -// newResponseStr generates a new Response with string payload (null-byte terminated) +// newResponseStr generates a new [Response] with string payload (null-byte terminated) func newResponseStr(code wire.Code, data string) (*Response, error) { if len(data) > int(DataSize64K)-1 { // space for null-bytes return nil, fmt.Errorf("milter: invalid data length: %d > %d", len(data), int(DataSize64K)-1) @@ -102,7 +102,7 @@ var ( // RespSkip signals to the MTA that transaction should continue and that the MTA // does not need to send more events of the same type. This response one makes sense/is possible as - // return value of Milter.RcptTo, Milter.Header and Milter.BodyChunk. + // return value of [Milter.RcptTo], [Milter.Header] and [Milter.BodyChunk]. // No more events get send to the milter after this response. RespSkip = &Response{code: wire.Code(wire.ActSkip)} ) diff --git a/server.go b/server.go index ce171cd..e6d7296 100644 --- a/server.go +++ b/server.go @@ -9,66 +9,89 @@ import ( // MaxServerProtocolVersion is the maximum Milter protocol version implemented by the server. const MaxServerProtocolVersion uint32 = 6 -// ErrServerClosed is returned by the Server's Serve method after a call to Close. +// ErrServerClosed is returned by the [Server]'s [Server.Serve] method after a call to [Server.Close]. var ErrServerClosed = errors.New("milter: server closed") // Milter is an interface for milter callback handlers. type Milter interface { // Connect is called to provide SMTP connection data for incoming message. // Suppress with OptNoConnect. + // + // If this method returns an error the error will be logged and the connection will be closed. + // If there is a [Response] (and we did not negotiate [OptNoConnReply]) this response will be sent before closing the connection. Connect(host string, family string, port uint16, addr string, m *Modifier) (*Response, error) - // Helo is called to process any HELO/EHLO related filters. Suppress with - // OptNoHelo. + // Helo is called to process any HELO/EHLO related filters. Suppress with [OptNoHelo]. + // + // If this method returns an error the error will be logged and the connection will be closed. + // If there is a [Response] (and we did not negotiate [OptNoHeloReply]) this response will be sent before closing the connection. Helo(name string, m *Modifier) (*Response, error) - // MailFrom is called to process filters on envelope FROM address. Suppress - // with OptNoMailFrom. + // MailFrom is called to process filters on envelope FROM address. Suppress with [OptNoMailFrom]. + // + // If this method returns an error the error will be logged and the connection will be closed. + // If there is a [Response] (and we did not negotiate [OptNoMailReply]) this response will be sent before closing the connection. MailFrom(from string, esmtpArgs string, m *Modifier) (*Response, error) - // RcptTo is called to process filters on envelope TO address. Suppress with - // OptNoRcptTo. + // RcptTo is called to process filters on envelope TO address. Suppress with [OptNoRcptTo]. + // + // If this method returns an error the error will be logged and the connection will be closed. + // If there is a [Response] (and we did not negotiate [OptNoRcptReply]) this response will be sent before closing the connection. RcptTo(rcptTo string, esmtpArgs string, m *Modifier) (*Response, error) - // Data is called at the beginning of the DATA command (after all RCPT TO commands). + // Data is called at the beginning of the DATA command (after all RCPT TO commands). Suppress with [OptNoData]. + // + // If this method returns an error the error will be logged and the connection will be closed. + // If there is a [Response] (and we did not negotiate [OptNoDataReply]) this response will be sent before closing the connection. Data(m *Modifier) (*Response, error) - // Header is called once for each header in incoming message. Suppress with - // OptNoHeaders. + // Header is called once for each header in incoming message. Suppress with [OptNoHeaders]. + // + // If this method returns an error the error will be logged and the connection will be closed. + // If there is a [Response] (and we did not negotiate [OptNoHeaderReply]) this response will be sent before closing the connection. Header(name string, value string, m *Modifier) (*Response, error) - // Headers gets called when all message headers have been processed. - // Suppress with OptNoEOH. - // h is a textproto.MIMEHeader of all collected headers. - // If you specified OptNoHeaders this will be of course empty. + // Headers gets called when all message headers have been processed. Suppress with [OptNoEOH]. + // + // If this method returns an error the error will be logged and the connection will be closed. + // If there is a [Response] (and we did not negotiate [OptNoEOHReply]) this response will be sent before closing the connection. Headers(m *Modifier) (*Response, error) // BodyChunk is called to process next message body chunk data (up to 64KB - // in size). Suppress with OptNoBody. If you return RespSkip the MTA will stop + // in size). Suppress with [OptNoBody]. If you return [RespSkip] the MTA will stop // sending more body chunks. But older MTAs do not support this and in this case there are more calls to BodyChunk. // Your code should be able to handle this. + // + // If this method returns an error the error will be logged and the connection will be closed. + // If there is a [Response] (and we did not negotiate [OptNoBodyReply]) this response will be sent before closing the connection. BodyChunk(chunk []byte, m *Modifier) (*Response, error) // EndOfMessage is called at the end of each message. All changes to message's // content & attributes must be done here. // The MTA can start over with another message in the same connection but that is handled in a new Milter instance. + // + // If this method returns an error the error will be logged and the connection will be closed. + // If there is a [Response] this response will be sent before closing the connection. EndOfMessage(m *Modifier) (*Response, error) // Abort is called if the current message has been aborted. All message data - // should be reset prior to the MailFrom callback. Connection data should be - // preserved. Cleanup is not called before or after Abort. + // should be reset prior to the [Milter.MailFrom] callback. Connection data should be + // preserved. [Milter.Cleanup] is not called before or after Abort. Abort(m *Modifier) error // Unknown is called when the MTA got an unknown command in the SMTP connection. + // + // If this method returns an error the error will be logged and the connection will be closed. + // If there is a [Response] (and we did not negotiate [OptNoUnknownReply]) this response will be sent before closing the connection. Unknown(cmd string, m *Modifier) (*Response, error) - // Cleanup always gets called when the Milter is about to be discarded. + // Cleanup always gets called when the [Milter] is about to be discarded. // E.g. because the MTA closed the connection, one SMTP message was successful or there was an error. - // May be called more than once for a single Milter. + // May be called more than once for a single [Milter]. Cleanup() } -// NoOpMilter is a dummy Milter implementation that does nothing. +// NoOpMilter is a dummy [Milter] implementation that does nothing. type NoOpMilter struct{} var _ Milter = NoOpMilter{} @@ -129,9 +152,9 @@ type Server struct { // NewServer creates a new milter server. // -// You need to at least specify the used Milter with the option WithMilter. -// You should also specify the actions your Milter will do. Otherwise, you cannot do any message modifications. -// For performance reasons you should disable protocol stages that you do not need with WithProtocol. +// You need to at least specify the used [Milter] with the option [WithMilter]. +// You should also specify the actions your [Milter] will do. Otherwise, you cannot do any message modifications. +// For performance reasons you should disable protocol stages that you do not need with [WithProtocol]. // // This function will panic when you provide invalid options. func NewServer(opts ...Option) *Server { @@ -171,11 +194,13 @@ func NewServer(opts ...Option) *Server { // Serve starts the server. func (s *Server) Serve(ln net.Listener) error { - defer func(ln net.Listener) { - _ = ln.Close() - }(ln) - s.listeners = append(s.listeners, ln) + defer func(ln net.Listener, len int) { + if s.listeners[len-1] != nil { + _ = ln.Close() + s.listeners[len-1] = nil + } + }(ln, len(s.listeners)) for { conn, err := ln.Accept() @@ -204,8 +229,10 @@ func (s *Server) Close() error { } s.closed = true for _, ln := range s.listeners { - if err := ln.Close(); err != nil { - return err + if ln != nil { + if err := ln.Close(); err != nil { + return err + } } } return nil diff --git a/session.go b/session.go index 91245d6..20d57f9 100644 --- a/session.go +++ b/session.go @@ -7,7 +7,6 @@ import ( "fmt" "io" "net" - "net/textproto" "strings" "github.com/d--j/go-milter/internal/wire" @@ -23,7 +22,6 @@ type serverSession struct { protocol OptProtocol maxDataSize DataSize conn net.Conn - headers textproto.MIMEHeader macros *macrosStages backend Milter } @@ -229,16 +227,11 @@ func (m *serverSession) Process(msg *wire.Message) (*Response, error) { if len(msg.Data) < 2 { return nil, fmt.Errorf("milter: header: unexpected data size: %d", len(msg.Data)) } - // make sure headers is initialized - if m.headers == nil { - m.headers = make(textproto.MIMEHeader) - } // add new header to headers map headerData := wire.DecodeCStrings(msg.Data) if len(headerData) != 2 { return nil, fmt.Errorf("milter: header: unexpected number of strings: %d", len(headerData)) } - m.headers.Add(headerData[0], headerData[1]) // call and return milter handler resp, err := m.backend.Header(headerData[0], headerData[1], newModifier(m, true)) m.macros.DelStageAndAbove(StageEndMarker) @@ -304,14 +297,12 @@ func (m *serverSession) Process(msg *wire.Message) (*Response, error) { case wire.CodeAbort: // abort current message and start over err := m.backend.Abort(newModifier(m, true)) - m.headers = nil m.macros.DelStageAndAbove(StageHelo) return nil, err case wire.CodeQuitNewConn: // abort current connection and start over m.backend.Cleanup() - m.headers = nil m.macros.DelStageAndAbove(StageConnect) m.backend = m.newBackend() // do not send response @@ -374,6 +365,9 @@ func (m *serverSession) HandleMilterCommands() { if err != errCloseSession { // log error condition LogWarning("Error performing milter command: %v", err) + if resp != nil && !m.skipResponse(msg.Code) { + _ = m.writePacket(resp.Response()) + } } return } diff --git a/session_test.go b/session_test.go index 8f8e341..09d6a9c 100644 --- a/session_test.go +++ b/session_test.go @@ -68,7 +68,6 @@ func (p *processTestMilter) Header(name string, value string, m *Modifier) (*Res } func (p *processTestMilter) Headers(m *Modifier) (*Response, error) { - p.headers = m.Headers p.headersCalled = true return RespContinue, nil } From 8f54a285ddef197b84cfa1aa3fa5cece30bafc60 Mon Sep 17 00:00:00 2001 From: Daniel Jagszent Date: Fri, 3 Mar 2023 19:05:49 +0100 Subject: [PATCH 2/4] fix: sending wrong modification when adding recipient without ESMTP args. --- modifier.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/modifier.go b/modifier.go index bf9dd4b..bd81ab1 100644 --- a/modifier.go +++ b/modifier.go @@ -264,7 +264,7 @@ func (m *Modifier) AddRecipient(r string, esmtpArgs string) error { buffer.WriteString(addHats(r)) buffer.WriteByte(0) // send wire.ActAddRcptPar when that is the only allowed action, or we need to send it because esmptArgs ist set - if (esmtpArgs != "" && m.actions&OptAddRcptWithArgs != 0) || (esmtpArgs == "" && m.actions&OptAddRcptWithArgs != 0) { + if (esmtpArgs != "" && m.actions&OptAddRcptWithArgs != 0) || (esmtpArgs == "" && m.actions&OptAddRcpt == 0) { buffer.WriteString(esmtpArgs) buffer.WriteByte(0) code = wire.ActAddRcptPar From bae43fb4f10766c0df6cc9c073e667b60fd9dbf7 Mon Sep 17 00:00:00 2001 From: Daniel Jagszent Date: Fri, 3 Mar 2023 19:06:59 +0100 Subject: [PATCH 3/4] fix: client: Helo can be called multiple times (EHLO after STARTTLS). --- client.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/client.go b/client.go index 92fa774..c8191dc 100644 --- a/client.go +++ b/client.go @@ -428,7 +428,7 @@ func (s *ClientSession) readAction(skipOk bool) (*Action, error) { switch act.Type { case ActionSkip: if !skipOk { - return nil, fmt.Errorf("action read: unexpected skip message received (can only be received after SMFIC_BODY when SMFIP_SKIP was negotiated)") + return nil, fmt.Errorf("action read: unexpected skip message received (can only be received after SMFIC_RCPT, SMFIC_HEADER, SMFIC_BODY when SMFIP_SKIP was negotiated)") } case ActionReject: act.SMTPCode = 550 @@ -507,7 +507,7 @@ func (s *ClientSession) Conn(hostname string, family ProtoFamily, port uint16, a // // It should be called once per milter session (from Client.Session to Close). func (s *ClientSession) Helo(helo string) (*Action, error) { - if s.state != clientStateConnectCalled { + if s.state != clientStateConnectCalled && s.state != clientStateHeloCalled { return nil, s.errorOut(fmt.Errorf("milter: in wrong state %d", s.state)) } From 2fe0288f23358d958bf223f8a4317de9f8ce46d9 Mon Sep 17 00:00:00 2001 From: Daniel Jagszent Date: Fri, 3 Mar 2023 19:31:56 +0100 Subject: [PATCH 4/4] chore: documentation --- mailfilter/mailfilter.go | 12 ++++++++++++ mailfilter/option.go | 2 +- 2 files changed, 13 insertions(+), 1 deletion(-) diff --git a/mailfilter/mailfilter.go b/mailfilter/mailfilter.go index 7a39303..3f33204 100644 --- a/mailfilter/mailfilter.go +++ b/mailfilter/mailfilter.go @@ -12,6 +12,18 @@ import ( "github.com/d--j/go-milter" ) +// DecisionModificationFunc is the callback function that you need to implement to create a mail filter. +// +// ctx is a [context.Context] that might get canceled when the connection to the MTA fails while your callback is running. +// If your decision function is running longer than one second the [MailFilter] automatically sends progress notifications +// every second so that MTA does not time out the milter connection. +// +// transaction is the [Transaction] object that you can inspect to see what the [MailFilter] got as information about the current SMTP transaction. +// You can also use transaction to modify the transaction (e.g. change recipients, alter headers). +// +// decision is your [Decision] about this SMTP transaction. Use [Accept], [TempFail], [Reject], [Discard] or [CustomErrorResponse]. +// +// If you return a non-nil error [WithErrorHandling] will determine what happens with the current SMTP transaction. type DecisionModificationFunc func(ctx context.Context, transaction *Transaction) (decision Decision, err error) type MailFilter struct { diff --git a/mailfilter/option.go b/mailfilter/option.go index b1dcdb3..b085aae 100644 --- a/mailfilter/option.go +++ b/mailfilter/option.go @@ -27,7 +27,7 @@ const ( type ErrorHandling int const ( - // Error just throws the error + // Error just throws the error. The connection to the MTA will break and the MTA will decide what happens to the SMTP transaction. Error ErrorHandling = iota // AcceptWhenError accepts the transaction despite the error (it gets logged). AcceptWhenError