diff --git a/.build.yml b/.build.yml deleted file mode 100644 index 214a164..0000000 --- a/.build.yml +++ /dev/null @@ -1,19 +0,0 @@ -image: alpine/edge -packages: - - go - # Required by codecov - - bash - - findutils -sources: - - https://github.com/emersion/go-milter -tasks: - - build: | - cd go-milter - go build -v ./... - - test: | - cd go-milter - go test -coverprofile=coverage.txt -covermode=atomic ./... - - upload-coverage: | - cd go-milter - export CODECOV_TOKEN=8c0f7014-fcfa-4ed9-8972-542eb5958fb3 - curl -s https://codecov.io/bash | bash diff --git a/.github/workflows/go.yml b/.github/workflows/go.yml index 3e9a76d..9d7b40a 100644 --- a/.github/workflows/go.yml +++ b/.github/workflows/go.yml @@ -14,15 +14,15 @@ jobs: build: runs-on: ubuntu-latest steps: - - uses: actions/checkout@v3 + - uses: actions/checkout@v3 - - name: Set up Go - uses: actions/setup-go@v3 - with: - go-version: 1.19 + - name: Set up Go + uses: actions/setup-go@v3 + with: + go-version: 1.19 - - name: Build - run: go build -v ./... + - name: Build + run: go build -v ./... - - name: Test - run: go test -v ./... + - name: Test + run: go test -v ./... diff --git a/.gitignore b/.gitignore index daf913b..25712e4 100644 --- a/.gitignore +++ b/.gitignore @@ -22,3 +22,5 @@ _testmain.go *.exe *.test *.prof + +/.idea diff --git a/LICENSE b/LICENSE index fb3a567..5258dbc 100644 --- a/LICENSE +++ b/LICENSE @@ -1,6 +1,8 @@ BSD 2-Clause License Copyright (c) 2017 Bozhin Zafirov +Copyright (c) 2019 Simon Ser +Copyright (c) 2023 Daniel Jagszent All rights reserved. Redistribution and use in source and binary forms, with or without diff --git a/README.md b/README.md index c2cfc4a..8bd9001 100644 --- a/README.md +++ b/README.md @@ -1,10 +1,90 @@ # go-milter -[![GoDoc](https://godoc.org/github.com/emersion/go-milter?status.svg)](https://godoc.org/github.com/emersion/go-milter) -[![builds.sr.ht status](https://builds.sr.ht/~emersion/go-milter/commits.svg)](https://builds.sr.ht/~emersion/go-milter/commits?) +[![GoDoc](https://godoc.org/github.com/d--j/go-milter?status.svg)](https://godoc.org/github.com/d--j/go-milter) +![Build status](https://github.com/d--j/go-milter/actions/workflows/go.yml/badge.svg?branch=main) A Go library to write mail filters. +With this library you can write both the client (MTA/SMTP-Server) and server (milter filter) +in pure Go without sendmail's libmilter. + +## Features + +* Client & Server support milter protocol version 6 with all features. E.g.: + * all milter events including DATA, UNKNOWN, ABORT and QUIT NEW CONNECTION + * milter can skip e.g. body chunks when it does not need all chunks + * milter can send progress notifications when response can take some time + * milter can automatically instruct the MTA which macros it needs. +* UTF-8 support + +## Usage + +```go +package main + +import ( + "log" + "net" + "sync" + + "github.com/d--j/go-milter" +) + +type ExampleBackend struct { + milter.NoOpMilter +} + +func (b *ExampleBackend) RcptTo(rcptTo string, esmtpArgs string, m *milter.Modifier) (*milter.Response, error) { + // reject the mail when it goes to other-spammer@example.com and is a local delivery + if rcptTo == "other-spammer@example.com" && m.Macros.Get(milter.MacroRcptMailer) == "local" { + return milter.RejectWithCodeAndReason(550, "We do not like you\r\nvery much, please go away") + } + return milter.RespContinue, nil +} + +func main() { + // create socket to listen on + socket, err := net.Listen("tcp4", "127.0.0.1:6785") + if err != nil { + log.Fatal(err) + } + defer socket.Close() + + // define the backend, required actions, protocol options and macros we want + server := milter.NewServer( + milter.WithMilter(func() milter.Milter { + return &ExampleBackend{} + }), + milter.WithProtocol(milter.OptNoConnect|milter.OptNoHelo|milter.OptNoMailFrom|milter.OptNoBody|milter.OptNoHeaders|milter.OptNoEOH|milter.OptNoUnknown|milter.OptNoData), + milter.WithAction(milter.OptChangeFrom|milter.OptAddRcpt|milter.OptRemoveRcpt), + milter.WithMaroRequest(milter.StageRcpt, []milter.MacroName{milter.MacroRcptMailer}), + ) + defer server.Close() + + // start the milter + var wgDone sync.WaitGroup + wgDone.Add(1) + go func(socket net.Listener) { + if err := server.Serve(socket); err != nil { + log.Fatal(err) + } + wgDone.Done() + }(socket) + + log.Printf("Started milter on %s:%s", socket.Addr().Network(), socket.Addr().String()) + + // quit when milter quits + wgDone.Wait() +} +``` + +See [![GoDoc](https://godoc.org/github.com/d--j/go-milter?status.svg)](https://godoc.org/github.com/d--j/go-milter) for more documentation and an example for a milter client. + ## License BSD 2-Clause + +## Credits + +Based on https://github.com/emersion/go-milter by [Simon Ser](https://github.com/emersion) which is based on https://github.com/phalaaxx/milter by +[Bozhin Zafirov](https://github.com/phalaaxx). [Max Mazurov](https://github.com/foxcpp) made major contributions to this code as well. diff --git a/client.go b/client.go index 1602d4e..92fa774 100644 --- a/client.go +++ b/client.go @@ -1,369 +1,691 @@ package milter import ( - "bytes" "encoding/binary" "fmt" "io" "net" - "strconv" + "strings" "time" + "github.com/d--j/go-milter/internal/wire" + "github.com/d--j/go-milter/milterutil" "github.com/emersion/go-message/textproto" ) -// Milter protocol version implemented by the client. -// -// Note: Not exported as we might want to support multiple versions -// transparently in the future. -const clientProtocolVersion uint32 = 6 +// MaxClientProtocolVersion is the maximum Milter protocol version implemented by the client. +const MaxClientProtocolVersion uint32 = 6 -// Client is a wrapper for managing milter connections. -// -// Currently it just creates new connections using provided Dialer. -type Client struct { - opts ClientOptions - network string - address string -} +const allClientSupportedProtocolMasks = OptNoConnect | OptNoHelo | OptNoMailFrom | OptNoRcptTo | OptNoBody | OptNoHeaders | OptNoEOH | OptNoUnknown | OptNoData | OptSkip | OptRcptRej | OptNoHeaderReply | OptNoConnReply | OptNoHeloReply | OptNoMailReply | OptNoRcptReply | OptNoDataReply | OptNoUnknownReply | OptNoEOHReply | OptNoBodyReply | OptHeaderLeadingSpace // SMFI_CURR_PROT +const allClientSupportedProtocolMasksV2 = OptNoConnect | OptNoHelo | OptNoMailFrom | OptNoRcptTo | OptNoBody | OptNoHeaders | OptNoEOH // SMFI_V2_PROT +const allClientSupportedProtocolMasksV3 = allClientSupportedProtocolMasksV2 | OptNoUnknown +const allClientSupportedProtocolMasksV4 = allClientSupportedProtocolMasksV3 | OptNoData + +const AllClientSupportedActionMasks = OptAddHeader | OptChangeBody | OptAddRcpt | OptRemoveRcpt | OptChangeHeader | OptQuarantine | OptChangeFrom | OptAddRcptWithArgs | OptSetMacros +const allClientSupportedActionMasksV2 = OptAddHeader | OptChangeBody | OptAddRcpt | OptRemoveRcpt | OptChangeHeader | OptQuarantine +// Dialer is the interface of the only method we use of a net.Dialer. type Dialer interface { Dial(network string, addr string) (net.Conn, error) } -type ClientOptions struct { - Dialer Dialer - ReadTimeout time.Duration - WriteTimeout time.Duration - ActionMask OptAction - ProtocolMask OptProtocol -} - -var defaultOptions = ClientOptions{ - Dialer: &net.Dialer{ - Timeout: 10 * time.Second, - }, - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, - ActionMask: OptAddHeader | OptAddRcpt | OptChangeBody | OptChangeFrom | OptChangeHeader, - ProtocolMask: 0, +// Client is a wrapper for managing milter connections to one milter. +// +// You need to call Session to actually open a connection to the milter. +type Client struct { + options options + network string + address string } -// NewDefaultClient creates a new Client object using default options. +// NewClient creates a new Client object connection to a miter at network / address. +// If you do not specify any opts the defaults are: // // It uses 10 seconds for connection/read/write timeouts and allows milter to // send any actions supported by library. -func NewDefaultClient(network, address string) *Client { - return NewClientWithOptions(network, address, defaultOptions) -} - -// NewClientWithOptions creates a new Client object using provided options. // -// You generally want to use options to restrict ActionMask to what your code -// supports and ProtocolMask to what you intend to submit. +// You generally want to use WithAction to advertise to the milter what modification options your MTA supports. +// A value of 0 is a valid value –then your MTA only supports accepting or rejecting an SMTP transaction. +// +// If WithDialer is not used, a net.Dialer with 10 seconds connection timeout will be used. +// If WithMaximumVersion is not used, MaxClientProtocolVersion will be used. +// If WithProtocol or WithProtocols is not set, it defaults to all protocol features the library can handle for the specified maximum milter version. +// If WithOfferedMaxData is not used, DataSize64K will be used. +// If WithoutDefaultMacros or WithMacroRequest are not used the following default macro stages are used: // -// If opts.Dialer is not set, empty net.Dialer object will be used. -func NewClientWithOptions(network, address string, opts ClientOptions) *Client { - if opts.Dialer == nil { - opts.Dialer = &net.Dialer{} +// WithMacroRequest(StageConnect, []MacroName{MacroMTAFullyQualifiedDomainName, MacroDaemonName, MacroIfName, MacroIfAddr}) +// WithMacroRequest(StageHelo, []MacroName{MacroTlsVersion, MacroCipher, MacroCipherBits, MacroCertSubject, MacroCertIssuer}) +// WithMacroRequest(StageMail, []MacroName{MacroAuthType, MacroAuthAuthen, MacroAuthSsf, MacroAuthAuthor, MacroMailMailer, MacroMailHost, MacroMailAddr}) +// WithMacroRequest(StageRcpt, []MacroName{MacroRcptMailer, MacroRcptHost, MacroRcptAddr}) +// WithMacroRequest(StageEOM, []MacroName{MacroQueueId}) +// +// This function will panic when you provide invalid options. +func NewClient(network, address string, opts ...Option) *Client { + options := options{ + dialer: &net.Dialer{ + Timeout: 10 * time.Second, + }, + readTimeout: 10 * time.Second, + writeTimeout: 10 * time.Second, + maxVersion: MaxClientProtocolVersion, + actions: AllClientSupportedActionMasks, + protocol: allClientSupportedProtocolMasks, + offeredMaxData: DataSize64K, + usedMaxData: DataSize64K, + macrosByStage: [][]MacroName{ + {MacroMTAFullyQualifiedDomainName, MacroDaemonName, MacroIfName, MacroIfAddr}, // StageConnect + {MacroTlsVersion, MacroCipher, MacroCipherBits, MacroCertSubject, MacroCertIssuer}, // StageHelo + {MacroAuthType, MacroAuthAuthen, MacroAuthSsf, MacroAuthAuthor, MacroMailMailer, MacroMailHost, MacroMailAddr}, // StageMail + {MacroRcptMailer, MacroRcptHost, MacroRcptAddr}, // StageRcpt + {}, // StageData + {MacroQueueId}, // StageEOM + {}, // StageEOH + }, + } + if len(opts) > 0 { + for _, o := range opts { + if o != nil { + o(&options) + } + } + } + + if options.dialer == nil { + panic("milter: you cannot pass to WithDialer") + } + if options.maxVersion > MaxClientProtocolVersion || options.maxVersion == 1 { + panic("milter: this library cannot handle this milter version") + } + if options.offeredMaxData != DataSize64K && options.offeredMaxData != DataSize256K && options.offeredMaxData != DataSize1M { + panic("milter: wrong data size passed to WithOfferedMaxData") + } + // ensure we only offer protocol options the version can handel + if options.protocol != 0 { + var all OptProtocol + switch options.maxVersion { + case 2: + all = allClientSupportedProtocolMasksV2 + case 3: + all = allClientSupportedProtocolMasksV3 + case 4: + all = allClientSupportedProtocolMasksV4 + default: + all = allClientSupportedProtocolMasks + } + if options.protocol&^all != 0 { + panic(fmt.Sprintf("Provided invalid protocol options for milter version %d %032b", options.maxVersion, options.protocol)) + } + } + // offering nothing to filters is unlikely, just default to all we can handle + if options.protocol == 0 { + switch options.maxVersion { + case 2: + options.protocol = allClientSupportedProtocolMasksV2 + case 3: + options.protocol = allClientSupportedProtocolMasksV3 + case 4, 5: + options.protocol = allClientSupportedProtocolMasksV4 + default: + options.protocol = allClientSupportedProtocolMasks + } + } + if options.newMilter != nil { + panic("milter: WithMilter/WithDynamicMilter is a server only option") + } + if options.negotiationCallback != nil { + panic("milter: WithNegotiationCallback is a server only option") } return &Client{ - opts: opts, + options: options, network: network, address: address, } } -func (c *Client) Session() (*ClientSession, error) { - s := &ClientSession{ - readTimeout: c.opts.ReadTimeout, - writeTimeout: c.opts.WriteTimeout, - } - - // TODO(foxcpp): Connection pooling. +// String returns the network and address that his Client is configured to connect to. +// This method is go-routine save. +func (c *Client) String() string { + return fmt.Sprintf("%s:%s", c.network, c.address) +} - conn, err := c.opts.Dialer.Dial(c.network, c.address) +// Session opens a new connection to this milter and negotiates protocol features with it. +// +// The macros parameter defines the Macros this ClientSession will use to send to the milter. +// It can be nil then this session will not send any macros to the milter. +// Set macro values as soon as you know them (e.g. the MacroMTAFullyQualifiedDomainName macro can be set before calling Session). +// It is your responsibility to clear command specific macros like MacroRcptMailer after +// the command got executed (on all milters in a list of milters). +// +// This method is go-routine save. +func (c *Client) Session(macros Macros) (*ClientSession, error) { + conn, err := c.options.dialer.Dial(c.network, c.address) if err != nil { return nil, fmt.Errorf("milter: session create: %w", err) } + return c.session(conn, macros) +} + +func (c *Client) session(conn net.Conn, macros Macros) (*ClientSession, error) { + s := &ClientSession{ + readTimeout: c.options.readTimeout, + writeTimeout: c.options.writeTimeout, + state: clientStateClosed, + macros: macros, + macrosByStages: make([][]string, StageEndMarker), + maxBodySize: uint32(c.options.usedMaxData), + } + if c.options.macrosByStage != nil { + copy(s.macrosByStages, c.options.macrosByStage) + } + + s.state = clientStateNegotiated + s.conn = conn - if err := s.negotiate(c.opts.ActionMask, c.opts.ProtocolMask); err != nil { + if err := s.negotiate(c.options.maxVersion, c.options.actions, c.options.protocol, c.options.offeredMaxData); err != nil { return nil, err } return s, nil } -func (c *Client) Close() error { - // Reserved for use in connection pooling. - return nil -} +type clientSessionState uint32 + +const ( + clientStateClosed = iota + clientStateNegotiated + clientStateConnectCalled + clientStateHeloCalled + clientStateMailCalled + clientStateRcptCalled + clientStateDataCalled + clientStateHeaderFieldCalled + clientStateHeaderEndCalled + clientStateBodyChunkCalled + clientStateError +) +// ClientSession is a connection to one Client for one SMTP connection. type ClientSession struct { conn net.Conn + // negotiated version of this session + version uint32 + // Bitmask of negotiated action options. - ActionOpts OptAction + actionOpts OptAction // Bitmask of negotiated protocol options. - ProtocolOpts OptProtocol + protocolOpts OptProtocol - needAbort bool + maxBodySize uint32 + negotiatedBodySize uint32 + + state clientSessionState + skip bool + skipUnknown bool + closedErr error readTimeout time.Duration writeTimeout time.Duration + + macros Macros + macrosByStages [][]MacroName +} + +func (s *ClientSession) errorOut(err error) error { + s.state = clientStateError + // close the connection + if s.conn != nil { + _ = s.conn.Close() + } + // give garbage collector a chance to free space + s.macros = nil + s.macrosByStages = nil + return err } -// negotiate exchanges OPTNEG messages with the milter and sets s.mask to the -// negotiated value. -func (s *ClientSession) negotiate(actionMask OptAction, protoMask OptProtocol) error { +// negotiate exchanges OPTNEG messages with the milter and configures this session to the negotiated values. +func (s *ClientSession) negotiate(maximumVersion uint32, actionMask OptAction, protoMask OptProtocol, requestedMaxBuffer DataSize) error { // Send our mask, get mask from milter.. - msg := &Message{ - Code: byte(CodeOptNeg), // TODO(foxcpp): Get rid of casts by changing msg.Code to have Code type + msg := &wire.Message{ + Code: wire.CodeOptNeg, Data: make([]byte, 4*3), } - binary.BigEndian.PutUint32(msg.Data, clientProtocolVersion) + binary.BigEndian.PutUint32(msg.Data, maximumVersion) binary.BigEndian.PutUint32(msg.Data[4:], uint32(actionMask)) - binary.BigEndian.PutUint32(msg.Data[8:], uint32(protoMask)) + if requestedMaxBuffer == DataSize256K { + binary.BigEndian.PutUint32(msg.Data[8:], uint32(protoMask)|optMds256K) + } else if requestedMaxBuffer == DataSize1M { + binary.BigEndian.PutUint32(msg.Data[8:], uint32(protoMask)|optMds1M) + } else { + binary.BigEndian.PutUint32(msg.Data[8:], uint32(protoMask)) + } - if err := writePacket(s.conn, msg, s.writeTimeout); err != nil { - return fmt.Errorf("milter: negotiate: optneg write: %w", err) + if err := s.writePacket(msg); err != nil { + return s.errorOut(fmt.Errorf("milter: negotiate: optneg write: %w", err)) } - msg, err := readPacket(s.conn, s.readTimeout) + msg, err := wire.ReadPacket(s.conn, s.readTimeout) if err != nil { - return fmt.Errorf("milter: negotiate: optneg read: %w", err) + return s.errorOut(fmt.Errorf("milter: negotiate: optneg read: %w", err)) } - if Code(msg.Code) != CodeOptNeg { - return fmt.Errorf("milter: negotiate: unexpected code: %v", rune(msg.Code)) + if msg.Code != wire.CodeOptNeg { + return s.errorOut(fmt.Errorf("milter: negotiate: unexpected code: %v", rune(msg.Code))) } if len(msg.Data) < 4*3 /* version + action mask + proto mask */ { - return fmt.Errorf("milter: negotiate: unexpected data size: %v", len(msg.Data)) + return s.errorOut(fmt.Errorf("milter: negotiate: unexpected data size: %v", len(msg.Data))) } - milterVersion := binary.BigEndian.Uint32(msg.Data[:4]) + milterVersion := binary.BigEndian.Uint32(msg.Data[0:]) - // Not a strict comparison since we might be able to work correctly with - // milter using a newer protocol as long as masks negotiated are meaningful. - if milterVersion < clientProtocolVersion { - return fmt.Errorf("milter: negotiate: unsupported protocol version: %v", milterVersion) + if milterVersion < 2 || milterVersion > maximumVersion { + return s.errorOut(fmt.Errorf("milter: negotiate: unsupported protocol version: %v", milterVersion)) } - milterActionMask := binary.BigEndian.Uint32(msg.Data[4:]) - s.ActionOpts = OptAction(milterActionMask) - milterProtoMask := binary.BigEndian.Uint32(msg.Data[8:]) - s.ProtocolOpts = OptProtocol(milterProtoMask) + s.version = milterVersion + + milterActionMask := OptAction(binary.BigEndian.Uint32(msg.Data[4:])) + if milterActionMask&actionMask != milterActionMask { + return s.errorOut(fmt.Errorf("milter: negotiate: unsupported actions requested: MTA %032b filter %032b", actionMask, milterActionMask)) + } + s.actionOpts = milterActionMask + milterProtoMask := OptProtocol(binary.BigEndian.Uint32(msg.Data[8:])) + + if uint32(milterProtoMask)&optMds1M == optMds1M { + s.negotiatedBodySize = uint32(DataSize1M) + } else if uint32(milterProtoMask)&optMds256K == optMds256K { + s.negotiatedBodySize = uint32(DataSize256K) + } else { + s.negotiatedBodySize = uint32(DataSize64K) + } + + // mask out the size flags + milterProtoMask = milterProtoMask & (^OptProtocol(optInternal)) + if milterProtoMask&protoMask != milterProtoMask { + return s.errorOut(fmt.Errorf("milter: negotiate: unsupported protocol options requested: MTA %032b filter %032b", protoMask, milterProtoMask)) + } + + // do not send commands that older versions do not understand + if milterVersion <= 2 { + milterProtoMask = milterProtoMask | OptNoUnknown + } + if milterVersion <= 3 { + milterProtoMask = milterProtoMask | OptNoData + } - s.needAbort = true + s.protocolOpts = milterProtoMask + + s.state = clientStateNegotiated + + // The filter defined macros it wants to get we only use them and not the defaults + if len(msg.Data) > 4*4 { + s.macrosByStages = make([][]string, StageEndMarker) + l := len(msg.Data) + offset := 4 * 3 + for l > offset+4 { + stage := binary.BigEndian.Uint32(msg.Data[offset:]) + offset += 4 + requestedMacros := wire.ReadCString(msg.Data[offset:]) + offset += len(requestedMacros) + if l <= offset || msg.Data[offset] != 0 { + LogWarning("macros for stage %d are not null-terminated, skipping rest of list: %s", stage, requestedMacros) + break + } + offset += 1 // skip null byte + if stage < uint32(StageConnect) || stage >= uint32(StageEndMarker) { + LogWarning("got request for unknown stage %d, ignoring this entry", stage) + continue + } + if s.macrosByStages[MacroStage(stage)] != nil { + LogWarning("macros for stage %d were send multiple times: %q is overwriting %q", stage, requestedMacros, strings.Join(s.macrosByStages[MacroStage(stage)], " ")) + } + s.macrosByStages[MacroStage(stage)] = parseRequestedMacros(requestedMacros) + } + } + for i := range s.macrosByStages { + if s.macrosByStages[i] != nil { + s.macrosByStages[i] = removeDuplicates(s.macrosByStages[i]) + } + } return nil } -// ProtocolOption checks whether the option is set in negotiated options, that -// is, requested by both sides. +// ProtocolOption checks whether the option is set in negotiated options. func (s *ClientSession) ProtocolOption(opt OptProtocol) bool { - return s.ProtocolOpts&opt != 0 + return s.protocolOpts&opt != 0 } -// ActionOption checks whether the option is set in negotiated options, that -// is, requested by both sides. +// ActionOption checks whether the option is set in negotiated options. func (s *ClientSession) ActionOption(opt OptAction) bool { - return s.ActionOpts&opt != 0 + return s.actionOpts&opt != 0 } -func (s *ClientSession) Macros(code Code, kv ...string) error { - // Note: kv is ...string with the expectation that the list of macro names - // will be static and not dynamically constructed. - - msg := &Message{ - Code: byte(CodeMacro), +func (s *ClientSession) sendMacros(code wire.Code, names []MacroName) error { + if s.macros == nil { + return nil + } + msg := &wire.Message{ + Code: wire.CodeMacro, Data: []byte{byte(code)}, } - for _, str := range kv { - msg.Data = appendCString(msg.Data, str) + foundMacro := false + for _, name := range names { + // only send macros we actually defined + if val, ok := s.macros.GetEx(name); ok { + foundMacro = true + msg.Data = wire.AppendCString(msg.Data, name) + msg.Data = wire.AppendCString(msg.Data, val) + } + } + // no need to send anything when we have not found a single macro + if !foundMacro { + return nil } - if err := writePacket(s.conn, msg, s.writeTimeout); err != nil { - return fmt.Errorf("milter: macros: %w", err) + if err := s.writePacket(msg); err != nil { + return fmt.Errorf("milter: sendMacros: %w", err) } return nil } -func appendUint16(dest []byte, val uint16) []byte { - dest = append(dest, 0x00, 0x00) - binary.BigEndian.PutUint16(dest[len(dest)-2:], val) - return dest -} +func (s *ClientSession) sendCmdMacros(code wire.Code, macros map[MacroName]string) error { + if len(macros) == 0 { + return nil + } + msg := &wire.Message{ + Code: wire.CodeMacro, + Data: []byte{byte(code)}, + } + for name, val := range macros { + msg.Data = wire.AppendCString(msg.Data, name) + msg.Data = wire.AppendCString(msg.Data, val) + + } -type Action struct { - Code ActionCode + if err := s.writePacket(msg); err != nil { + return fmt.Errorf("milter: sendMacros: %w", err) + } - // SMTP code if Code == ActReplyCode. - SMTPCode int - // Reply text if Code == ActReplyCode. - SMTPText string + return nil } -func (s *ClientSession) readAction() (*Action, error) { +func (s *ClientSession) readAction(skipOk bool) (*Action, error) { for { - msg, err := readPacket(s.conn, s.readTimeout) + msg, err := wire.ReadPacket(s.conn, s.readTimeout) if err != nil { - return nil, fmt.Errorf("action read: %w", err) + return nil, s.errorOut(fmt.Errorf("action read: %w", err)) } - if msg.Code == 'p' /* progress */ { + if wire.ActionCode(msg.Code) == wire.ActProgress /* progress */ { continue } - if ActionCode(msg.Code) != ActContinue { - s.needAbort = false - } - - return parseAction(msg) - } -} - -func parseAction(msg *Message) (*Action, error) { - act := &Action{ - Code: ActionCode(msg.Code), - } - var err error - switch ActionCode(msg.Code) { - case ActAccept, ActContinue, ActDiscard, ActReject, ActTempFail: - case ActReplyCode: - if len(msg.Data) <= 4 { - return nil, fmt.Errorf("action read: unexpected data length: %v", len(msg.Data)) - } - act.SMTPCode, err = strconv.Atoi(string(msg.Data[:3])) + act, err := parseAction(msg) if err != nil { - return nil, fmt.Errorf("action read: malformed SMTP code: %v", msg.Data[:3]) + return nil, err } - // There is 0x20 (' ') in between. - act.SMTPText = readCString(msg.Data[4:]) - default: - return nil, fmt.Errorf("action read: unexpected code: %v", msg.Code) + switch act.Type { + case ActionSkip: + if !skipOk { + return nil, fmt.Errorf("action read: unexpected skip message received (can only be received after SMFIC_BODY when SMFIP_SKIP was negotiated)") + } + case ActionReject: + act.SMTPCode = 550 + act.SMTPReply = "550 5.7.1 Command rejected" + case ActionTempFail: + act.SMTPCode = 451 + act.SMTPReply = "451 4.7.1 Service unavailable - try again later" + } + + return act, err } +} - return act, nil +func (s *ClientSession) writePacket(msg *wire.Message) error { + return wire.WritePacket(s.conn, msg, s.writeTimeout) } // Conn sends the connection information to the milter. // // It should be called once per milter session (from Session to Close). +// Exception: After you called Reset you need to call Conn again. func (s *ClientSession) Conn(hostname string, family ProtoFamily, port uint16, addr string) (*Action, error) { - if s.ProtocolOpts&OptNoConnect != 0 { - return &Action{Code: ActContinue}, nil + if s.state != clientStateNegotiated { + return nil, s.errorOut(fmt.Errorf("milter: in wrong state %d", s.state)) + } + + s.skip = false + s.state = clientStateConnectCalled + + if len(s.macrosByStages) > int(StageConnect) && len(s.macrosByStages[StageConnect]) > 0 { + if err := s.sendMacros(wire.CodeConn, s.macrosByStages[StageConnect]); err != nil { + return nil, err + } + } + + if s.ProtocolOption(OptNoConnect) { + return &Action{Type: ActionContinue}, nil } - msg := &Message{ - Code: byte(CodeConn), + msg := &wire.Message{ + Code: wire.CodeConn, } - msg.Data = appendCString(msg.Data, hostname) + msg.Data = wire.AppendCString(msg.Data, hostname) msg.Data = append(msg.Data, byte(family)) if family != FamilyUnknown { if family == FamilyInet || family == FamilyInet6 { - msg.Data = appendUint16(msg.Data, port) + msg.Data = wire.AppendUint16(msg.Data, port) + } else if family == FamilyUnix { + msg.Data = wire.AppendUint16(msg.Data, 0) } - msg.Data = appendCString(msg.Data, addr) + msg.Data = wire.AppendCString(msg.Data, addr) } - if err := writePacket(s.conn, msg, s.writeTimeout); err != nil { - return nil, fmt.Errorf("milter: conn: %w", err) + if err := s.writePacket(msg); err != nil { + return nil, s.errorOut(fmt.Errorf("milter: conn: %w", err)) } - if !s.ProtocolOption(OptNoConnReply) { - act, err := s.readAction() - if err != nil { - return nil, fmt.Errorf("milter: conn: %w", err) - } - return act, nil + if s.ProtocolOption(OptNoConnReply) { + return &Action{Type: ActionContinue}, nil } - return &Action{Code: ActContinue}, nil + + act, err := s.readAction(false) + if err != nil { + return nil, s.errorOut(fmt.Errorf("milter: conn: %w", err)) + } + + if act.Type == ActionDiscard { + LogWarning("Connect got a discard action, ignoring it") + act.Type = ActionContinue + } + + return act, nil } // Helo sends the HELO hostname to the milter. // -// It should be called once per milter session (from Session to Close). +// It should be called once per milter session (from Client.Session to Close). func (s *ClientSession) Helo(helo string) (*Action, error) { + if s.state != clientStateConnectCalled { + return nil, s.errorOut(fmt.Errorf("milter: in wrong state %d", s.state)) + } + + s.skip = false + s.state = clientStateHeloCalled + + if len(s.macrosByStages) > int(StageHelo) && len(s.macrosByStages[StageHelo]) > 0 { + if err := s.sendMacros(wire.CodeHelo, s.macrosByStages[StageHelo]); err != nil { + return nil, s.errorOut(err) + } + } + // Synthesise response as if server replied "go on" while in fact it does - // not support that message. - if s.ProtocolOpts&OptNoHelo != 0 { - return &Action{Code: ActContinue}, nil + // not want or support that message. + if s.ProtocolOption(OptNoHelo) { + return &Action{Type: ActionContinue}, nil } - msg := &Message{ - Code: byte(CodeHelo), - Data: appendCString(nil, helo), + msg := &wire.Message{ + Code: wire.CodeHelo, + Data: wire.AppendCString(nil, helo), } - if err := writePacket(s.conn, msg, s.writeTimeout); err != nil { - return nil, fmt.Errorf("milter: helo: %w", err) + if err := s.writePacket(msg); err != nil { + return nil, s.errorOut(fmt.Errorf("milter: helo: %w", err)) } - if !s.ProtocolOption(OptNoHeloReply) { - act, err := s.readAction() - if err != nil { - return nil, fmt.Errorf("milter: helo: %w", err) - } - return act, nil + if s.ProtocolOption(OptNoHeloReply) { + return &Action{Type: ActionContinue}, nil } - return &Action{Code: ActContinue}, nil + + act, err := s.readAction(false) + if err != nil { + return nil, s.errorOut(fmt.Errorf("milter: helo: %w", err)) + } + + if act.Type == ActionDiscard { + LogWarning("Helo got a discard action, ignoring it") + act.Type = ActionContinue + } + + return act, nil } -func (s *ClientSession) Mail(sender string, esmtpArgs []string) (*Action, error) { - if s.ProtocolOpts&OptNoMailFrom != 0 { - return &Action{Code: ActContinue}, nil +// Mail sends the sender (with optional esmtpArgs) to the milter. +func (s *ClientSession) Mail(sender string, esmtpArgs string) (*Action, error) { + if s.state != clientStateHeloCalled { + return nil, s.errorOut(fmt.Errorf("milter: in wrong state %d", s.state)) } - msg := &Message{ - Code: byte(CodeMail), + s.skip = false + s.state = clientStateMailCalled + + if len(s.macrosByStages) > int(StageMail) && len(s.macrosByStages[StageMail]) > 0 { + if err := s.sendMacros(wire.CodeMail, s.macrosByStages[StageMail]); err != nil { + return nil, s.errorOut(err) + } } - msg.Data = appendCString(msg.Data, "<"+sender+">") - for _, arg := range esmtpArgs { - msg.Data = appendCString(msg.Data, arg) + if s.ProtocolOption(OptNoMailFrom) { + return &Action{Type: ActionContinue}, nil } - if err := writePacket(s.conn, msg, s.writeTimeout); err != nil { - return nil, fmt.Errorf("milter: mail: %w", err) + msg := &wire.Message{ + Code: wire.CodeMail, } - if !s.ProtocolOption(OptNoMailReply) { - act, err := s.readAction() - if err != nil { - return nil, fmt.Errorf("milter: mail: %w", err) - } - return act, nil + msg.Data = wire.AppendCString(msg.Data, "<"+sender+">") + if len(esmtpArgs) > 0 { + msg.Data = wire.AppendCString(msg.Data, esmtpArgs) + } + + if err := s.writePacket(msg); err != nil { + return nil, s.errorOut(fmt.Errorf("milter: mail: %w", err)) } - return &Action{Code: ActContinue}, nil + + if s.ProtocolOption(OptNoMailReply) { + return &Action{Type: ActionContinue}, nil + } + + act, err := s.readAction(false) + if err != nil { + return nil, s.errorOut(fmt.Errorf("milter: mail: %w", err)) + } + return act, nil } -func (s *ClientSession) Rcpt(rcpt string, esmtpArgs []string) (*Action, error) { - if s.ProtocolOpts&OptNoRcptTo != 0 { - return &Action{Code: ActContinue}, nil +// Rcpt sends the RCPT TO rcpt (with optional esmtpArgs) to the milter. +// If s.ProtocolOption(OptRcptRej) is true the milter wants rejected recipients. +// The default is to only send valid recipients to the milter. +func (s *ClientSession) Rcpt(rcpt string, esmtpArgs string) (*Action, error) { + if s.state != clientStateMailCalled && s.state != clientStateRcptCalled { + return nil, s.errorOut(fmt.Errorf("milter: in wrong state %d", s.state)) } + if s.skip { + return &Action{Type: ActionContinue}, nil + } + + s.state = clientStateRcptCalled - msg := &Message{ - Code: byte(CodeRcpt), + if len(s.macrosByStages) > int(StageRcpt) && len(s.macrosByStages[StageRcpt]) > 0 { + if err := s.sendMacros(wire.CodeRcpt, s.macrosByStages[StageRcpt]); err != nil { + return nil, s.errorOut(err) + } } - msg.Data = appendCString(msg.Data, "<"+rcpt+">") - for _, arg := range esmtpArgs { - msg.Data = appendCString(msg.Data, arg) + if s.ProtocolOption(OptNoRcptTo) { + return &Action{Type: ActionContinue}, nil } - if err := writePacket(s.conn, msg, s.writeTimeout); err != nil { - return nil, fmt.Errorf("milter: rcpt: %w", err) + msg := &wire.Message{ + Code: wire.CodeRcpt, } - if !s.ProtocolOption(OptNoRcptReply) { - act, err := s.readAction() - if err != nil { - return nil, fmt.Errorf("milter: rcpt: %w", err) + msg.Data = wire.AppendCString(msg.Data, "<"+rcpt+">") + if len(esmtpArgs) > 0 { + msg.Data = wire.AppendCString(msg.Data, esmtpArgs) + } + + if err := s.writePacket(msg); err != nil { + return nil, s.errorOut(fmt.Errorf("milter: rcpt: %w", err)) + } + + if s.ProtocolOption(OptNoRcptReply) { + return &Action{Type: ActionContinue}, nil + } + + act, err := s.readAction(s.ProtocolOption(OptSkip)) + if err != nil { + return nil, s.errorOut(fmt.Errorf("milter: rcpt: %w", err)) + } + if act.Type == ActionSkip { + s.skip = true + return &Action{Type: ActionContinue}, nil + } + return act, nil +} + +// DataStart sends the start of the DATA command to the milter. +// DataStart can be automatically called from Header, but you should normally call it explicitly. +// +// When your MTA can handle multiple milter in a chain, DataStart is the last event that is called individually for each milter in the chain. +// After DataStart you need to call the HeaderField/Header and BodyChunk&End/BodyReadFrom calls for the whole message serially to each milter. +// The first milter may alter the message and the next milter should receive the altered message, not the original message. +func (s *ClientSession) DataStart() (*Action, error) { + if s.state != clientStateRcptCalled { + return nil, s.errorOut(fmt.Errorf("milter: in wrong state %d", s.state)) + } + s.skip = false + s.state = clientStateDataCalled + + if s.version > 3 && len(s.macrosByStages) > int(StageData) && len(s.macrosByStages[StageData]) > 0 { + if err := s.sendMacros(wire.CodeData, s.macrosByStages[StageData]); err != nil { + return nil, s.errorOut(err) } - return act, nil } - return &Action{Code: ActContinue}, nil + + if s.ProtocolOption(OptNoData) { + return &Action{Type: ActionContinue}, nil + } + + msg := &wire.Message{ + Code: wire.CodeData, + } + + if err := s.writePacket(msg); err != nil { + return nil, s.errorOut(fmt.Errorf("milter: rcpt: %w", err)) + } + + if s.ProtocolOption(OptNoDataReply) { + return &Action{Type: ActionContinue}, nil + } + + act, err := s.readAction(false) + if err != nil { + return nil, s.errorOut(fmt.Errorf("milter: rcpt: %w", err)) + } + return act, nil } // HeaderField sends a single header field to the milter. @@ -371,66 +693,113 @@ func (s *ClientSession) Rcpt(rcpt string, esmtpArgs []string) (*Action, error) { // Value should be the original field value without any unfolding applied. // // HeaderEnd() must be called after the last field. -func (s *ClientSession) HeaderField(key, value string) (*Action, error) { - if s.ProtocolOpts&OptNoHeaders != 0 { - return &Action{Code: ActContinue}, nil +// +// You can send macros to the milter with macros. They only get send to the milter when it wants header values and it did not send a skip response. +// Thus, the macros you send here should be relevant to this header only. +func (s *ClientSession) HeaderField(key, value string, macros map[MacroName]string) (*Action, error) { + if s.state > clientStateHeaderFieldCalled || s.state < clientStateDataCalled { + return nil, s.errorOut(fmt.Errorf("milter: in wrong state %d", s.state)) } + if s.skip { + return &Action{Type: ActionContinue}, nil + } + + s.state = clientStateHeaderFieldCalled - msg := &Message{ - Code: byte(CodeHeader), + if s.ProtocolOption(OptNoHeaders) { + return &Action{Type: ActionContinue}, nil } - msg.Data = appendCString(msg.Data, key) - msg.Data = appendCString(msg.Data, value) - if err := writePacket(s.conn, msg, s.writeTimeout); err != nil { - return nil, fmt.Errorf("milter: header field: %w", err) + if err := s.sendCmdMacros(wire.CodeHeader, macros); err != nil { + return nil, s.errorOut(err) } - if !s.ProtocolOption(OptNoHeaderReply) { - act, err := s.readAction() - if err != nil { - return nil, fmt.Errorf("milter: header field: %w", err) - } - return act, nil + msg := &wire.Message{ + Code: wire.CodeHeader, + } + msg.Data = wire.AppendCString(msg.Data, key) + msg.Data = wire.AppendCString(msg.Data, value) + + if err := s.writePacket(msg); err != nil { + return nil, s.errorOut(fmt.Errorf("milter: header field: %w", err)) } - return &Action{Code: ActContinue}, nil + + if s.ProtocolOption(OptNoHeaderReply) { + return &Action{Type: ActionContinue}, nil + } + + act, err := s.readAction(s.ProtocolOption(OptSkip)) + if err != nil { + return nil, s.errorOut(fmt.Errorf("milter: header field: %w", err)) + } + if act.Type == ActionSkip { + s.skip = true + return &Action{Type: ActionContinue}, nil + } + return act, nil } // HeaderEnd send the EOH (End-Of-Header) message to the milter. // // No HeaderField calls are allowed after this point. func (s *ClientSession) HeaderEnd() (*Action, error) { - if s.ProtocolOpts&OptNoEOH != 0 { - return &Action{Code: ActContinue}, nil + if s.state > clientStateHeaderFieldCalled || s.state < clientStateDataCalled { + return nil, s.errorOut(fmt.Errorf("milter: in wrong state %d", s.state)) } + s.skip = false + s.state = clientStateHeaderEndCalled - if err := writePacket(s.conn, &Message{ - Code: byte(CodeEOH), - }, s.writeTimeout); err != nil { - return nil, fmt.Errorf("milter: header end: %w", err) + if len(s.macrosByStages) > int(StageEOH) && len(s.macrosByStages[StageEOH]) > 0 { + if err := s.sendMacros(wire.CodeEOH, s.macrosByStages[StageEOH]); err != nil { + return nil, s.errorOut(err) + } } - if !s.ProtocolOption(OptNoEOHReply) { - act, err := s.readAction() - if err != nil { - return nil, fmt.Errorf("milter: header end: %w", err) - } - return act, nil + if s.ProtocolOption(OptNoEOH) { + return &Action{Type: ActionContinue}, nil + } + + if err := s.writePacket(&wire.Message{ + Code: wire.CodeEOH, + }); err != nil { + return nil, s.errorOut(fmt.Errorf("milter: header end: %w", err)) + } + + if s.ProtocolOption(OptNoEOHReply) { + return &Action{Type: ActionContinue}, nil } - return &Action{Code: ActContinue}, nil + + act, err := s.readAction(false) + if err != nil { + return nil, s.errorOut(fmt.Errorf("milter: header end: %w", err)) + } + return act, nil } // Header sends each field from textproto.Header followed by EOH unless // header messages are disabled during negotiation. +// +// You may call HeaderField before calling this method but since it calls HeaderEnd afterwards +// you should call BodyChunk or BodyReadFrom. func (s *ClientSession) Header(hdr textproto.Header) (*Action, error) { - for f := hdr.Fields(); f.Next(); { - act, err := s.HeaderField(f.Key(), f.Value()) - if err != nil { - return nil, err + if s.state < clientStateRcptCalled || s.state > clientStateHeaderFieldCalled { + return nil, s.errorOut(fmt.Errorf("milter: in wrong state %d", s.state)) + } + if s.state == clientStateRcptCalled { + act, err := s.DataStart() + if err != nil || act.Type != ActionContinue { + return act, err } - - if act.Code != ActContinue { - return act, nil + } + if !s.ProtocolOption(OptNoHeaders) || s.skip { + for f := hdr.Fields(); f.Next(); { + act, err := s.HeaderField(f.Key(), f.Value(), nil) + if err != nil || (act.Type != ActionContinue) { + return act, err + } + if s.skip { + break + } } } @@ -440,164 +809,108 @@ func (s *ClientSession) Header(hdr textproto.Header) (*Action, error) { // BodyChunk sends a single body chunk to the milter. // // It is callers responsibility to ensure every chunk is not bigger than -// MaxBodyChunk. +// defined in WithUsedMaxData. // -// If OptSkip was specified during negotiation, caller should be ready to -// handle return ActSkip and stop sending body chunks if it is returned. +// BodyChunk can be called even after the milter responded with ActSkip. +// This method translates a ActSkip milter response into a ActContinue response +// but after a successful ActSkip response Skip will return true. func (s *ClientSession) BodyChunk(chunk []byte) (*Action, error) { - if s.ProtocolOpts&OptNoBody != 0 { - return &Action{Code: ActContinue}, nil + if s.state < clientStateHeaderEndCalled || s.state > clientStateBodyChunkCalled { + return nil, s.errorOut(fmt.Errorf("milter: body: in wrong state %d", s.state)) + } + if s.skip { + return &Action{Type: ActionContinue}, nil + } + + s.state = clientStateBodyChunkCalled + + if s.ProtocolOption(OptNoBody) { + return &Action{Type: ActionContinue}, nil } - // Callers tend to be irresponsible... /s - if len(chunk) > MaxBodyChunk { - return nil, fmt.Errorf("milter: body chunk: too big body chunk: %v", len(chunk)) + if len(chunk) > int(s.maxBodySize) { + return nil, s.errorOut(fmt.Errorf("milter: body: too big body chunk: %d > %d", len(chunk), s.maxBodySize)) } - if err := writePacket(s.conn, &Message{ - Code: byte(CodeBody), + if err := s.writePacket(&wire.Message{ + Code: wire.CodeBody, Data: chunk, - }, s.writeTimeout); err != nil { - return nil, fmt.Errorf("milter: body chunk: %w", err) + }); err != nil { + return nil, s.errorOut(fmt.Errorf("milter: body chunk: %w", err)) } - if !s.ProtocolOption(OptNoBodyReply) { - act, err := s.readAction() - if err != nil { - return nil, fmt.Errorf("milter: body chunk: %w", err) - } - return act, nil + if s.ProtocolOption(OptNoBodyReply) { + return &Action{Type: ActionContinue}, nil + } + + act, err := s.readAction(s.ProtocolOption(OptSkip)) + if err != nil { + return nil, s.errorOut(fmt.Errorf("milter: body chunk: %w", err)) + } + if act.Type == ActionSkip { + s.skip = true + return &Action{Type: ActionContinue}, nil } - return &Action{Code: ActContinue}, nil + return act, nil } -// BodyReadFrom is a helper function that calls BodyChunk repeately to transmit entire +// BodyReadFrom is a helper function that calls BodyChunk repeatedly to transmit entire // body from io.Reader and then calls End. // // See documentation for these functions for details. +// +// You may first call BodyChunk and then call BodyReadFrom but after BodyReadFrom the End method gets +// called automatically. func (s *ClientSession) BodyReadFrom(r io.Reader) ([]ModifyAction, *Action, error) { - // It is problematic to use io.WriteCloser since we may need to report - // action after each write. - - buf := make([]byte, MaxBodyChunk) - for { - n, err := r.Read(buf) - if err != nil { - if err == io.EOF { + if s.state < clientStateHeaderEndCalled || s.state > clientStateBodyChunkCalled { + return nil, nil, s.errorOut(fmt.Errorf("milter: body: in wrong state %d", s.state)) + } + if !s.ProtocolOption(OptNoBody) && !s.skip { + scanner := milterutil.GetFixedBufferScanner(s.maxBodySize, r) + defer scanner.Close() + for scanner.Scan() { + act, err := s.BodyChunk(scanner.Bytes()) + if err != nil { + return nil, nil, err + } + if s.skip { break } - return nil, nil, err - } - if n == 0 { - break - } - - act, err := s.BodyChunk(buf[:n]) - if err != nil { - return nil, nil, err - } - if act.Code == ActSkip { - break + if act.Type != ActionContinue { + if scanner.Err() != nil { + return nil, nil, scanner.Err() + } + return nil, act, nil + } } - if act.Code != ActContinue { - return nil, act, nil + if scanner.Err() != nil { + return nil, nil, scanner.Err() } } return s.End() } -type ModifyAction struct { - Code ModifyActCode - - // Recipient to add/remove if Code == ActAddRcpt or ActDelRcpt. - Rcpt string - - // New envelope sender if Code = ActChangeFrom. - From string - - // ESMTP arguments for envelope sender if Code = ActChangeFrom. - FromArgs []string - - // Portion of body to be replaced if Code == ActReplBody. - Body []byte - - // Index of the header field to be changed if Code = ActChangeHeader or Code = ActInsertHeader. - // Index is 1-based and is per value of HdrName. - // E.g. HeaderIndex = 3 and HdrName = "DKIM-Signature" mean "change third - // DKIM-Signature field". Order is the same as of HeaderField calls. - HeaderIndex uint32 - - // Header field name to be added/changed if Code == ActAddHeader or - // ActChangeHeader or ActInsertHeader. - HeaderName string - - // Header field value to be added/changed if Code == ActAddHeader or - // ActChangeHeader or ActInsertHeader. If set to empty string - the field - // should be removed. - HeaderValue string - - // Quarantine reason if Code == ActQuarantine. - Reason string -} - -func parseModifyAct(msg *Message) (*ModifyAction, error) { - act := &ModifyAction{ - Code: ModifyActCode(msg.Code), - } - - switch ModifyActCode(msg.Code) { - case ActAddRcpt, ActDelRcpt: - act.Rcpt = readCString(msg.Data) - case ActQuarantine: - act.Reason = readCString(msg.Data) - case ActReplBody: - act.Body = msg.Data - case ActChangeFrom: - argv := bytes.Split(msg.Data, []byte{0x00}) - act.From = string(argv[0]) - for _, arg := range argv[1:] { - act.FromArgs = append(act.FromArgs, string(arg)) - } - case ActChangeHeader, ActInsertHeader: - if len(msg.Data) < 4 { - return nil, fmt.Errorf("read modify action: missing header index") - } - act.HeaderIndex = binary.BigEndian.Uint32(msg.Data) - - msg.Data = msg.Data[4:] - fallthrough - case ActAddHeader: - // TODO: Change readCString to return last index. - act.HeaderName = readCString(msg.Data) - nul := bytes.IndexByte(msg.Data, 0x00) - if nul == -1 { - return nil, fmt.Errorf("read modify action: missing NUL delimiter") - } - if nul == len(msg.Data) { - return nil, fmt.Errorf("read modify action: missing header value") - } - act.HeaderValue = readCString(msg.Data[nul+1:]) - default: - return nil, fmt.Errorf("read modify action: unexpected message code: %v", msg.Code) - } - - return act, nil +// Skip can be used after a BodyChunk, HeaderField or Rcpt call to check if the milter indicated to not need any more +// of these events. You can directly skip to the next event class. It is not an error to ignore this +// and just keep sending the same events since ClientSession will handle skipping internally. +func (s *ClientSession) Skip() bool { + return s.skip } func (s *ClientSession) readModifyActs() (modifyActs []ModifyAction, act *Action, err error) { for { - msg, err := readPacket(s.conn, s.readTimeout) + msg, err := wire.ReadPacket(s.conn, s.readTimeout) if err != nil { return nil, nil, fmt.Errorf("action read: %w", err) } - if msg.Code == 'p' /* progress */ { + if msg.Code == wire.Code(wire.ActProgress) /* progress */ { continue } - switch ModifyActCode(msg.Code) { - case ActAddRcpt, ActDelRcpt, ActReplBody, ActChangeHeader, ActInsertHeader, - ActAddHeader, ActChangeFrom, ActQuarantine: + switch wire.ModifyActCode(msg.Code) { + case wire.ActAddRcpt, wire.ActDelRcpt, wire.ActReplBody, wire.ActChangeHeader, wire.ActInsertHeader, + wire.ActAddHeader, wire.ActChangeFrom, wire.ActQuarantine, wire.ActAddRcptPar: modifyAct, err := parseModifyAct(msg) if err != nil { return nil, nil, err @@ -620,42 +933,133 @@ func (s *ClientSession) readModifyActs() (modifyActs []ModifyAction, act *Action // // Close should be called to conclude session. func (s *ClientSession) End() ([]ModifyAction, *Action, error) { - if err := writePacket(s.conn, &Message{ - Code: byte(CodeEOB), - }, s.writeTimeout); err != nil { - return nil, nil, fmt.Errorf("milter: end: %w", err) + if s.state != clientStateBodyChunkCalled { + return nil, nil, s.errorOut(fmt.Errorf("milter: end: in wrong state %d", s.state)) + } + s.state = clientStateHeloCalled + s.skip = false + s.skipUnknown = false + if len(s.macrosByStages) > int(StageEOM) && len(s.macrosByStages[StageEOM]) > 0 { + if err := s.sendMacros(wire.CodeEOB, s.macrosByStages[StageEOM]); err != nil { + return nil, nil, s.errorOut(err) + } + } + if err := s.writePacket(&wire.Message{ + Code: wire.CodeEOB, + }); err != nil { + return nil, nil, s.errorOut(fmt.Errorf("milter: end: %w", err)) } modifyActs, act, err := s.readModifyActs() if err != nil { - return nil, nil, fmt.Errorf("milter: end: %w", err) + return nil, nil, s.errorOut(fmt.Errorf("milter: end: %w", err)) } return modifyActs, act, nil } -// Abort sends Abort to the milter. +// Unknown sends an unknown command to the milter. This can happen at any time in the connection. +// Although you should probably do not call it after DataStart until End was called. // -// This is called for an unexpected end to an email outside the milters -// control. -func (s *ClientSession) Abort() error { - return writePacket(s.conn, &Message{ - Code: byte(CodeAbort), - }, s.writeTimeout) +// You can send macros to the milter with macros. They only get send to the milter when it wants unknown commands. +func (s *ClientSession) Unknown(cmd string, macros map[MacroName]string) (*Action, error) { + if s.state < clientStateNegotiated || s.state == clientStateError { + return nil, s.errorOut(fmt.Errorf("milter: unknown: in wrong state %d", s.state)) + } + + if s.ProtocolOption(OptNoUnknown) || s.skipUnknown { + return &Action{Type: ActionContinue}, nil + } + + if err := s.sendCmdMacros(wire.CodeUnknown, macros); err != nil { + return nil, s.errorOut(err) + } + + msg := &wire.Message{ + Code: wire.CodeUnknown, + } + msg.Data = wire.AppendCString(msg.Data, cmd) + + if err := s.writePacket(msg); err != nil { + return nil, s.errorOut(fmt.Errorf("milter: unknown: %w", err)) + } + + if s.ProtocolOption(OptNoUnknownReply) { + return &Action{Type: ActionContinue}, nil + } + + act, err := s.readAction(false) + if err != nil { + return nil, s.errorOut(fmt.Errorf("milter: unknown: %w", err)) + } + return act, nil } -// Close releases resources associated with the session. +// Abort sends Abort to the milter. You can call Mail in this same session after a successful call to Abort. // -// If there a milter sequence in progress - it is aborted. -func (s *ClientSession) Close() error { - if s.needAbort { - _ = s.Abort() +// This should be called for a premature but valid end of the SMTP session. +// That is when the SMTP client issues a RSET or QUIT command after at least Helo was called. +// +// You can send macros to the milter with macros. They only get send to the milter when it wants unknown commands. +func (s *ClientSession) Abort(macros map[MacroName]string) error { + if s.state == clientStateError || s.state < clientStateHeloCalled { + return s.errorOut(fmt.Errorf("milter: abort: in wrong state %d", s.state)) + } + s.state = clientStateHeloCalled + s.skip = false + s.skipUnknown = false + if err := s.sendCmdMacros(wire.CodeHeader, macros); err != nil { + return s.errorOut(err) + } + if err := s.writePacket(&wire.Message{ + Code: wire.CodeAbort, + }); err != nil { + return s.errorOut(err) } - if err := writePacket(s.conn, &Message{ - Code: byte(CodeQuit), - }, s.writeTimeout); err != nil { - return fmt.Errorf("milter: close: %w", err) + return nil +} + +// Reset sends CodeQuitNewConn to the milter so this session can be used for another connection. +// +// You can use this to do connection pooling - but that could be quite flaky +// since not all milters can handle CodeQuitNewConn +// sendmail or postfix do not use CodeQuitNewConn and never re-use a connection. +// Existing milters might not expect the MTA to use this feature. +func (s *ClientSession) Reset(macros Macros) error { + if s.state == clientStateError || s.state == clientStateClosed { + return s.errorOut(fmt.Errorf("milter: reset: in wrong state %d", s.state)) + } + s.state = clientStateNegotiated + s.skip = false + s.skipUnknown = false + if err := s.writePacket(&wire.Message{ + Code: wire.CodeQuitNewConn, + }); err != nil { + return s.errorOut(err) + } + s.macros = macros + return nil +} + +// Close releases resources associated with the session and closes the connection to the milter. +// +// If there is a milter sequence in progress the CodeQuit command is called to signal closure to the milter. +// +// You can call Close at any time in the session, and you can call Close multiple times without harm. +func (s *ClientSession) Close() error { + if s.state == clientStateClosed || s.state == clientStateError { + return s.closedErr + } + s.state = clientStateClosed + + if err := s.writePacket(&wire.Message{ + Code: wire.CodeQuit, + }); err != nil { + s.closedErr = fmt.Errorf("milter: close: quit: %w", err) + _ = s.conn.Close() + return s.closedErr } - return s.conn.Close() + s.closedErr = s.conn.Close() + return s.closedErr } diff --git a/client_test.go b/client_test.go index 7822374..23317fd 100644 --- a/client_test.go +++ b/client_test.go @@ -2,70 +2,85 @@ package milter import ( "bytes" + "encoding/binary" + "fmt" + "io" "net" nettextproto "net/textproto" "reflect" + "strings" "testing" + "time" + "github.com/d--j/go-milter/internal/wire" "github.com/emersion/go-message/textproto" ) -func init() { - // HACK: claim to support v6 in server for tests - serverProtocolVersion = 6 -} - type MockMilter struct { - ConnResp Response + ConnResp *Response ConnMod func(m *Modifier) ConnErr error - HeloResp Response + HeloResp *Response HeloMod func(m *Modifier) HeloErr error - MailResp Response + MailResp *Response MailMod func(m *Modifier) MailErr error - RcptResp Response + RcptResp *Response RcptMod func(m *Modifier) RcptErr error - HdrResp Response + DataResp *Response + DataMod func(m *Modifier) + DataErr error + + HdrResp *Response HdrMod func(m *Modifier) HdrErr error - HdrsResp Response + HdrsResp *Response HdrsMod func(m *Modifier) HdrsErr error - BodyChunkResp Response + BodyChunkResp *Response BodyChunkMod func(m *Modifier) BodyChunkErr error - BodyResp Response + BodyResp *Response BodyMod func(m *Modifier) BodyErr error AbortMod func(m *Modifier) AbortErr error + UnknownResp *Response + UnknownMod func(m *Modifier) + UnknownErr error + + OnClose func() + // Info collected during calls. Host string Family string Port uint16 - Addr net.IP + Addr string HeloValue string From string + FromEsmtp string Rcpt []string + RcptEsmtp []string Hdr nettextproto.MIMEHeader Chunks [][]byte + + Cmds []string } -func (mm *MockMilter) Connect(host string, family string, port uint16, addr net.IP, m *Modifier) (Response, error) { +func (mm *MockMilter) Connect(host string, family string, port uint16, addr string, m *Modifier) (*Response, error) { if mm.ConnMod != nil { mm.ConnMod(m) } @@ -76,7 +91,7 @@ func (mm *MockMilter) Connect(host string, family string, port uint16, addr net. return mm.ConnResp, mm.ConnErr } -func (mm *MockMilter) Helo(name string, m *Modifier) (Response, error) { +func (mm *MockMilter) Helo(name string, m *Modifier) (*Response, error) { if mm.HeloMod != nil { mm.HeloMod(m) } @@ -84,38 +99,47 @@ func (mm *MockMilter) Helo(name string, m *Modifier) (Response, error) { return mm.HeloResp, mm.HeloErr } -func (mm *MockMilter) MailFrom(from string, m *Modifier) (Response, error) { +func (mm *MockMilter) MailFrom(from string, esmtpArgs string, m *Modifier) (*Response, error) { if mm.MailMod != nil { mm.MailMod(m) } mm.From = from + mm.FromEsmtp = esmtpArgs return mm.MailResp, mm.MailErr } -func (mm *MockMilter) RcptTo(rcptTo string, m *Modifier) (Response, error) { +func (mm *MockMilter) RcptTo(rcptTo string, esmtpArgs string, m *Modifier) (*Response, error) { if mm.RcptMod != nil { mm.RcptMod(m) } mm.Rcpt = append(mm.Rcpt, rcptTo) + mm.RcptEsmtp = append(mm.RcptEsmtp, esmtpArgs) return mm.RcptResp, mm.RcptErr } -func (mm *MockMilter) Header(name string, value string, m *Modifier) (Response, error) { +func (mm *MockMilter) Data(m *Modifier) (*Response, error) { + if mm.DataMod != nil { + mm.DataMod(m) + } + return mm.DataResp, mm.DataErr +} + +func (mm *MockMilter) Header(name string, value string, m *Modifier) (*Response, error) { if mm.HdrMod != nil { mm.HdrMod(m) } return mm.HdrResp, mm.HdrErr } -func (mm *MockMilter) Headers(h nettextproto.MIMEHeader, m *Modifier) (Response, error) { +func (mm *MockMilter) Headers(m *Modifier) (*Response, error) { if mm.HdrsMod != nil { mm.HdrsMod(m) } - mm.Hdr = h + mm.Hdr = m.Headers return mm.HdrsResp, mm.HdrsErr } -func (mm *MockMilter) BodyChunk(chunk []byte, m *Modifier) (Response, error) { +func (mm *MockMilter) BodyChunk(chunk []byte, m *Modifier) (*Response, error) { if mm.BodyChunkMod != nil { mm.BodyChunkMod(m) } @@ -123,7 +147,7 @@ func (mm *MockMilter) BodyChunk(chunk []byte, m *Modifier) (Response, error) { return mm.BodyChunkResp, mm.BodyChunkErr } -func (mm *MockMilter) Body(m *Modifier) (Response, error) { +func (mm *MockMilter) EndOfMessage(m *Modifier) (*Response, error) { if mm.BodyMod != nil { mm.BodyMod(m) } @@ -137,57 +161,103 @@ func (mm *MockMilter) Abort(m *Modifier) error { return mm.AbortErr } -func TestMilterClient_UsualFlow(t *testing.T) { - mm := MockMilter{ - ConnResp: RespContinue, - HeloResp: RespContinue, - MailResp: RespContinue, - RcptResp: RespContinue, - HdrResp: RespContinue, - HdrsResp: RespContinue, - BodyChunkResp: RespContinue, - BodyResp: RespContinue, - BodyMod: func(m *Modifier) { - m.AddHeader("X-Bad", "very") - m.ChangeHeader(1, "Subject", "***SPAM***") - m.Quarantine("very bad message") - }, +func (mm *MockMilter) Unknown(cmd string, m *Modifier) (*Response, error) { + if mm.UnknownMod != nil { + mm.UnknownMod(m) } - s := Server{ - NewMilter: func() Milter { - return &mm - }, - Actions: OptAddHeader | OptChangeHeader, + mm.Cmds = append(mm.Cmds, cmd) + return mm.UnknownResp, mm.UnknownErr +} + +func (mm *MockMilter) Cleanup() { + if mm.OnClose != nil { + mm.OnClose() } - defer s.Close() - local, err := net.Listen("tcp", "127.0.0.1:0") +} + +func assertAction(t *testing.T, act *Action, err error, expectCode ActionType) { + t.Helper() if err != nil { t.Fatal(err) } - go s.Serve(local) + if act.Type != expectCode { + t.Fatalf("Unexpected code %c: %+v", act.Type, act) + } +} - cl := NewClientWithOptions("tcp", local.Addr().String(), ClientOptions{ - ActionMask: OptAddHeader | OptChangeHeader | OptQuarantine, - }) - defer cl.Close() - session, err := cl.Session() +type serverClientWrap struct { + server *Server + client *Client + session *ClientSession + local net.Listener +} + +func newServerClient(t *testing.T, macros Macros, serverOptions []Option, clientOptions []Option) serverClientWrap { + var err error + s := NewServer(serverOptions...) + w := serverClientWrap{server: s} + w.local, err = net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + go func() { + s.Serve(w.local) + }() + w.client = NewClient("tcp", w.local.Addr().String(), clientOptions...) + w.session, err = w.client.Session(macros) if err != nil { + w.server.Close() t.Fatal(err) } - defer session.Close() + return w +} - assertAction := func(act *Action, err error, expectCode ActionCode) { - t.Helper() - if err != nil { - t.Fatal(err) - } - if act.Code != expectCode { - t.Fatal("Unexpectedcode:", act.Code) - } +func (w *serverClientWrap) Cleanup() { + w.session.Close() + w.server.Close() +} + +func TestMilterClient_UsualFlow(t *testing.T) { + t.Parallel() + mm := MockMilter{ + ConnResp: RespContinue, + HeloResp: RespContinue, + MailResp: RespContinue, + RcptResp: RespContinue, + DataResp: RespContinue, + HdrResp: RespContinue, + HdrsResp: RespContinue, + HdrsMod: func(m *Modifier) { + m.Progress() + }, + BodyChunkResp: RespContinue, + BodyResp: RespContinue, + BodyMod: func(m *Modifier) { + m.ChangeFrom("changed@example.com", "") + m.ChangeFrom("changed@example.com", "A=B") + m.AddRecipient("example@example.com", "") + m.AddRecipient("", "A=B") + m.DeleteRecipient("del@example.com") + m.AddHeader("X-Bad", "very") + m.Progress() + m.ChangeHeader(1, "Subject", "***SPAM***") + m.InsertHeader(2, "X-Hdr", "value") + m.Quarantine("very bad message") + m.ReplaceBody(strings.NewReader(strings.Repeat("-", int(DataSize64K)+1))) + }, + UnknownResp: RespContinue, } + macros := NewMacroBag() + w := newServerClient(t, macros, []Option{WithMilter(func() Milter { + return &mm + }), WithActions(OptAddHeader | OptChangeBody | OptAddRcpt | OptRemoveRcpt | OptChangeHeader | OptQuarantine | OptChangeFrom | OptAddRcptWithArgs)}, + []Option{WithActions(OptAddHeader | OptChangeBody | OptAddRcpt | OptRemoveRcpt | OptChangeHeader | OptQuarantine | OptChangeFrom | OptAddRcptWithArgs)}, + ) + defer w.Cleanup() - act, err := session.Conn("host", FamilyInet, 25565, "172.0.0.1") - assertAction(act, err, ActContinue) + macros.Set(MacroTlsVersion, "very old") + act, err := w.session.Conn("host", FamilyInet, 25565, "172.0.0.1") + assertAction(t, act, err, ActionContinue) if mm.Host != "host" { t.Fatal("Wrong host:", mm.Host) } @@ -197,40 +267,39 @@ func TestMilterClient_UsualFlow(t *testing.T) { if mm.Port != 25565 { t.Fatal("Wrong port:", mm.Port) } - if mm.Addr.String() != "172.0.0.1" { + if mm.Addr != "172.0.0.1" { t.Fatal("Wrong IP:", mm.Addr) } - if err := session.Macros(CodeHelo, "tls_version", "very old"); err != nil { - t.Fatal("Unexpected error", err) - } - - act, err = session.Helo("helo_host") - assertAction(act, err, ActContinue) + act, err = w.session.Helo("helo_host") + assertAction(t, act, err, ActionContinue) if mm.HeloValue != "helo_host" { t.Fatal("Wrong helo value:", mm.HeloValue) } - act, err = session.Mail("from@example.org", []string{"A=B"}) - assertAction(act, err, ActContinue) + act, err = w.session.Mail("from@example.org", "A=B") + assertAction(t, act, err, ActionContinue) if mm.From != "from@example.org" { t.Fatal("Wrong MAIL FROM:", mm.From) } - act, err = session.Rcpt("to1@example.org", []string{"A=B"}) - assertAction(act, err, ActContinue) - act, err = session.Rcpt("to2@example.org", []string{"A=B"}) - assertAction(act, err, ActContinue) + act, err = w.session.Rcpt("to1@example.org", "A=B") + assertAction(t, act, err, ActionContinue) + act, err = w.session.Rcpt("to2@example.org", "A=C") + assertAction(t, act, err, ActionContinue) if !reflect.DeepEqual(mm.Rcpt, []string{"to1@example.org", "to2@example.org"}) { t.Fatal("Wrong recipients:", mm.Rcpt) } + if !reflect.DeepEqual(mm.RcptEsmtp, []string{"A=B", "A=C"}) { + t.Fatal("Wrong recipients esmtp args:", mm.RcptEsmtp) + } hdr := textproto.Header{} hdr.Add("From", "from@example.org") hdr.Add("To", "to@example.org") hdr.Add("x-empty-header", "") - act, err = session.Header(hdr) - assertAction(act, err, ActContinue) + act, err = w.session.Header(hdr) // calls DataStart() automatically + assertAction(t, act, err, ActionContinue) if len(mm.Hdr) != 3 { t.Fatal("Unexpected header length:", len(mm.Hdr)) } @@ -244,8 +313,14 @@ func TestMilterClient_UsualFlow(t *testing.T) { t.Fatal("Wrong To header:", val) } - modifyActs, act, err := session.BodyReadFrom(bytes.NewReader(bytes.Repeat([]byte{'A'}, 128000))) - assertAction(act, err, ActContinue) + act, err = w.session.Unknown("INVALID command", map[MacroName]string{MacroHopCount: "2"}) + assertAction(t, act, err, ActionContinue) + if !reflect.DeepEqual(mm.Cmds, []string{"INVALID command"}) { + t.Fatal("Wrong cmds:", mm.Cmds) + } + + modifyActs, act, err := w.session.BodyReadFrom(bytes.NewReader(bytes.Repeat([]byte{'A'}, 128000))) + assertAction(t, act, err, ActionContinue) if len(mm.Chunks) != 2 { t.Fatal("Wrong amount of body chunks received") @@ -257,76 +332,54 @@ func TestMilterClient_UsualFlow(t *testing.T) { t.Fatal("Some body bytes lost:", totalLen) } + firstBodyChunk := []byte(strings.Repeat("-", int(DataSize64K))) expected := []ModifyAction{ - { - Code: ActAddHeader, - HeaderName: "X-Bad", - HeaderValue: "very", - }, - { - Code: ActChangeHeader, - HeaderIndex: 1, - HeaderName: "Subject", - HeaderValue: "***SPAM***", - }, - { - Code: ActQuarantine, - Reason: "very bad message", - }, + {Type: ActionChangeFrom, From: ""}, + {Type: ActionChangeFrom, From: "", FromArgs: "A=B"}, + {Type: ActionAddRcpt, Rcpt: ""}, + {Type: ActionAddRcpt, Rcpt: "", RcptArgs: "A=B"}, + {Type: ActionDelRcpt, Rcpt: ""}, + {Type: ActionAddHeader, HeaderName: "X-Bad", HeaderValue: "very"}, + {Type: ActionChangeHeader, HeaderIndex: 1, HeaderName: "Subject", HeaderValue: "***SPAM***"}, + {Type: ActionInsertHeader, HeaderIndex: 2, HeaderName: "X-Hdr", HeaderValue: "value"}, + {Type: ActionQuarantine, Reason: "very bad message"}, + {Type: ActionReplaceBody, Body: firstBodyChunk}, + {Type: ActionReplaceBody, Body: []byte{'-'}}, } - if !reflect.DeepEqual(modifyActs, expected) { - t.Fatalf("Wrong modify actions, got %+v", modifyActs) + t.Fatalf("Wrong modify actions: got %+v", modifyActs) } } func TestMilterClient_AbortFlow(t *testing.T) { - macros := make(map[string]string) + t.Parallel() + waitChan := make(chan interface{}, 2) + heloTls := "not set" + aborTls := "not set" + mailAuthen := "not set" mm := MockMilter{ ConnResp: RespContinue, HeloResp: RespContinue, HeloMod: func(m *Modifier) { - macros = m.Macros + heloTls = m.Macros.Get(MacroTlsVersion) }, - AbortMod: func(m *Modifier) { - macros = m.Macros + MailResp: RespContinue, + MailMod: func(m *Modifier) { + mailAuthen = m.Macros.Get(MacroAuthAuthen) }, - } - s := Server{ - NewMilter: func() Milter { - return &mm + AbortMod: func(m *Modifier) { + aborTls = m.Macros.Get(MacroTlsVersion) + waitChan <- nil }, - Actions: OptAddHeader | OptChangeHeader, - } - defer s.Close() - local, err := net.Listen("tcp", "127.0.0.1:0") - if err != nil { - t.Fatal(err) - } - go s.Serve(local) - - cl := NewClientWithOptions("tcp", local.Addr().String(), ClientOptions{ - ActionMask: OptAddHeader | OptChangeHeader | OptQuarantine, - }) - defer cl.Close() - session, err := cl.Session() - if err != nil { - t.Fatal(err) - } - defer session.Close() - - assertAction := func(act *Action, err error, expectCode ActionCode) { - t.Helper() - if err != nil { - t.Fatal(err) - } - if act.Code != expectCode { - t.Fatal("Unexpectedcode:", act.Code) - } } + macros := NewMacroBag() + w := newServerClient(t, macros, []Option{WithMilter(func() Milter { + return &mm + }), WithActions(OptAddHeader | OptChangeHeader)}, []Option{WithActions(OptAddHeader | OptChangeHeader | OptQuarantine)}) + defer w.Cleanup() - act, err := session.Conn("host", FamilyInet, 25565, "172.0.0.1") - assertAction(act, err, ActContinue) + act, err := w.session.Conn("host", FamilyInet, 25565, "172.0.0.1") + assertAction(t, act, err, ActionContinue) if mm.Host != "host" { t.Fatal("Wrong host:", mm.Host) } @@ -336,39 +389,1288 @@ func TestMilterClient_AbortFlow(t *testing.T) { if mm.Port != 25565 { t.Fatal("Wrong port:", mm.Port) } - if mm.Addr.String() != "172.0.0.1" { + if mm.Addr != "172.0.0.1" { t.Fatal("Wrong IP:", mm.Addr) } - if err := session.Macros(CodeHelo, "tls_version", "very old"); err != nil { - t.Fatal("Unexpected error", err) - } - - act, err = session.Helo("helo_host") - assertAction(act, err, ActContinue) + macros.Set(MacroTlsVersion, "very old") + act, err = w.session.Helo("helo_host") + assertAction(t, act, err, ActionContinue) if mm.HeloValue != "helo_host" { t.Fatal("Wrong helo value:", mm.HeloValue) } - if v, ok := macros["tls_version"]; !ok || v != "very old" { - t.Fatal("Wrong tls_version macro value:", v) + + if heloTls != "very old" { + t.Fatal("Wrong tls_version macro value:", heloTls) + } + + macros.Set(MacroAuthAuthen, "login-user") + act, err = w.session.Mail("login-user@example.com", "") + assertAction(t, act, err, ActionContinue) + if mm.From != "login-user@example.com" { + t.Fatal("Wrong from value:", mm.From) + } + if mailAuthen != "login-user" { + t.Fatal("Unexpected macro data:", mailAuthen) } - err = session.Abort() + err = w.session.Abort(nil) + <-waitChan // since Abort() does not wait for a response we need to wait for the server to finish on our own if err != nil { t.Fatal(err) } // Validate macro values are preserved for the abort callback - if v, ok := macros["tls_version"]; !ok || v != "very old" { - t.Fatal("Wrong tls_version macro value: ", v) + if aborTls != "very old" { + t.Fatal("Wrong tls_version macro value: ", aborTls) } - act, err = session.Helo("repeated_helo_host") - assertAction(act, err, ActContinue) - if mm.HeloValue != "repeated_helo_host" { - t.Fatal("Wrong helo value:", mm.HeloValue) + macros.Set(MacroAuthAuthen, "") + act, err = w.session.Mail("another-user@example.com", "") + assertAction(t, act, err, ActionContinue) + if mm.From != "another-user@example.com" { + t.Fatal("Wrong from value:", mm.From) + } + if len(mailAuthen) != 0 { + t.Fatal("Unexpected macro data:", mailAuthen) + } +} + +func TestMilterClient_NoWorking(t *testing.T) { + t.Parallel() + mm := MockMilter{ + MailResp: RespReject, + } + w := newServerClient(t, nil, []Option{WithMilter(func() Milter { + return &mm + }), WithActions(OptAddHeader | OptChangeHeader), WithProtocols(OptNoMailFrom)}, + []Option{WithActions(OptAddHeader | OptChangeHeader | OptQuarantine)}, + ) + defer w.Cleanup() + + _, err := w.session.Mail("from@example.org", "A=B") + if err == nil || err.Error() != "milter: in wrong state 1" { + t.Fatal("expected error") + } + w.local.Close() + + cl2 := NewClient(w.local.Addr().Network(), w.local.Addr().String()) + if _, err := cl2.Session(nil); err == nil { + t.Fatal("could start a session to a non-existing server") + } +} + +func TestMilterClient_NegotiationMismatch(t *testing.T) { + t.Parallel() + mm := MockMilter{} + s := NewServer(WithMilter(func() Milter { + return &mm + }), WithActions(OptAddHeader|OptChangeHeader), WithProtocols(OptNoMailFrom)) + local, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + go s.Serve(local) + client := NewClient("tcp", local.Addr().String(), WithActions(OptAddHeader|OptChangeHeader|OptQuarantine), WithProtocols(OptNoEOH)) + session, err := client.Session(nil) + if err == nil { + session.Close() + t.Fatal("negotiation should fail") + } + + client2 := NewClient("tcp", local.Addr().String(), WithActions(OptAddHeader), WithProtocols(OptNoMailFrom)) + session2, err := client2.Session(nil) + if err == nil { + session2.Close() + t.Fatal("negotiation should fail") + } +} + +func TestMilterClient_BogusServerNegotiation(t *testing.T) { + tests := []struct { + name string + opts []Option + negResponse []byte + onlyWarning bool + }{ + {"not even full packet", []Option{WithReadTimeout(time.Second)}, []byte{0}, false}, + {"wrong response code", nil, []byte{0, 0, 0, 1, 'a'}, false}, + {"too few bytes", nil, []byte{0, 0, 0, 2, byte(wire.CodeOptNeg), 0}, false}, + {"milter version 0", nil, []byte{0, 0, 0, 13, byte(wire.CodeOptNeg), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, false}, + {"milter version 1", nil, []byte{0, 0, 0, 13, byte(wire.CodeOptNeg), 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0}, false}, + {"wrong actions", nil, []byte{0, 0, 0, 13, byte(wire.CodeOptNeg), 0, 0, 0, 2, 255, 255, 255, 255, 0, 0, 0, 0}, false}, + {"wrong protocol", nil, []byte{0, 0, 0, 13, byte(wire.CodeOptNeg), 0, 0, 0, 2, 0, 0, 0, 0, 255, 255, 255, 255}, false}, + {"wrong milter stage", nil, []byte{0, 0, 0, 18, byte(wire.CodeOptNeg), 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 255, 255, 255, 255, 0}, true}, + {"wrong milter list", nil, []byte{0, 0, 0, 18, byte(wire.CodeOptNeg), 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 'a'}, true}, + {"repeated milter stage", nil, []byte{0, 0, 0, 25, byte(wire.CodeOptNeg), 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 'a', 0, 0, 0, 0, 0, 'a', 0}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + // t.Parallel() - test cannot be Parallel() because it replaces the global LogWarning + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + // just slurp up everything the client sends + go func() { + buf := make([]byte, 1024) + for { + _ = serverConn.SetReadDeadline(time.Now().Add(time.Minute)) + if _, err := serverConn.Read(buf); err != nil { + if err != io.EOF && err != io.ErrClosedPipe { + t.Logf("server got error: read: %v", err) + } + return + } + } + }() + sErrChan := make(chan error) + go func() { + if _, err := serverConn.Write(ltt.negResponse); err != nil { + sErrChan <- err + return + } + sErrChan <- nil + }() + warningCalled := false + if ltt.onlyWarning { + LogWarning = func(format string, v ...interface{}) { + warningCalled = true + logWarning(format, v...) + } + } + cl := NewClient(clientConn.LocalAddr().Network(), clientConn.LocalAddr().String(), ltt.opts...) + session, err := cl.session(clientConn, nil) + if ltt.onlyWarning { + LogWarning = logWarning + if session == nil { + t.Fatalf("negotiation should succeed but it did not with server response %x", ltt.negResponse) + } + session.Close() + if !warningCalled { + t.Fatal("negotiation should have called a warning") + } + } else { + if err == nil { + session.Close() + t.Fatalf("expected error in negotiation but it succeeded with server response %x", ltt.negResponse) + } + } + + if err := <-sErrChan; err != nil { + t.Fatal(err) + } + }) + } +} + +func TestMilterClient_Negotiation(t *testing.T) { + t.Parallel() + tests := []struct { + name string + opts []Option + serverVersion uint32 + serverActions OptAction + serverProtocol OptProtocol + wantVersion uint32 + wantActions OptAction + wantProtocol OptProtocol + wantBufferSize DataSize + }{ + {"default", nil, MaxClientProtocolVersion, OptAddHeader, 0, MaxClientProtocolVersion, OptAddHeader, 0, DataSize64K}, + {"v6 client v2 server", nil, 2, OptAddHeader, 0, 2, OptAddHeader, OptNoUnknown | OptNoData, DataSize64K}, + {"v2 client v2 server", []Option{WithMaximumVersion(2), WithProtocols(allClientSupportedProtocolMasksV2), WithActions(allClientSupportedActionMasksV2)}, 2, OptAddHeader, 0, 2, OptAddHeader, OptNoUnknown | OptNoData, DataSize64K}, + {"offered 1MB not accepted", []Option{WithActions(AllClientSupportedActionMasks), WithOfferedMaxData(DataSize1M)}, MaxClientProtocolVersion, OptAddHeader, 0, MaxClientProtocolVersion, OptAddHeader, 0, DataSize64K}, + {"offered 256K not accepted", []Option{WithActions(AllClientSupportedActionMasks), WithOfferedMaxData(DataSize256K)}, MaxClientProtocolVersion, OptAddHeader, 0, MaxClientProtocolVersion, OptAddHeader, 0, DataSize64K}, + {"offered 1MB accepted", []Option{WithActions(AllClientSupportedActionMasks), WithOfferedMaxData(DataSize1M)}, MaxClientProtocolVersion, OptAddHeader, OptProtocol(optMds1M), MaxClientProtocolVersion, OptAddHeader, 0, DataSize1M}, + {"offered 256K accepted", []Option{WithActions(AllClientSupportedActionMasks), WithOfferedMaxData(DataSize256K)}, MaxClientProtocolVersion, OptAddHeader, OptProtocol(optMds256K), MaxClientProtocolVersion, OptAddHeader, 0, DataSize256K}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + + // just slurp up everything the client sends + go func() { + buf := make([]byte, 1024) + for { + _ = serverConn.SetReadDeadline(time.Now().Add(time.Minute)) + if _, err := serverConn.Read(buf); err != nil { + if err != io.EOF && err != io.ErrClosedPipe { + t.Logf("server got error: read: %v", err) + } + return + } + } + }() + sErrChan := make(chan error) + go func() { + defer serverConn.Close() + response := []byte{0, 0, 0, 13, byte(wire.CodeOptNeg), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + binary.BigEndian.PutUint32(response[5:], ltt.serverVersion) + binary.BigEndian.PutUint32(response[9:], uint32(ltt.serverActions)) + binary.BigEndian.PutUint32(response[13:], uint32(ltt.serverProtocol)) + if _, err := serverConn.Write(response); err != nil { + sErrChan <- err + return + } + sErrChan <- nil + }() + cl := NewClient(clientConn.LocalAddr().Network(), clientConn.LocalAddr().String(), ltt.opts...) + session, err := cl.session(clientConn, nil) + if err != nil { + t.Fatalf("expected no error in negotiation but got %v, with server version %d actions %x protocol %x", err, ltt.serverVersion, ltt.serverActions, ltt.serverProtocol) + } + if session.version != ltt.wantVersion { + t.Fatalf("version: got %d expected %d", session.version, ltt.wantVersion) + } + if session.actionOpts != ltt.wantActions { + t.Fatalf("actions: got %032b expected %032b", session.actionOpts, ltt.wantActions) + } + if session.protocolOpts != ltt.wantProtocol { + t.Fatalf("protocol: got %032b expected %032b", session.protocolOpts, ltt.wantProtocol) + } + if session.negotiatedBodySize != uint32(ltt.wantBufferSize) { + t.Fatalf("buffer size: got %d expected %d", session.negotiatedBodySize, ltt.wantBufferSize) + } + session.Close() + + if err := <-sErrChan; err != nil { + t.Fatal(err) + } + }) + } +} + +func TestMilterClient_WithMockServer(t *testing.T) { + t.Parallel() + type op struct { + s1 func(*ClientSession) (*Action, error) + s2 func(*ClientSession) error + s3 func(*ClientSession) ([]ModifyAction, *Action, error) + v1 func(*testing.T, *ClientSession, *Action, error) + v2 func(*testing.T, *ClientSession, error) + v3 func(*testing.T, *ClientSession, []ModifyAction, *Action, error) + server []byte + } + type ops []op + + type cfg struct { + Opts []Option + ServerNegotiation []byte + Macros Macros + } + + withProtC := func(prot OptProtocol) cfg { + c := cfg{ + Opts: []Option{WithActions(AllClientSupportedActionMasks), WithReadTimeout(time.Second), WithWriteTimeout(time.Second)}, + ServerNegotiation: []byte{0, 0, 0, 13, byte(wire.CodeOptNeg), 0, 0, 0, 6, 0, 0, 0, 0, 0, 0, 0, 0}, + Macros: nil, + } + binary.BigEndian.PutUint32(c.ServerNegotiation[13:], uint32(prot)) + return c + } + withActC := func(c cfg, act OptAction) cfg { + binary.BigEndian.PutUint32(c.ServerNegotiation[9:], uint32(act)) + return c + } + with256KbC := func(c cfg) cfg { + c.Opts = append(c.Opts, WithOfferedMaxData(DataSize256K)) + binary.BigEndian.PutUint32(c.ServerNegotiation[13:], optMds256K|binary.BigEndian.Uint32(c.ServerNegotiation[13:])) + return c + } + dC := withProtC(0) + + sendConnect := func(s *ClientSession) (*Action, error) { + return s.Conn("localhost", FamilyUnix, 0, "/var/run/sock") + } + sendHelo := func(s *ClientSession) (*Action, error) { + return s.Helo("localhost") + } + sendMail := func(s *ClientSession) (*Action, error) { + return s.Mail("", "") + } + sendRcpt := func(s *ClientSession) (*Action, error) { + return s.Rcpt("", "") + } + sendData := func(s *ClientSession) (*Action, error) { + return s.DataStart() + } + sendHeaderField := func(s *ClientSession) (*Action, error) { + return s.HeaderField("a", "b", nil) + } + sendHeaderEnd := func(s *ClientSession) (*Action, error) { + return s.HeaderEnd() + } + sendBodyChunk := func(s *ClientSession) (*Action, error) { + return s.BodyChunk([]byte("line\n")) + } + sendEnd := func(s *ClientSession) ([]ModifyAction, *Action, error) { + return s.End() + } + + expectErr1 := func(t *testing.T, _ *ClientSession, act *Action, err error) { + t.Helper() + if err == nil { + t.Fatalf("expected err but got act = %+v", act) + } + } + expectAct := func(expectedActCode ActionType, t *testing.T, act *Action, err error) { + t.Helper() + if err != nil { + t.Fatal(err) + } + if act.Type != expectedActCode { + t.Fatalf("expected %c, got %+v", expectedActCode, act) + } + } + expectContinue := func(t *testing.T, _ *ClientSession, act *Action, err error) { + t.Helper() + expectAct(ActionContinue, t, act, err) + } + expectReject := func(t *testing.T, _ *ClientSession, act *Action, err error) { + t.Helper() + expectAct(ActionReject, t, act, err) + } + expectAcceptEmptyMods := func(t *testing.T, _ *ClientSession, mActs []ModifyAction, act *Action, err error) { + t.Helper() + expectAct(ActionAccept, t, act, err) + if len(mActs) > 0 { + t.Fatalf("expected empty modifications, got %+v", mActs) + } + } + + responseContinue := []byte{0, 0, 0, 1, byte(wire.ActContinue)} + + tests := []struct { + name string + cfg cfg + ops ops + }{ + {"bogus response at connect", dC, ops{{s1: sendConnect, v1: expectErr1, server: []byte{0, 0, 0}}}}, + {"double connect", dC, ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendConnect, v1: expectErr1, server: responseContinue}, + }}, + {"Progress response working", dC, ops{ + {s1: sendConnect, v1: expectContinue, server: []byte{0, 0, 0, 1, byte(wire.ActProgress), 0, 0, 0, 1, byte(wire.ActProgress), 0, 0, 0, 1, byte(wire.ActContinue)}}, + }}, + {"ActReplyCode response working", dC, ops{ + {s1: sendConnect, v1: func(t *testing.T, s *ClientSession, act *Action, err error) { + expectAct(ActionRejectWithCode, t, act, err) + if act.SMTPCode != 400 { + t.Fatalf("expected code %d, got %d", 400, act.SMTPCode) + } + if act.SMTPReply != "400 T" { + t.Fatalf("expected text %s, got %s", "400 T", act.SMTPReply) + } + }, server: []byte{0, 0, 0, 7, byte(wire.ActReplyCode), '4', '0', '0', ' ', 'T', 0}}, + }}, + {"ActReplyCode parsing error 1", dC, ops{ + {s1: sendConnect, v1: expectErr1, server: []byte{0, 0, 0, 7, byte(wire.ActReplyCode), 'a', '0', '0', ' ', 'T', 0}}, + }}, + {"ActReplyCode parsing error 2", dC, ops{ + {s1: sendConnect, v1: expectErr1, server: []byte{0, 0, 0, 4, byte(wire.ActReplyCode), '4', '0', '0'}}, + }}, + {"OptNoConnect working", withProtC(OptNoConnect), ops{ + {s1: sendConnect, v1: expectContinue, server: nil}, + }}, + {"OptNoConnReply working", withProtC(OptNoConnReply), ops{ + {s1: sendConnect, v1: expectContinue, server: nil}, + }}, + {"premature Helo", dC, + ops{{s1: sendHelo, v1: expectErr1, server: responseContinue}}, + }, + {"bogus response at helo", dC, + ops{{s1: sendConnect, v1: expectContinue, server: responseContinue}, {s1: sendHelo, v1: expectErr1, server: []byte{0, 0, 0}}}, + }, + {"OptNoHelo working", withProtC(OptNoHelo), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: nil}, + }}, + {"OptNoHeloReply working", withProtC(OptNoHeloReply), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: nil}, + }}, + {"premature Mail", withProtC(OptNoMailFrom), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectErr1, server: responseContinue}, + }}, + {"OptNoMailFrom working", withProtC(OptNoMailFrom), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: nil}, + }}, + {"OptNoMailReply working", withProtC(OptNoMailReply), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: nil}, + }}, + {"premature Rcpt", dC, ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectErr1, server: responseContinue}, + }}, + {"OptNoRcptTo working", withProtC(OptNoRcptTo), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: nil}, + }}, + {"OptNoRcptReply working", withProtC(OptNoRcptReply), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: nil}, + }}, + {"premature DataStart", dC, ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectErr1, server: responseContinue}, + }}, + {"OptNoData working", withProtC(OptNoData), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: nil}, + }}, + {"OptNoDataReply working", withProtC(OptNoDataReply), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: nil}, + }}, + {"premature HeaderField", dC, ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectErr1, server: responseContinue}, + }}, + {"OptNoHeaders working", withProtC(OptNoHeaders), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: nil}, + }}, + {"OptNoHeaderReply working", withProtC(OptNoHeaderReply), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: nil}, + }}, + {"premature HeaderEnd", dC, ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectErr1, server: responseContinue}, + }}, + {"OptNoEOH working", withProtC(OptNoEOH), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: nil}, + }}, + {"OptNoEOHReply working", withProtC(OptNoEOHReply), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: nil}, + }}, + {"premature BodyChunk", dC, ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectErr1, server: responseContinue}, + }}, + {"Skip working", withProtC(OptSkip), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: []byte{0, 0, 0, 1, byte(wire.ActSkip)}}, + {s1: sendBodyChunk, v1: expectContinue, server: nil}, + }}, + {"Skip rejected 1", dC, ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectErr1, server: []byte{0, 0, 0, 1, byte(wire.ActSkip)}}, + }}, + {"Skip rejected 2", dC, ops{ + {s1: sendConnect, v1: expectErr1, server: []byte{0, 0, 0, 1, byte(wire.ActSkip)}}, + }}, + {"Reject too much data", dC, ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: func(s *ClientSession) (*Action, error) { + return s.BodyChunk(make([]byte, DataSize256K)) + }, v1: expectErr1, server: responseContinue}, + }}, + {"BodyChunk Skip working", withProtC(OptSkip), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: func(s *ClientSession) (*Action, error) { + act, err := s.BodyChunk([]byte("line\n")) + if err != nil || act.Type != ActionContinue { + if err == nil { + err = fmt.Errorf("expected continue response, got %+v", act) + } + return nil, err + } + if !s.Skip() { + return nil, fmt.Errorf("expected Skip to be true") + } + act, err = s.BodyChunk([]byte("line\n")) + if err != nil || act.Type != ActionContinue { + if err == nil { + err = fmt.Errorf("expected continue response, got %+v", act) + } + return nil, err + } + return act, err + }, v1: expectContinue, server: []byte{0, 0, 0, 1, byte(wire.ActSkip)}}, + }}, + {"OptNoBody working", withProtC(OptNoBody), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: nil}, + }}, + {"OptNoBodyReply working", withProtC(OptNoBodyReply), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: nil}, + }}, + {"Header working", dC, ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: func(s *ClientSession) (*Action, error) { + hdrs := textproto.Header{} + hdrs.Add("From", "<>") + hdrs.Add("To", "<>") + return s.Header(hdrs) + }, v1: expectContinue, server: []byte{0, 0, 0, 1, byte(wire.ActContinue), 0, 0, 0, 1, byte(wire.ActContinue), 0, 0, 0, 1, byte(wire.ActContinue)}}, + }}, + {"Header after HeaderField", dC, ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: func(s *ClientSession) (*Action, error) { + hdrs := textproto.Header{} + hdrs.Add("From", "<>") + hdrs.Add("To", "<>") + return s.Header(hdrs) + }, v1: expectContinue, server: []byte{0, 0, 0, 1, byte(wire.ActContinue), 0, 0, 0, 1, byte(wire.ActContinue), 0, 0, 0, 1, byte(wire.ActContinue)}}, + }}, + {"Header auto Data", dC, ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: func(s *ClientSession) (*Action, error) { + hdrs := textproto.Header{} + hdrs.Add("From", "<>") + hdrs.Add("To", "<>") + return s.Header(hdrs) + }, v1: expectContinue, server: []byte{0, 0, 0, 1, byte(wire.ActContinue), 0, 0, 0, 1, byte(wire.ActContinue), 0, 0, 0, 1, byte(wire.ActContinue), 0, 0, 0, 1, byte(wire.ActContinue)}}, + }}, + {"Header auto Data reject", dC, ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: func(s *ClientSession) (*Action, error) { + hdrs := textproto.Header{} + hdrs.Add("From", "<>") + hdrs.Add("To", "<>") + return s.Header(hdrs) + }, v1: expectReject, server: []byte{0, 0, 0, 1, byte(wire.ActReject)}}, + }}, + {"Header premature", dC, ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: func(s *ClientSession) (*Action, error) { + hdrs := textproto.Header{} + hdrs.Add("From", "<>") + hdrs.Add("To", "<>") + return s.Header(hdrs) + }, v1: expectErr1, server: responseContinue}, + }}, + {"Header Reject second header", dC, ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: func(s *ClientSession) (*Action, error) { + hdrs := textproto.Header{} + hdrs.Add("From", "<>") + hdrs.Add("To", "<>") + return s.Header(hdrs) + }, v1: expectReject, server: []byte{0, 0, 0, 1, byte(wire.ActContinue), 0, 0, 0, 1, byte(wire.ActReject)}}, + }}, + {"BodyReadFrom working", dC, ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s3: func(s *ClientSession) ([]ModifyAction, *Action, error) { + return s.BodyReadFrom(bytes.NewReader(make([]byte, 3*DataSize64K))) + }, v3: expectAcceptEmptyMods, server: []byte{0, 0, 0, 1, byte(wire.ActContinue), 0, 0, 0, 1, byte(wire.ActContinue), 0, 0, 0, 1, byte(wire.ActContinue), 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"BodyReadFrom skip working", withProtC(OptSkip), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s3: func(s *ClientSession) ([]ModifyAction, *Action, error) { + return s.BodyReadFrom(bytes.NewReader(make([]byte, 3*DataSize64K))) + }, v3: expectAcceptEmptyMods, server: []byte{0, 0, 0, 1, byte(wire.ActContinue), 0, 0, 0, 1, byte(wire.ActSkip), 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"BodyReadFrom accept mid-stream", dC, ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s3: func(s *ClientSession) ([]ModifyAction, *Action, error) { + return s.BodyReadFrom(bytes.NewReader(make([]byte, 3*DataSize64K))) + }, v3: expectAcceptEmptyMods, server: []byte{0, 0, 0, 1, byte(wire.ActContinue), 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"BodyReadFrom premature", dC, ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s3: func(s *ClientSession) ([]ModifyAction, *Action, error) { + return s.BodyReadFrom(bytes.NewReader(make([]byte, 3*DataSize64K))) + }, v3: func(t *testing.T, s *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectErr1(t, s, act, err) + }, server: []byte{0, 0, 0, 1, byte(wire.ActContinue), 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"ActAddRcpt working", withActC(withProtC(0), OptAddRcpt), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, _ *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectAct(ActionAccept, t, act, err) + exp := []ModifyAction{{Type: ActionAddRcpt, Rcpt: "<>"}} + if !reflect.DeepEqual(exp, mActs) { + t.Fatalf("modifications: expect %+v, got %+v", exp, mActs) + } + }, server: []byte{0, 0, 0, 4, byte(wire.ActAddRcpt), '<', '>', 0, 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"End with ActProgress working", withActC(withProtC(0), OptAddRcpt), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, _ *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectAct(ActionAccept, t, act, err) + exp := []ModifyAction{{Type: ActionAddRcpt, Rcpt: "<>"}} + if !reflect.DeepEqual(exp, mActs) { + t.Fatalf("modifications: expect %+v, got %+v", exp, mActs) + } + }, server: []byte{0, 0, 0, 1, byte(wire.ActProgress), 0, 0, 0, 1, byte(wire.ActProgress), 0, 0, 0, 4, byte(wire.ActAddRcpt), '<', '>', 0, 0, 0, 0, 1, byte(wire.ActProgress), 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"End premature", dC, ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, s *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectErr1(t, s, act, err) + }, server: []byte{0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"ActAddRcpt error detection 1", withActC(withProtC(0), OptAddRcpt), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, s *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectErr1(t, s, act, err) + }, server: []byte{0, 0, 0, 5, byte(wire.ActAddRcpt), '<', '>', 0, 0, 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"ActAddRcpt error detection 2", withActC(withProtC(0), OptAddRcpt), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, s *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectErr1(t, s, act, err) + }, server: []byte{0, 0, 0, 1, byte(wire.ActAddRcpt), 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"OptAddRcptWithArgs working", withActC(withProtC(0), OptAddRcptWithArgs), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, _ *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectAct(ActionAccept, t, act, err) + exp := []ModifyAction{{Type: ActionAddRcpt, Rcpt: "<>", RcptArgs: "A"}} + if !reflect.DeepEqual(exp, mActs) { + t.Fatalf("modifications: expect %+v, got %+v", exp, mActs) + } + }, server: []byte{0, 0, 0, 6, byte(wire.ActAddRcptPar), '<', '>', 0, 'A', 0, 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"OptAddRcptWithArgs error detection 1", withActC(withProtC(0), OptAddRcptWithArgs), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, s *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectErr1(t, s, act, err) + }, server: []byte{0, 0, 0, 8, byte(wire.ActAddRcpt), '<', '>', 0, 'A', 0, 'B', 0, 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"OptAddRcptWithArgs error detection 2", withActC(withProtC(0), OptAddRcptWithArgs), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, s *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectErr1(t, s, act, err) + }, server: []byte{0, 0, 0, 1, byte(wire.ActAddRcpt), 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"ActDelRcpt working", withActC(withProtC(0), OptRemoveRcpt), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, _ *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectAct(ActionAccept, t, act, err) + exp := []ModifyAction{{Type: ActionDelRcpt, Rcpt: "<>"}} + if !reflect.DeepEqual(exp, mActs) { + t.Fatalf("modifications: expect %+v, got %+v", exp, mActs) + } + }, server: []byte{0, 0, 0, 4, byte(wire.ActDelRcpt), '<', '>', 0, 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"ActQuarantine working", withActC(withProtC(0), OptQuarantine), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, _ *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectAct(ActionAccept, t, act, err) + exp := []ModifyAction{{Type: ActionQuarantine, Reason: "test"}} + if !reflect.DeepEqual(exp, mActs) { + t.Fatalf("modifications: expect %+v, got %+v", exp, mActs) + } + }, server: []byte{0, 0, 0, 6, byte(wire.ActQuarantine), 't', 'e', 's', 't', 0, 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"ActReplBody working", withActC(withProtC(0), OptChangeBody), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, _ *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectAct(ActionAccept, t, act, err) + exp := []ModifyAction{{Type: ActionReplaceBody, Body: []byte("test")}} + if !reflect.DeepEqual(exp, mActs) { + t.Fatalf("modifications: expect %+v, got %+v", exp, mActs) + } + }, server: []byte{0, 0, 0, 5, byte(wire.ActReplBody), 't', 'e', 's', 't', 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"ActReplBody accept up to max size", with256KbC(withActC(withProtC(0), OptChangeBody)), ops{ + {s1: sendConnect, v1: func(t *testing.T, s *ClientSession, act *Action, err error) { + expectContinue(t, s, act, err) + if s.negotiatedBodySize != uint32(DataSize256K) { + t.Fatalf("buffer size: expected: %d, got %d", DataSize256K, s.negotiatedBodySize) + } + }, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, s *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectAct(ActionAccept, t, act, err) + // expectErr1(t, s, act, err) + }, server: func() []byte { + data := make([]byte, DataSize256K) + r := []byte{0, 0, 0, 0, byte(wire.ActReplBody)} + binary.BigEndian.PutUint32(r, uint32(1+len(data))) + r = append(r, data...) + r = append(r, 0, 0, 0, 1, byte(wire.ActAccept)) + return r + }()}, + }}, + {"ActReplBody enforce max size", withActC(withProtC(0), OptChangeBody), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, s *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectErr1(t, s, act, err) + }, server: func() []byte { + data := make([]byte, DataSize64K+1) + r := []byte{0, 0, 0, 0, byte(wire.ActReplBody)} + binary.BigEndian.PutUint32(r, uint32(1+len(data))) + r = append(r[:], data...) + return r + }()}, + }}, + {"ActChangeFrom working 1", withActC(withProtC(0), OptChangeFrom), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, _ *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectAct(ActionAccept, t, act, err) + exp := []ModifyAction{{Type: ActionChangeFrom, From: "<>"}} + if !reflect.DeepEqual(exp, mActs) { + t.Fatalf("modifications: expect %+v, got %+v", exp, mActs) + } + }, server: []byte{0, 0, 0, 4, byte(wire.ActChangeFrom), '<', '>', 0, 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"ActChangeFrom working 2", withActC(withProtC(0), OptChangeFrom), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, _ *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectAct(ActionAccept, t, act, err) + exp := []ModifyAction{{Type: ActionChangeFrom, From: "<>", FromArgs: "A"}} + if !reflect.DeepEqual(exp, mActs) { + t.Fatalf("modifications: expect %+v, got %+v", exp, mActs) + } + }, server: []byte{0, 0, 0, 6, byte(wire.ActChangeFrom), '<', '>', 0, 'A', 0, 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"ActChangeFrom error detection 1", withActC(withProtC(0), OptChangeFrom), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, s *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectErr1(t, s, act, err) + }, server: []byte{0, 0, 0, 3, byte(wire.ActChangeFrom), '<', '>', 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"ActChangeFrom error detection 2", withActC(withProtC(0), OptChangeFrom), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, s *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectErr1(t, s, act, err) + }, server: []byte{0, 0, 0, 8, byte(wire.ActChangeFrom), '<', '>', 0, 'A', 0, 'B', 0, 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"ActChangeHeader working", withActC(withProtC(0), OptChangeFrom), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, _ *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectAct(ActionAccept, t, act, err) + exp := []ModifyAction{{Type: ActionChangeHeader, HeaderIndex: 3, HeaderName: "A", HeaderValue: "B"}} + if !reflect.DeepEqual(exp, mActs) { + t.Fatalf("modifications: expect %+v, got %+v", exp, mActs) + } + }, server: []byte{0, 0, 0, 9, byte(wire.ActChangeHeader), 0, 0, 0, 3, 'A', 0, 'B', 0, 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"ActChangeHeader error detection 1", withActC(withProtC(0), OptChangeFrom), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, s *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectErr1(t, s, act, err) + }, server: []byte{0, 0, 0, 4, byte(wire.ActChangeHeader), 0, 0, 0, 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"ActChangeHeader error detection 2", withActC(withProtC(0), OptChangeFrom), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, s *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectErr1(t, s, act, err) + }, server: []byte{0, 0, 0, 5, byte(wire.ActChangeHeader), 0, 0, 0, 3, 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"ActChangeHeader error detection 3", withActC(withProtC(0), OptChangeFrom), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, s *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectErr1(t, s, act, err) + }, server: []byte{0, 0, 0, 7, byte(wire.ActChangeHeader), 0, 0, 0, 3, 'A', 0, 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"ActInsertHeader working", withActC(withProtC(0), OptAddHeader), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, _ *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectAct(ActionAccept, t, act, err) + exp := []ModifyAction{{Type: ActionInsertHeader, HeaderIndex: 3, HeaderName: "A", HeaderValue: "B"}} + if !reflect.DeepEqual(exp, mActs) { + t.Fatalf("modifications: expect %+v, got %+v", exp, mActs) + } + }, server: []byte{0, 0, 0, 9, byte(wire.ActInsertHeader), 0, 0, 0, 3, 'A', 0, 'B', 0, 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"ActInsertHeader working", withActC(withProtC(0), OptAddHeader), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, _ *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectAct(ActionAccept, t, act, err) + exp := []ModifyAction{{Type: ActionAddHeader, HeaderName: "A", HeaderValue: "B"}} + if !reflect.DeepEqual(exp, mActs) { + t.Fatalf("modifications: expect %+v, got %+v", exp, mActs) + } + }, server: []byte{0, 0, 0, 5, byte(wire.ActAddHeader), 'A', 0, 'B', 0, 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"End Unknown msg code", withActC(withProtC(0), OptChangeFrom), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, s *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectErr1(t, s, act, err) + }, server: []byte{0, 0, 0, 1, 'O', 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"Two messages in a row", withActC(withProtC(0), OptChangeFrom), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, s *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectAct(ActionAccept, t, act, err) + exp := []ModifyAction{{Type: ActionChangeHeader, HeaderIndex: 3, HeaderName: "A", HeaderValue: "B"}} + if !reflect.DeepEqual(exp, mActs) { + t.Fatalf("modifications: expect %+v, got %+v", exp, mActs) + } + }, server: []byte{0, 0, 0, 9, byte(wire.ActChangeHeader), 0, 0, 0, 3, 'A', 0, 'B', 0, 0, 0, 0, 1, byte(wire.ActAccept)}}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, s *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectAct(ActionAccept, t, act, err) + exp := []ModifyAction{{Type: ActionChangeHeader, HeaderIndex: 3, HeaderName: "A", HeaderValue: "B"}} + if !reflect.DeepEqual(exp, mActs) { + t.Fatalf("modifications: expect %+v, got %+v", exp, mActs) + } + }, server: []byte{0, 0, 0, 9, byte(wire.ActChangeHeader), 0, 0, 0, 3, 'A', 0, 'B', 0, 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"Two connections in a row", withActC(withProtC(0), OptChangeFrom), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, s *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectAct(ActionAccept, t, act, err) + exp := []ModifyAction{{Type: ActionChangeHeader, HeaderIndex: 3, HeaderName: "A", HeaderValue: "B"}} + if !reflect.DeepEqual(exp, mActs) { + t.Fatalf("modifications: expect %+v, got %+v", exp, mActs) + } + }, server: []byte{0, 0, 0, 9, byte(wire.ActChangeHeader), 0, 0, 0, 3, 'A', 0, 'B', 0, 0, 0, 0, 1, byte(wire.ActAccept)}}, + {s2: func(s *ClientSession) error { + return s.Reset(nil) + }, v2: func(t *testing.T, s *ClientSession, err error) { + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + }, server: nil}, + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, s *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectAct(ActionAccept, t, act, err) + exp := []ModifyAction{{Type: ActionChangeHeader, HeaderIndex: 3, HeaderName: "A", HeaderValue: "B"}} + if !reflect.DeepEqual(exp, mActs) { + t.Fatalf("modifications: expect %+v, got %+v", exp, mActs) + } + }, server: []byte{0, 0, 0, 9, byte(wire.ActChangeHeader), 0, 0, 0, 3, 'A', 0, 'B', 0, 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"no Reset after error", withActC(withProtC(0), OptChangeFrom), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, s *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectErr1(t, s, act, err) + }, server: []byte{0, 0, 0, 1, 'O', 0, 0, 0, 1, byte(wire.ActAccept)}}, + {s2: func(s *ClientSession) error { + return s.Reset(nil) + }, v2: func(t *testing.T, s *ClientSession, err error) { + if err == nil { + t.Fatalf("expected error") + } + }, server: nil}, + }}, + {"Abort after Rcpt", withActC(withProtC(0), OptChangeFrom), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s2: func(s *ClientSession) error { + return s.Abort(nil) + }, v2: func(t *testing.T, s *ClientSession, err error) { + if err != nil { + t.Fatalf("expected no error, got %v", err) + } + }, server: nil}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderField, v1: expectContinue, server: responseContinue}, + {s1: sendHeaderEnd, v1: expectContinue, server: responseContinue}, + {s1: sendBodyChunk, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, s *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectAct(ActionAccept, t, act, err) + exp := []ModifyAction{{Type: ActionChangeHeader, HeaderIndex: 3, HeaderName: "A", HeaderValue: "B"}} + if !reflect.DeepEqual(exp, mActs) { + t.Fatalf("modifications: expect %+v, got %+v", exp, mActs) + } + }, server: []byte{0, 0, 0, 9, byte(wire.ActChangeHeader), 0, 0, 0, 3, 'A', 0, 'B', 0, 0, 0, 0, 1, byte(wire.ActAccept)}}, + }}, + {"no Abort after error", withActC(withProtC(0), OptChangeFrom), ops{ + {s1: sendConnect, v1: expectContinue, server: responseContinue}, + {s1: sendHelo, v1: expectContinue, server: responseContinue}, + {s1: sendMail, v1: expectContinue, server: responseContinue}, + {s1: sendRcpt, v1: expectContinue, server: responseContinue}, + {s1: sendData, v1: expectContinue, server: responseContinue}, + {s3: sendEnd, v3: func(t *testing.T, s *ClientSession, mActs []ModifyAction, act *Action, err error) { + expectErr1(t, s, act, err) + }, server: []byte{0, 0, 0, 1, 'O', 0, 0, 0, 1, byte(wire.ActAccept)}}, + {s2: func(s *ClientSession) error { + return s.Abort(nil) + }, v2: func(t *testing.T, s *ClientSession, err error) { + if err == nil { + t.Fatalf("expected error") + } + }, server: nil}, + }}, } - if len(macros["tls_version"]) != 0 { - t.Fatal("Unexpected macro data:", macros) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + clientConn, serverConn := net.Pipe() + defer clientConn.Close() + defer serverConn.Close() + // just slurp up everything the client sends + go func() { + buf := make([]byte, 1024) + for { + _ = serverConn.SetReadDeadline(time.Now().Add(time.Minute)) + if _, err := serverConn.Read(buf); err != nil { + if err != io.EOF && err != io.ErrClosedPipe { + t.Logf("server got error: read: %v", err) + } + return + } + } + }() + // send pre-defined answers to the client + go func() { + if _, err := serverConn.Write(ltt.cfg.ServerNegotiation); err != nil { + t.Logf("server got error: write: %v", err) + return + } + for _, op := range ltt.ops { + if op.server == nil { + continue + } + if _, err := serverConn.Write(op.server); err != nil { + if err != io.ErrClosedPipe { + t.Logf("server got error: write: %v", err) + } + return + } + } + }() + cl := NewClient(clientConn.LocalAddr().Network(), clientConn.LocalAddr().String(), ltt.cfg.Opts...) + session, err := cl.session(clientConn, ltt.cfg.Macros) + if err != nil { + t.Fatal(err) + } + defer session.Close() + for i, op := range ltt.ops { + t.Logf("%q op %d", ltt.name, i) + if op.s1 != nil { + act, err := op.s1(session) + op.v1(t, session, act, err) + } else if op.s2 != nil { + op.v2(t, session, op.s2(session)) + } else if op.s3 != nil { + mActs, act, err := op.s3(session) + op.v3(t, session, mActs, act, err) + } else { + panic("one of s1, s2 or s3 must be set") + } + } + }) } } diff --git a/cmd/log-milter/main.go b/cmd/log-milter/main.go new file mode 100644 index 0000000..538cc44 --- /dev/null +++ b/cmd/log-milter/main.go @@ -0,0 +1,83 @@ +// Command log-milter is a no-op milter that logs all milter communication +package main + +import ( + "flag" + "log" + "math/rand" + "net" + "os" + "sync" + + "github.com/d--j/go-milter" +) + +//goland:noinspection SpellCheckingInspection +var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ") + +func randSeq(n int) string { + b := make([]rune, n) + for i := range b { + b[i] = letters[rand.Intn(len(letters))] + } + return string(b) +} + +func main() { + transport := flag.String("transport", "tcp", "Transport to use for milter connection, One of 'tcp', 'unix', 'tcp4' or 'tcp6'") + address := flag.String("address", "127.0.0.1:0", "Transport address, path for 'unix', address:port for 'tcp'") + + flag.Parse() + + // make sure socket does not exist + if *transport == "unix" { + // ignore os.Remove errors + _ = os.Remove(*address) + } + // bind to listening address + socket, err := net.Listen(*transport, *address) + if err != nil { + log.Fatal(err) + } + defer func(socket net.Listener) { + _ = socket.Close() + }(socket) + + if *transport == "unix" { + // set mode 0660 for unix domain sockets + if err := os.Chmod(*address, 0660); err != nil { + log.Fatal(err) + } + // remove socket on exit + defer func(name string) { + _ = os.Remove(name) + }(*address) + } + + server := milter.NewServer( + milter.WithMilter(func() milter.Milter { + return &LogMilter{logPrefix: randSeq(10)} + }), + milter.WithNegotiationCallback(func(mtaVersion, milterVersion uint32, mtaActions, milterActions milter.OptAction, mtaProtocol, milterProtocol milter.OptProtocol, offeredDataSize milter.DataSize) (version uint32, actions milter.OptAction, protocol milter.OptProtocol, maxDataSize milter.DataSize, err error) { + log.Printf("ACCEPT milter version %d, actions %032b, protocol %032b, data size %d", mtaVersion, mtaActions, mtaProtocol, offeredDataSize) + return mtaVersion, mtaActions, 0, offeredDataSize, nil + }), + ) + + defer func(server *milter.Server) { + _ = server.Close() + }(server) + var wgDone sync.WaitGroup + wgDone.Add(1) + go func(socket net.Listener) { + if err := server.Serve(socket); err != nil { + log.Fatal(err) + } + wgDone.Done() + }(socket) + + log.Printf("Started milter on %s:%s", socket.Addr().Network(), socket.Addr().String()) + + // quit when milter quits + wgDone.Wait() +} diff --git a/cmd/log-milter/milter.go b/cmd/log-milter/milter.go new file mode 100644 index 0000000..011dc07 --- /dev/null +++ b/cmd/log-milter/milter.go @@ -0,0 +1,136 @@ +package main + +import ( + "fmt" + "log" + + "github.com/d--j/go-milter" +) + +type LogMilter struct { + logPrefix string + macroValues map[milter.MacroName]string +} + +func (l *LogMilter) log(format string, v ...interface{}) { + log.Printf(fmt.Sprintf("[%s] %s", l.logPrefix, format), v...) +} + +func (l *LogMilter) Connect(host string, family string, port uint16, addr string, m *milter.Modifier) (*milter.Response, error) { + l.log("CONNECT host = %q, family = %q, port = %d, addr = %q", host, family, port, addr) + l.outputChangedMacros(m) + return milter.RespContinue, nil +} + +func (l *LogMilter) Helo(name string, m *milter.Modifier) (*milter.Response, error) { + l.log("HELO %q", name) + l.outputChangedMacros(m) + return milter.RespContinue, nil +} + +func (l *LogMilter) MailFrom(from string, esmtpArgs string, m *milter.Modifier) (*milter.Response, error) { + l.log("MAIL FROM <%s> %s", from, esmtpArgs) + l.outputChangedMacros(m) + return milter.RespContinue, nil +} + +func (l *LogMilter) RcptTo(rcptTo string, esmtpArgs string, m *milter.Modifier) (*milter.Response, error) { + l.log("RCPT TO <%s> %s", rcptTo, esmtpArgs) + l.outputChangedMacros(m) + return milter.RespContinue, nil +} + +func (l *LogMilter) Data(m *milter.Modifier) (*milter.Response, error) { + l.log("DATA") + l.outputChangedMacros(m) + return milter.RespContinue, nil +} + +func (l *LogMilter) Header(name string, value string, m *milter.Modifier) (*milter.Response, error) { + l.log("HEADER %s: %q", name, value) + l.outputChangedMacros(m) + return milter.RespContinue, nil +} + +func (l *LogMilter) Headers(m *milter.Modifier) (*milter.Response, error) { + l.log("EOH") + l.outputChangedMacros(m) + return milter.RespContinue, nil +} + +func (l *LogMilter) BodyChunk(chunk []byte, m *milter.Modifier) (*milter.Response, error) { + l.log("BODY CHUNK size = %d", len(chunk)) + l.outputChangedMacros(m) + return milter.RespContinue, nil +} + +func (l *LogMilter) EndOfMessage(m *milter.Modifier) (*milter.Response, error) { + l.log("EOM") + l.outputChangedMacros(m) + return milter.RespAccept, nil +} + +func (l *LogMilter) Abort(m *milter.Modifier) error { + l.log("ABORT") + l.outputChangedMacros(m) + return nil +} + +func (l *LogMilter) Unknown(cmd string, m *milter.Modifier) (*milter.Response, error) { + l.log("UNKNOWN %q", cmd) + l.outputChangedMacros(m) + return milter.RespContinue, nil +} + +func (l *LogMilter) Cleanup() { + l.log("cleanup") + l.macroValues = nil +} + +func (l *LogMilter) outputChangedMacros(m *milter.Modifier) { + if l.macroValues == nil { + l.macroValues = make(map[milter.MacroName]string) + } + for _, name := range []milter.MacroName{ + milter.MacroMTAFullyQualifiedDomainName, + milter.MacroDaemonName, + milter.MacroIfName, + milter.MacroIfAddr, + milter.MacroTlsVersion, + milter.MacroCipher, + milter.MacroCipherBits, + milter.MacroCertSubject, + milter.MacroCertIssuer, + milter.MacroQueueId, + milter.MacroAuthType, + milter.MacroAuthAuthen, + milter.MacroAuthSsf, + milter.MacroAuthAuthor, + milter.MacroMailMailer, + milter.MacroMailHost, + milter.MacroMailAddr, + milter.MacroRcptMailer, + milter.MacroRcptHost, + milter.MacroRcptAddr, + milter.MacroRFC1413AuthInfo, + milter.MacroHopCount, + milter.MacroSenderHostName, + milter.MacroProtocolUsed, + milter.MacroMTAPid, + milter.MacroDateRFC822Origin, + milter.MacroDateRFC822Current, + milter.MacroDateANSICCurrent, + milter.MacroDateSecondsCurrent, + } { + oldValue := l.macroValues[name] + newValue := m.Macros.Get(name) + if oldValue != newValue { + if oldValue != "" { + l.log(" macro %s value %q -> %q", name, oldValue, newValue) + } else { + l.log(" macro %s value %q", name, newValue) + } + } + l.macroValues[name] = newValue + } +} diff --git a/cmd/milter-check/main.go b/cmd/milter-check/main.go index a1ba245..cc9eb9b 100644 --- a/cmd/milter-check/main.go +++ b/cmd/milter-check/main.go @@ -1,3 +1,4 @@ +// Command milter-check can be used to send test data to milters. package main import ( @@ -6,46 +7,49 @@ import ( "log" "os" "strings" - "time" + "github.com/d--j/go-milter" + "github.com/d--j/go-milter/milterutil" "github.com/emersion/go-message/textproto" - "github.com/emersion/go-milter" + "golang.org/x/text/transform" ) func printAction(prefix string, act *milter.Action) { - switch act.Code { - case milter.ActAccept: + switch act.Type { + case milter.ActionAccept: log.Println(prefix, "accept") - case milter.ActReject: + case milter.ActionReject: log.Println(prefix, "reject") - case milter.ActDiscard: + case milter.ActionDiscard: log.Println(prefix, "discard") - case milter.ActTempFail: + case milter.ActionTempFail: log.Println(prefix, "temp. fail") - case milter.ActReplyCode: - log.Println(prefix, "reply code:", act.SMTPCode, act.SMTPText) - case milter.ActContinue: + case milter.ActionRejectWithCode: + log.Println(prefix, "reply code:", act.SMTPCode, act.SMTPReply) + case milter.ActionContinue: log.Println(prefix, "continue") + case milter.ActionSkip: + log.Println(prefix, "skip") } } func printModifyAction(act milter.ModifyAction) { - switch act.Code { - case milter.ActAddHeader: + switch act.Type { + case milter.ActionAddHeader: log.Printf("add header: name %s, value %s", act.HeaderName, act.HeaderValue) - case milter.ActInsertHeader: + case milter.ActionInsertHeader: log.Printf("insert header: at %d, name %s, value %s", act.HeaderIndex, act.HeaderName, act.HeaderValue) - case milter.ActChangeFrom: + case milter.ActionChangeFrom: log.Printf("change from: %s %v", act.From, act.FromArgs) - case milter.ActChangeHeader: + case milter.ActionChangeHeader: log.Printf("change header: at %d, name %s, value %s", act.HeaderIndex, act.HeaderName, act.HeaderValue) - case milter.ActReplBody: + case milter.ActionReplaceBody: log.Println("replace body:", string(act.Body)) - case milter.ActAddRcpt: + case milter.ActionAddRcpt: log.Println("add rcpt:", act.Rcpt) - case milter.ActDelRcpt: + case milter.ActionDelRcpt: log.Println("del rcpt:", act.Rcpt) - case milter.ActQuarantine: + case milter.ActionQuarantine: log.Println("quarantine:", act.Reason) } } @@ -61,26 +65,21 @@ func main() { mailFrom := flag.String("from", "foxcpp@example.org", "Value to send in MAIL message") rcptTo := flag.String("rcpt", "foxcpp@example.com", "Comma-separated list of values for RCPT messages") actionMask := flag.Uint("actions", - uint(milter.OptChangeBody|milter.OptChangeFrom|milter.OptChangeHeader| - milter.OptAddHeader|milter.OptAddRcpt|milter.OptChangeFrom), + uint(milter.AllClientSupportedActionMasks), "Bitmask value of actions we allow") disabledMsgs := flag.Uint("disabled-msgs", 0, "Bitmask of disabled protocol messages") flag.Parse() - c := milter.NewClientWithOptions(*transport, *address, milter.ClientOptions{ - ActionMask: milter.OptAction(*actionMask), - ProtocolMask: milter.OptProtocol(*disabledMsgs), - ReadTimeout: 10 * time.Second, - WriteTimeout: 10 * time.Second, - }) - defer c.Close() + c := milter.NewClient(*transport, *address, milter.WithActions(milter.OptAction(*actionMask)), milter.WithProtocols(milter.OptProtocol(*disabledMsgs))) - s, err := c.Session() + s, err := c.Session(nil) if err != nil { log.Println(err) return } - defer s.Close() + defer func(s *milter.ClientSession) { + _ = s.Close() + }(s) act, err := s.Conn(*hostname, milter.ProtoFamily((*family)[0]), uint16(*port), *connAddr) if err != nil { @@ -88,7 +87,7 @@ func main() { return } printAction("CONNECT:", act) - if act.Code != milter.ActContinue { + if act.StopProcessing() { return } @@ -98,33 +97,43 @@ func main() { return } printAction("HELO:", act) - if act.Code != milter.ActContinue { + if act.StopProcessing() { return } - act, err = s.Mail(*mailFrom, nil) + act, err = s.Mail(*mailFrom, "") if err != nil { log.Println(err) return } printAction("MAIL:", act) - if act.Code != milter.ActContinue { + if act.StopProcessing() { return } for _, rcpt := range strings.Split(*rcptTo, ",") { - act, err = s.Rcpt(rcpt, nil) + act, err = s.Rcpt(rcpt, "") if err != nil { log.Println(err) return } printAction("RCPT:", act) - if act.Code != milter.ActContinue { + if act.StopProcessing() { return } } - bufR := bufio.NewReader(os.Stdin) + act, err = s.DataStart() + if err != nil { + log.Println(err) + return + } + printAction("DATA:", act) + if act.StopProcessing() { + return + } + + bufR := bufio.NewReader(transform.NewReader(os.Stdin, &milterutil.CrLfCanonicalizationTransformer{})) hdr, err := textproto.ReadHeader(bufR) if err != nil { log.Println("header parse:", err) @@ -137,7 +146,7 @@ func main() { return } printAction("HEADER:", act) - if act.Code != milter.ActContinue { + if act.StopProcessing() { return } diff --git a/cstrings.go b/cstrings.go deleted file mode 100644 index 5ebcd6a..0000000 --- a/cstrings.go +++ /dev/null @@ -1,33 +0,0 @@ -package milter - -import ( - "bytes" - "strings" -) - -// NULL terminator -const null = "\x00" - -// DecodeCStrings splits a C style strings into a Go slice -func decodeCStrings(data []byte) []string { - if len(data) == 0 { - return nil - } - return strings.Split(strings.Trim(string(data), null), null) -} - -// ReadCString reads and returns a C style string from []byte -func readCString(data []byte) string { - pos := bytes.IndexByte(data, 0) - if pos == -1 { - return string(data) - } - return string(data[0:pos]) -} - -// appendCString appends a C style string to the buffer and returns it (like append does). -func appendCString(dest []byte, s string) []byte { - dest = append(dest, []byte(s)...) - dest = append(dest, 0x00) - return dest -} diff --git a/example_client_test.go b/example_client_test.go new file mode 100644 index 0000000..ba1b00d --- /dev/null +++ b/example_client_test.go @@ -0,0 +1,89 @@ +package milter_test + +import ( + "log" + "strings" + "time" + + "github.com/d--j/go-milter" +) + +func ExampleClient() { + // create milter definition once + client := milter.NewClient("tcp", "127.0.0.1:1234") + globalMacros := milter.NewMacroBag() + globalMacros.Set(milter.MacroMTAFullyQualifiedDomainName, "localhost.local") + globalMacros.Set(milter.MacroMTAPid, "123") + + // on each SMTP connection + macros := globalMacros.Copy() + session, err := client.Session(macros) + if err != nil { + panic(err) + } + defer session.Close() + + handleMilterResponse := func(act *milter.Action, err error) { + if err != nil { + // you should disable this milter for this connection or close the SMTP transaction + panic(err) + } + if act.StopProcessing() { + // abort SMTP transaction, you can use act.SMTPReply to send to the SMTP client + panic(act.SMTPReply) + } + if act.Type == milter.ActionDiscard { + // close the milter connection (do not send more SMTP events of this SMTP transaction) + // but keep SMTP connection open and after DATA, silently discard the message + panic(session.Close()) + } + } + + // for each received SMTP command set relevant macros and send it to the milter + macros.Set(milter.MacroIfAddr, "127.0.0.1") + macros.Set(milter.MacroIfName, "eth0") + handleMilterResponse(session.Conn("spammer.example.com", milter.FamilyInet, 0, "127.0.0.15")) + + macros.Set(milter.MacroSenderHostName, "spammer.example.com") + macros.Set(milter.MacroTlsVersion, "SSLv3") + handleMilterResponse(session.Helo("spammer.example.com")) + + macros.Set(milter.MacroMailMailer, "esmtp") + macros.Set(milter.MacroMailHost, "example.com") + macros.Set(milter.MacroMailAddr, "spammer@example.com") + handleMilterResponse(session.Mail("", "")) + + macros.Set(milter.MacroRcptMailer, "local") + macros.Set(milter.MacroRcptHost, "example.com") + macros.Set(milter.MacroRcptAddr, "other-spammer@example.com") + handleMilterResponse(session.Rcpt("", "")) + + macros.Set(milter.MacroRcptMailer, "local") + macros.Set(milter.MacroRcptHost, "example.com") + macros.Set(milter.MacroRcptAddr, "other-spammer2@example.com") + handleMilterResponse(session.Rcpt("", "")) + + // After DataStart you should send the initial SMTP data to the first milter, accept and apply its modifications + // and then send this modified data to the next milter. Before this point all milters could be queried in parallel. + handleMilterResponse(session.DataStart()) + + handleMilterResponse(session.HeaderField("From", "Your Bank ", nil)) + handleMilterResponse(session.HeaderField("To", "Your ", nil)) + handleMilterResponse(session.HeaderField("Subject", "Your money", nil)) + macros.SetHeaderDate(time.Date(2023, time.January, 1, 1, 1, 1, 0, time.UTC)) + handleMilterResponse(session.HeaderField("Date", "Sun, 1 Jan 2023 00:00:00 +0000", nil)) + + handleMilterResponse(session.HeaderEnd()) + + mActs, act, err := session.BodyReadFrom(strings.NewReader("Hello You,\r\ndo you want money?\r\nYour bank\r\n")) + if err != nil { + panic(err) + } + if act.StopProcessing() { + panic(act.SMTPReply) + } + for _, mAct := range mActs { + // process mAct + log.Print(mAct) + } +} diff --git a/example_server_test.go b/example_server_test.go new file mode 100644 index 0000000..c455929 --- /dev/null +++ b/example_server_test.go @@ -0,0 +1,56 @@ +package milter_test + +import ( + "log" + "net" + "sync" + + "github.com/d--j/go-milter" +) + +type ExampleBackend struct { + milter.NoOpMilter +} + +func (b *ExampleBackend) RcptTo(rcptTo string, esmtpArgs string, m *milter.Modifier) (*milter.Response, error) { + // reject the mail when it goes to other-spammer@example.com and is a local delivery + if rcptTo == "other-spammer@example.com" && m.Macros.Get(milter.MacroRcptMailer) == "local" { + return milter.RejectWithCodeAndReason(550, "We do not like you\r\nvery much, please go away") + } + return milter.RespContinue, nil +} + +func ExampleServer() { + // create socket to listen on + socket, err := net.Listen("tcp4", "127.0.0.1:6785") + if err != nil { + log.Fatal(err) + } + defer socket.Close() + + // define the backend, required actions, protocol options and macros we want + server := milter.NewServer( + milter.WithMilter(func() milter.Milter { + return &ExampleBackend{} + }), + milter.WithProtocol(milter.OptNoConnect|milter.OptNoHelo|milter.OptNoMailFrom|milter.OptNoBody|milter.OptNoHeaders|milter.OptNoEOH|milter.OptNoUnknown|milter.OptNoData), + milter.WithAction(milter.OptChangeFrom|milter.OptAddRcpt|milter.OptRemoveRcpt), + milter.WithMacroRequest(milter.StageRcpt, []milter.MacroName{milter.MacroRcptMailer}), + ) + defer server.Close() + + // start the milter + var wgDone sync.WaitGroup + wgDone.Add(1) + go func(socket net.Listener) { + if err := server.Serve(socket); err != nil { + log.Fatal(err) + } + wgDone.Done() + }(socket) + + log.Printf("Started milter on %s:%s", socket.Addr().Network(), socket.Addr().String()) + + // quit when milter quits + wgDone.Wait() +} diff --git a/go.mod b/go.mod index d9942d2..9af8466 100644 --- a/go.mod +++ b/go.mod @@ -1,5 +1,8 @@ -module github.com/emersion/go-milter +module github.com/d--j/go-milter -go 1.12 +go 1.18 -require github.com/emersion/go-message v0.11.2 +require ( + github.com/emersion/go-message v0.11.2 + golang.org/x/text v0.3.2 +) diff --git a/go.sum b/go.sum index d0cb08b..7e7e386 100644 --- a/go.sum +++ b/go.sum @@ -7,5 +7,6 @@ github.com/martinlindhe/base36 v1.0.0/go.mod h1:+AtEs8xrBpCeYgSLoY/aJ6Wf37jtBuR0 github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +golang.org/x/text v0.3.2 h1:tW2bmiBqwgJj/UpqtC8EpXEZVYOwU0yG4iWbprSVAcs= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/internal/wire/cstrings.go b/internal/wire/cstrings.go new file mode 100644 index 0000000..735530f --- /dev/null +++ b/internal/wire/cstrings.go @@ -0,0 +1,40 @@ +package wire + +import ( + "bytes" + "strings" +) + +// NULL terminator +const null = "\x00" + +// DecodeCStrings splits a C style strings into a Go string slice +// The last C style string in data can optionally not be terminated with a null-byte. +func DecodeCStrings(data []byte) []string { + if len(data) == 0 { + return nil + } + // strip the last null byte + if data[len(data)-1] == 0 { + data = data[0 : len(data)-1] + } + return strings.Split(string(data), null) +} + +// ReadCString reads and returns a C style string from []byte. +// If data does not contain a null-byte the whole data-slice is returned as string +func ReadCString(data []byte) string { + pos := bytes.IndexByte(data, 0) + if pos == -1 { + return string(data) + } + return string(data[0:pos]) +} + +// AppendCString appends a C style string to the buffer and returns it (like append does). +// It is assumed that s does not contain null-bytes. +func AppendCString(dest []byte, s string) []byte { + dest = append(dest, []byte(s)...) + dest = append(dest, 0x00) + return dest +} diff --git a/internal/wire/cstrings_test.go b/internal/wire/cstrings_test.go new file mode 100644 index 0000000..6a5bf86 --- /dev/null +++ b/internal/wire/cstrings_test.go @@ -0,0 +1,80 @@ +package wire + +import ( + "reflect" + "testing" +) + +func TestDecodeCStrings(t *testing.T) { + tests := []struct { + name string + data []byte + want []string + }{ + {"single string", []byte("one\u0000"), []string{"one"}}, + {"two strings", []byte("one\u0000two\u0000"), []string{"one", "two"}}, + {"last empty", []byte("one\u0000\u0000"), []string{"one", ""}}, + {"first empty", []byte("\u0000two\u0000"), []string{"", "two"}}, + {"all empty", []byte("\u0000\u0000"), []string{"", ""}}, + {"nil in nil out", nil, nil}, + {"empty ok", []byte{}, nil}, + {"missing last null", []byte("one"), []string{"one"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + if got := DecodeCStrings(ltt.data); !reflect.DeepEqual(got, ltt.want) { + t.Errorf("DecodeCStrings() = %v, want %v", got, ltt.want) + } + }) + } +} + +func TestAppendCString(t *testing.T) { + type args struct { + dest []byte + s string + } + tests := []struct { + name string + args args + want []byte + }{ + {"append to nil", args{nil, "append"}, []byte("append\u0000")}, + {"append to empty", args{[]byte{}, "append"}, []byte("append\u0000")}, + {"append", args{[]byte("one\u0000"), "append"}, []byte("one\u0000append\u0000")}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + if got := AppendCString(ltt.args.dest, ltt.args.s); !reflect.DeepEqual(got, ltt.want) { + t.Errorf("AppendCString() = %v, want %v", got, ltt.want) + } + }) + } +} + +func TestReadCString(t *testing.T) { + tests := []struct { + name string + data []byte + want string + }{ + {"simple", []byte("simple\u0000"), "simple"}, + {"trailing", []byte("simple\u0000other data"), "simple"}, + {"no null", []byte("simple"), "simple"}, + {"empty", []byte("\u0000"), ""}, + {"nil", nil, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + if got := ReadCString(ltt.data); got != ltt.want { + t.Errorf("ReadCString() = %v, want %v", got, ltt.want) + } + }) + } +} diff --git a/milter-protocol-extras.txt b/internal/wire/milter-protocol-extras.txt similarity index 100% rename from milter-protocol-extras.txt rename to internal/wire/milter-protocol-extras.txt diff --git a/milter-protocol.txt b/internal/wire/milter-protocol.txt similarity index 100% rename from milter-protocol.txt rename to internal/wire/milter-protocol.txt diff --git a/internal/wire/wire.go b/internal/wire/wire.go new file mode 100644 index 0000000..77e2c73 --- /dev/null +++ b/internal/wire/wire.go @@ -0,0 +1,136 @@ +// Package wire includes constants and functions for the raw libmilter protocol +package wire + +import ( + "encoding/binary" + "errors" + "fmt" + "io" + "net" + "time" +) + +type Code byte + +// Message represents a command sent from milter client +type Message struct { + Code Code + Data []byte +} + +type ActionCode byte + +const ( + ActAccept ActionCode = 'a' // SMFIR_ACCEPT + ActContinue ActionCode = 'c' // SMFIR_CONTINUE + ActDiscard ActionCode = 'd' // SMFIR_DISCARD + ActReject ActionCode = 'r' // SMFIR_REJECT + ActTempFail ActionCode = 't' // SMFIR_TEMPFAIL + ActReplyCode ActionCode = 'y' // SMFIR_REPLYCODE + ActSkip ActionCode = 's' // SMFIR_SKIP [v6] + ActProgress ActionCode = 'p' // SMFIR_PROGRESS [v6] +) + +type ModifyActCode byte + +const ( + ActAddRcpt ModifyActCode = '+' // SMFIR_ADDRCPT + ActDelRcpt ModifyActCode = '-' // SMFIR_DELRCPT + ActReplBody ModifyActCode = 'b' // SMFIR_ACCEPT + ActAddHeader ModifyActCode = 'h' // SMFIR_ADDHEADER + ActChangeHeader ModifyActCode = 'm' // SMFIR_CHGHEADER + ActInsertHeader ModifyActCode = 'i' // SMFIR_INSHEADER + ActQuarantine ModifyActCode = 'q' // SMFIR_QUARANTINE + ActChangeFrom ModifyActCode = 'e' // SMFIR_CHGFROM [v6] + ActAddRcptPar ModifyActCode = '2' // SMFIR_ADDRCPT_PAR [v6] +) + +const ( + CodeOptNeg Code = 'O' // SMFIC_OPTNEG + CodeMacro Code = 'D' // SMFIC_MACRO + CodeConn Code = 'C' // SMFIC_CONNECT + CodeQuit Code = 'Q' // SMFIC_QUIT + CodeHelo Code = 'H' // SMFIC_HELO + CodeMail Code = 'M' // SMFIC_MAIL + CodeRcpt Code = 'R' // SMFIC_RCPT + CodeHeader Code = 'L' // SMFIC_HEADER + CodeEOH Code = 'N' // SMFIC_EOH + CodeBody Code = 'B' // SMFIC_BODY + CodeEOB Code = 'E' // SMFIC_BODYEOB + CodeAbort Code = 'A' // SMFIC_ABORT + CodeData Code = 'T' // SMFIC_DATA + CodeQuitNewConn Code = 'K' // SMFIC_QUIT_NC [v6] + CodeUnknown Code = 'U' // SMFIC_UNKNOWN [v6] +) + +// We reject reading/writing messages larger than 512 MB outright. +const maxPacketSize = 512 * 1024 * 1024 + +func ReadPacket(conn net.Conn, timeout time.Duration) (*Message, error) { + if timeout != 0 { + _ = conn.SetReadDeadline(time.Now().Add(timeout)) + defer func(conn net.Conn) { + _ = conn.SetReadDeadline(time.Time{}) + }(conn) + } + + // read packet length + var length uint32 + if err := binary.Read(conn, binary.BigEndian, &length); err != nil { + return nil, err + } + + if length > maxPacketSize { + return nil, fmt.Errorf("milter: reject to read %d bytes in one message", length) + } + + // read packet data + data := make([]byte, length) + if _, err := io.ReadFull(conn, data); err != nil { + return nil, err + } + + // prepare response data + message := Message{ + Code: Code(data[0]), + Data: data[1:], + } + + return &message, nil +} + +func WritePacket(conn net.Conn, msg *Message, timeout time.Duration) error { + if msg == nil { + return errors.New("msg nil pointer") + } + if timeout != 0 { + _ = conn.SetWriteDeadline(time.Now().Add(timeout)) + defer func(conn net.Conn) { + _ = conn.SetWriteDeadline(time.Time{}) + }(conn) + } + + length := len(msg.Data) + 1 + if length > maxPacketSize { + return fmt.Errorf("milter: cannot write %d bytes in one message", length) + } + + header := [5]byte{0, 0, 0, 0, byte(msg.Code)} + binary.BigEndian.PutUint32(header[0:], uint32(length)) + _, err := conn.Write(header[:]) + if err != nil { + return err + } + + if len(msg.Data) == 0 { + return nil + } + _, err = conn.Write(msg.Data) + + return err +} + +// AppendUint16 appends the big endian encoding of val to dest. It returns the new dest like append does. +func AppendUint16(dest []byte, val uint16) []byte { + return append(dest, byte(val>>8), byte(val)) +} diff --git a/internal/wire/wire_test.go b/internal/wire/wire_test.go new file mode 100644 index 0000000..8b298c6 --- /dev/null +++ b/internal/wire/wire_test.go @@ -0,0 +1,177 @@ +package wire + +import ( + "bytes" + "fmt" + "io" + "net" + "testing" + "time" +) + +func TestReadPacket(t *testing.T) { + type packet struct { + data []byte + sleep time.Duration + } + + type packets []packet + + type args struct { + data packets + timeout time.Duration + } + tests := []struct { + name string + args args + want *Message + wantErr bool + }{ + {"Error on bogus data", args{packets{{[]byte("bogus"), 0}}, time.Second}, nil, true}, + {"Simple", args{packets{{[]byte{0, 0, 0, 1}, 0}, {[]byte("b"), 0}}, time.Second}, &Message{Code: 'b'}, false}, + {"Timeout", args{packets{{[]byte{0, 0, 0, 1}, 2 * time.Second}, {[]byte("b"), 0}}, time.Second}, nil, true}, + {"Timeout2", args{packets{{[]byte{}, 2 * time.Second}, {[]byte{0, 0, 0, 1, 'b'}, 0}}, time.Second}, nil, true}, + {"With Data", args{packets{{[]byte{0, 0, 0, 4, 't', 'e', 's', 't'}, 0}}, time.Second}, &Message{Code: 't', Data: []byte{'e', 's', 't'}}, false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + serverChan := make(chan error) + // Acceptor + go func() { + c, err := ln.Accept() + if err != nil { + serverChan <- err + return + } + c.SetDeadline(time.Now().Add(time.Minute)) // Not intended to fire. + for m := 0; m < len(ltt.args.data); m++ { + if n, err := c.Write(ltt.args.data[m].data); err != nil || n != len(ltt.args.data[m].data) { + if err == nil { + err = fmt.Errorf("expected to write %d bytes but only wrote %d bytes", len(ltt.args.data[m].data), n) + } + serverChan <- err + return + } + if ltt.args.data[m].sleep > 0 { + time.Sleep(ltt.args.data[m].sleep) + } + } + serverChan <- nil + }() + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + got, err := ReadPacket(conn, ltt.args.timeout) + if (err != nil) != ltt.wantErr { + t.Errorf("ReadPacket() error = %v, wantErr %v", err, ltt.wantErr) + return + } + if (got == nil && ltt.want != nil) || ((got != nil && ltt.want != nil) && (got.Code != ltt.want.Code || !bytes.Equal(got.Data, ltt.want.Data))) { + t.Errorf("ReadPacket() got = %+v, want %+v", got, ltt.want) + } + if serverErr := <-serverChan; serverErr != nil { + t.Fatal(serverErr) + } + }) + } +} + +func TestWritePacket(t *testing.T) { + type writeOp struct { + msg *Message + onAfter func(ln net.Listener, conn net.Conn) + onBefore func(ln net.Listener, conn net.Conn) + } + type writeOps []writeOp + tests := []struct { + name string + writeOps writeOps + want []byte + wantErr bool + }{ + {"Single", writeOps{{msg: &Message{Code: 'a'}}}, []byte{0, 0, 0, 1, 'a'}, false}, + {"Single2", writeOps{{msg: &Message{Code: 'a', Data: []byte{'a', 0}}}}, []byte{0, 0, 0, 3, 'a', 'a', 0}, false}, + {"Too big", writeOps{{msg: &Message{Code: 'a', Data: make([]byte, 513*(1024*1024))}}}, nil, true}, + {"Nil msg", writeOps{{msg: nil}}, nil, true}, + {"Multiple", writeOps{{msg: &Message{Code: 'a'}}, {msg: &Message{Code: 'b'}}}, []byte{0, 0, 0, 1, 'a', 0, 0, 0, 1, 'b'}, false}, + {"Multiple close in middle", writeOps{{msg: &Message{Code: 'a'}, onAfter: func(ln net.Listener, conn net.Conn) { _ = conn.Close() }}, {msg: &Message{Code: 'b'}}}, []byte{0, 0, 0, 1, 'a'}, true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatal(err) + } + defer ln.Close() + type response struct { + data []byte + err error + } + serverChan := make(chan response) + // Acceptor + go func() { + c, err := ln.Accept() + if err != nil { + serverChan <- response{err: err} + return + } + c.SetDeadline(time.Now().Add(time.Minute)) // Not intended to fire. + data, err := io.ReadAll(c) + if err != nil { + serverChan <- response{err: err} + return + } + serverChan <- response{data: data} + }() + conn, err := net.Dial("tcp", ln.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer func(conn net.Conn) { + _ = conn.Close() + }(conn) + for _, op := range ltt.writeOps { + if op.onBefore != nil { + op.onBefore(ln, conn) + } + err = WritePacket(conn, op.msg, time.Minute) + if err != nil { + break + } + if op.onAfter != nil { + op.onAfter(ln, conn) + } + } + _ = conn.Close() + if (err != nil) != ltt.wantErr { + t.Errorf("WritePacket() error = %v, wantErr %v", err, ltt.wantErr) + resp := <-serverChan + if resp.err != nil { + t.Fatal(resp.err) + } + t.Errorf("read data %v", resp.data) + + return + } + + resp := <-serverChan + if resp.err != nil { + t.Fatal(resp.err) + } + if !bytes.Equal(resp.data, ltt.want) { + t.Errorf("read data mismatch got = %+v, want %+v", resp.data, ltt.want) + } + }) + } +} diff --git a/log.go b/log.go new file mode 100644 index 0000000..08b1873 --- /dev/null +++ b/log.go @@ -0,0 +1,18 @@ +package milter + +import ( + "fmt" + "log" +) + +func logWarning(format string, v ...interface{}) { + log.Printf(fmt.Sprintf("milter: warning: %s", format), v...) +} + +// LogWarning is called by this library when it wants to output a warning. +// Warnings can happen even when the library user did everything right (because the other end did something wrong +// but we recovered from it) +// +// The default implementation uses log.Print to output the warning. +// You can re-assign LogWarning to something more suitable for your application. But do not assign nil to it. +var LogWarning = logWarning diff --git a/macro.go b/macro.go new file mode 100644 index 0000000..23bd8e0 --- /dev/null +++ b/macro.go @@ -0,0 +1,330 @@ +package milter + +import ( + "fmt" + "strings" + "sync" + "time" + "unicode" +) + +type MacroStage = byte + +const ( + StageConnect MacroStage = iota // SMFIM_CONNECT + StageHelo // SMFIM_HELO + StageMail // SMFIM_ENVFROM + StageRcpt // SMFIM_ENVRCPT + StageData // SMFIM_DATA + StageEOM // SMFIM_EOM + StageEOH // SMFIM_EOH + StageEndMarker // is used for command level macros for Abort, Unknown and Header commands + StageNotFoundMarker // identifies that a macro was not found +) + +type MacroName = string + +// Macros that have good support between MTAs like sendmail and Postfix +const ( + MacroMTAFullyQualifiedDomainName MacroName = "j" + MacroDaemonName MacroName = "{daemon_name}" + MacroIfName MacroName = "{if_name}" + MacroIfAddr MacroName = "{if_addr}" + MacroTlsVersion MacroName = "{tls_version}" + MacroCipher MacroName = "{cipher}" + MacroCipherBits MacroName = "{cipher_bits}" + MacroCertSubject MacroName = "{cert_subject}" + MacroCertIssuer MacroName = "{cert_issuer}" + // The queue ID for this message. Some MTAs only assign a Queue ID after the DATA command (Postfix) + MacroQueueId MacroName = "i" + // The used authentication method (LOGIN, DIGEST-MD5, etc) + MacroAuthType MacroName = "{auth_type}" + // The username of the authenticated user + MacroAuthAuthen MacroName = "{auth_authen}" + // The key length (in bits) of the used encryption layer (TLS) – if any + MacroAuthSsf MacroName = "{auth_ssf}" + // The optional overwrite username for this message + MacroAuthAuthor MacroName = "{auth_author}" + // the delivery agent for this MAIL FROM (e.g. esmtp, lmtp) + MacroMailMailer MacroName = "{mail_mailer}" + // the domain part of the MAIL FROM address + MacroMailHost MacroName = "{mail_host}" + // the MAIL FROM address (only the address without <>) + MacroMailAddr MacroName = "{mail_addr}" + // MacroRcptMailer holds the delivery agent for the current RCPT TO address + MacroRcptMailer MacroName = "{rcpt_mailer}" + // The domain part of the RCPT TO address + MacroRcptHost MacroName = "{rcpt_host}" + // the RCPT TO address (only the address without <>) + MacroRcptAddr MacroName = "{rcpt_addr}" +) + +// Macros that do not have good cross-MTA support. Only usable with sendmail as MTA. +const ( + MacroRFC1413AuthInfo MacroName = "_" + MacroHopCount MacroName = "c" + MacroSenderHostName MacroName = "s" + MacroProtocolUsed MacroName = "r" + MacroMTAPid MacroName = "p" + MacroDateRFC822Origin MacroName = "a" + MacroDateRFC822Current MacroName = "b" + MacroDateANSICCurrent MacroName = "d" + MacroDateSecondsCurrent MacroName = "t" +) + +type macroRequests [][]MacroName + +type Macros interface { + Get(name MacroName) string + GetEx(name MacroName) (value string, ok bool) +} + +// MacroBag is a default implementation of the Macros interface. +// A MacroBag is safe for concurrent use by multiple goroutines. +// It has special handling for the date related macros and can be copied. +// +// The zero value of MacroBag is invalid. Use NewMacroBag to create an empty MacroBag. +type MacroBag struct { + macros map[MacroName]string + mutex sync.RWMutex + currentDate, headerDate time.Time +} + +func NewMacroBag() *MacroBag { + return &MacroBag{ + macros: make(map[MacroName]string), + } +} + +func (m *MacroBag) Get(name MacroName) string { + v, _ := m.GetEx(name) + return v +} + +func (m *MacroBag) GetEx(name MacroName) (value string, ok bool) { + m.mutex.RLock() + defer m.mutex.RUnlock() + value, ok = m.macros[name] + if !ok { + switch name { + case MacroDateRFC822Origin: + if !m.headerDate.IsZero() { + ok = true + value = m.headerDate.Format(time.RFC822Z) + } + case MacroDateRFC822Current, MacroDateSecondsCurrent, MacroDateANSICCurrent: + ok = true + current := m.currentDate + if current.IsZero() { + current = time.Now() + } + switch name { + case MacroDateRFC822Current: + value = current.Format(time.RFC822Z) + case MacroDateSecondsCurrent: + value = fmt.Sprintf("%d", current.Unix()) + case MacroDateANSICCurrent: + value = current.Format(time.ANSIC) + } + } + } + return +} + +func (m *MacroBag) Set(name MacroName, value string) { + m.mutex.Lock() + defer m.mutex.Unlock() + m.macros[name] = value +} + +// Copy copies the macros to a new MacroBag. +// The time.Time values set by SetCurrentDate and SetHeaderDate do not get copied. +func (m *MacroBag) Copy() *MacroBag { + m.mutex.Lock() + defer m.mutex.Unlock() + macros := make(map[MacroName]string) + for k, v := range m.macros { + macros[k] = v + } + return &MacroBag{macros: macros} +} + +func (m *MacroBag) SetCurrentDate(date time.Time) { + m.mutex.Lock() + defer m.mutex.Unlock() + m.currentDate = date +} + +func (m *MacroBag) SetHeaderDate(date time.Time) { + m.mutex.Lock() + defer m.mutex.Unlock() + m.headerDate = date +} + +var _ Macros = &MacroBag{} + +type macrosStages struct { + byStages []map[MacroName]string +} + +func newMacroStages() *macrosStages { + return ¯osStages{ + byStages: make([]map[MacroName]string, StageEndMarker+1), + } +} + +func (s *macrosStages) GetMacroEx(name MacroName) (value string, stageFound MacroStage) { + i := StageEndMarker + for { + if s.byStages[i] != nil { + if v, ok := s.byStages[i][name]; ok { + return v, i + } + } + if i == StageConnect { + return "", StageNotFoundMarker + } + i-- + } +} + +func (s *macrosStages) SetMacro(stage MacroStage, name MacroName, value string) { + if len(s.byStages) < int(stage) { + panic(fmt.Sprintf("tried to set macro in invalid stage %v", stage)) + } + if s.byStages[stage] == nil { + s.byStages[stage] = make(map[MacroName]string) + } + s.byStages[stage][name] = value +} + +func (s *macrosStages) SetStageMap(stage MacroStage, kv map[MacroName]string) { + if len(s.byStages) < int(stage) { + panic(fmt.Sprintf("tried to set invalid stage %v", stage)) + } + s.byStages[stage] = make(map[MacroName]string) + for k, v := range kv { + s.byStages[stage][k] = v + } +} + +func (s *macrosStages) SetStage(stage MacroStage, kv ...string) { + if len(kv)%2 != 0 { + panic(fmt.Sprintf("kv needs to have an even amount of entries, not %d", len(kv))) + } + if len(s.byStages) < int(stage) { + panic(fmt.Sprintf("tried to set invalid stage %v", stage)) + } + s.byStages[stage] = make(map[MacroName]string) + k := "" + for i, str := range kv { + if i%2 == 0 { + k = str + } else { + s.byStages[stage][k] = str + } + } +} + +func (s *macrosStages) DelMacro(stage MacroStage, name MacroName) { + if s.byStages[stage] == nil { + return + } + delete(s.byStages[stage], name) + if len(s.byStages[stage]) == 0 { + s.byStages[stage] = nil + } +} + +func (s *macrosStages) DelStage(stage MacroStage) { + s.byStages[stage] = nil +} + +func (s *macrosStages) DelStageAndAbove(stage MacroStage) { + var stages []MacroStage + switch stage { + case StageConnect: + stages = []MacroStage{StageConnect, StageHelo, StageMail, StageRcpt, StageData, StageEOH, StageEOM, StageEndMarker} + case StageHelo: + stages = []MacroStage{StageHelo, StageMail, StageRcpt, StageData, StageEOH, StageEOM, StageEndMarker} + case StageMail: + stages = []MacroStage{StageMail, StageRcpt, StageData, StageEOH, StageEOM, StageEndMarker} + case StageRcpt: + stages = []MacroStage{StageRcpt, StageData, StageEOH, StageEOM, StageEndMarker} + case StageData: + stages = []MacroStage{StageData, StageEOH, StageEOM, StageEndMarker} + case StageEOH: + stages = []MacroStage{StageEOH, StageEOM, StageEndMarker} + case StageEOM: + stages = []MacroStage{StageEOM, StageEndMarker} + case StageEndMarker: + stages = []MacroStage{StageEndMarker} + } + for _, st := range stages { + s.byStages[st] = nil + } +} + +// macroReader is a read-only Macros compatible view of its macroStages +type macroReader struct { + macrosStages *macrosStages +} + +func (r *macroReader) GetEx(name MacroName) (val string, ok bool) { + if r == nil || r.macrosStages == nil { + return "", false + } + var stage MacroStage + val, stage = r.macrosStages.GetMacroEx(name) + ok = stage <= StageEndMarker // StageEndMarker is for command-level macros + return +} + +func (r *macroReader) Get(name MacroName) string { + v, _ := r.GetEx(name) + return v +} + +var _ Macros = ¯oReader{} + +func parseRequestedMacros(str string) []string { + return removeEmpty(strings.FieldsFunc(str, func(r rune) bool { + return unicode.IsSpace(r) || r == ',' + })) +} + +func removeEmpty(str []string) []string { + if len(str) == 0 { + return []string{} + } + indexesToKeep := make([]int, 0, len(str)) + for i, s := range str { + if len(s) > 0 { + indexesToKeep = append(indexesToKeep, i) + } + } + r := make([]string, 0, len(indexesToKeep)) + for _, index := range indexesToKeep { + r = append(r, str[index]) + } + return r +} + +func removeDuplicates(str []string) []string { + if len(str) == 0 { + return []string{} + } + found := make(map[string]bool, len(str)) + indexesToKeep := make([]int, 0, len(str)) + for i, v := range str { + if !found[v] { + indexesToKeep = append(indexesToKeep, i) + found[v] = true + } + } + noDuplicates := make([]string, len(indexesToKeep)) + for i, index := range indexesToKeep { + noDuplicates[i] = str[index] + } + return noDuplicates +} diff --git a/macro_test.go b/macro_test.go new file mode 100644 index 0000000..e469399 --- /dev/null +++ b/macro_test.go @@ -0,0 +1,560 @@ +package milter + +import ( + "reflect" + "testing" + "time" +) + +func TestMacroBag_GetMacro(t *testing.T) { + tests := []struct { + name string + macros map[MacroName]string + arg MacroName + want string + }{ + {"QueueID", map[MacroName]string{MacroQueueId: "123"}, MacroQueueId, "123"}, + {"QueueID empty", map[MacroName]string{MacroAuthAuthen: "123"}, MacroQueueId, ""}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + m := &MacroBag{ + macros: ltt.macros, + } + if got := m.Get(ltt.arg); got != ltt.want { + t.Errorf("Get() = %v, want %v", got, ltt.want) + } + }) + } +} + +func TestMacroBag_GetMacroEx(t *testing.T) { + tests := []struct { + name string + macros map[MacroName]string + arg MacroName + wantValue string + wantOk bool + }{ + {"QueueID", map[MacroName]string{MacroQueueId: "123"}, MacroQueueId, "123", true}, + {"QueueID 2", map[MacroName]string{MacroAuthSsf: "456", MacroQueueId: "123"}, MacroQueueId, "123", true}, + {"QueueID empty", map[MacroName]string{MacroAuthAuthen: "123"}, MacroQueueId, "", false}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + m := &MacroBag{ + macros: ltt.macros, + } + gotValue, gotOk := m.GetEx(ltt.arg) + if gotValue != ltt.wantValue { + t.Errorf("GetEx() gotValue = %v, want %v", gotValue, ltt.wantValue) + } + if gotOk != ltt.wantOk { + t.Errorf("GetEx() gotOk = %v, want %v", gotOk, ltt.wantOk) + } + }) + } +} + +func TestMacroBag_GetMacroEx_Dates(t *testing.T) { + t.Parallel() + type dates struct { + current time.Time + header time.Time + } + date1 := time.Date(2023, time.January, 1, 1, 1, 1, 0, time.UTC) + tests := []struct { + name string + dates dates + macros map[MacroName]string + arg MacroName + wantValue string + wantOk bool + }{ + {"header: force set", dates{header: date1}, map[MacroName]string{MacroDateRFC822Origin: "123"}, MacroDateRFC822Origin, "123", true}, + {"header: set", dates{header: date1}, map[MacroName]string{}, MacroDateRFC822Origin, "01 Jan 23 01:01 +0000", true}, + {"header: not-set", dates{}, map[MacroName]string{}, MacroDateRFC822Origin, "", false}, + {"current: force set", dates{current: date1}, map[MacroName]string{MacroDateRFC822Current: "123"}, MacroDateRFC822Current, "123", true}, + {"current: set", dates{current: date1}, map[MacroName]string{}, MacroDateRFC822Current, "01 Jan 23 01:01 +0000", true}, + {"current: set seconds", dates{current: date1}, map[MacroName]string{}, MacroDateSecondsCurrent, "1672534861", true}, + {"current: set ANSI", dates{current: date1}, map[MacroName]string{}, MacroDateANSICCurrent, "Sun Jan 1 01:01:01 2023", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + m := &MacroBag{ + macros: ltt.macros, + } + m.SetHeaderDate(ltt.dates.header) + m.SetCurrentDate(ltt.dates.current) + gotValue, gotOk := m.GetEx(ltt.arg) + if gotValue != ltt.wantValue { + t.Errorf("GetEx() gotValue = %v, want %v", gotValue, ltt.wantValue) + } + if gotOk != ltt.wantOk { + t.Errorf("GetEx() gotOk = %v, want %v", gotOk, ltt.wantOk) + } + }) + } + t.Run("current: not-set", func(t *testing.T) { + m := &MacroBag{ + macros: map[MacroName]string{}, + } + gotValue, gotOk := m.GetEx(MacroDateRFC822Current) + if gotValue == "" { + t.Errorf("GetEx() gotValue = %v, want not empty", gotValue) + } + if gotOk != true { + t.Errorf("GetEx() gotOk = %v, want %v", gotOk, true) + } + }) +} + +func TestMacroBag_SetMacro(t *testing.T) { + type args struct { + name MacroName + value string + } + tests := []struct { + name string + macros map[MacroName]string + args args + }{ + {"Overwrite", map[MacroName]string{MacroQueueId: "123"}, args{MacroQueueId, "456"}}, + {"Set", map[MacroName]string{}, args{MacroQueueId, "456"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + m := &MacroBag{ + macros: ltt.macros, + } + m.Set(ltt.args.name, ltt.args.value) + if got := m.Get(ltt.args.name); got != ltt.args.value { + t.Errorf("Get() = %v, want %v", got, ltt.args.value) + } + }) + } +} + +func TestMacroBag_Copy(t *testing.T) { + type fields struct { + macros map[MacroName]string + currentDate time.Time + headerDate time.Time + } + tests := []struct { + name string + fields fields + want map[MacroName]string + }{ + {"empty", fields{}, map[MacroName]string{}}, + {"simple", fields{macros: map[MacroName]string{MacroQueueId: "123"}}, map[MacroName]string{MacroQueueId: "123"}}, + {"no-dates", fields{macros: map[MacroName]string{MacroQueueId: "123"}, headerDate: time.Now(), currentDate: time.Now()}, map[MacroName]string{MacroQueueId: "123"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + m := &MacroBag{ + macros: tt.fields.macros, + currentDate: tt.fields.currentDate, + headerDate: tt.fields.headerDate, + } + if got := m.Copy().macros; !reflect.DeepEqual(got, tt.want) { + t.Errorf("Copy() = %+v, want %+v", got, tt.want) + } + }) + } +} + +func TestMacroReader_Get(t *testing.T) { + tests := []struct { + name string + macrosStages *macrosStages + arg MacroName + want string + }{ + {"QueueID last", ¯osStages{[]map[MacroName]string{nil, nil, nil, nil, nil, nil, nil, {MacroQueueId: "123"}}}, MacroQueueId, "123"}, + {"QueueID first", ¯osStages{[]map[MacroName]string{{MacroQueueId: "123"}, nil, nil, nil, nil, nil, nil, nil}}, MacroQueueId, "123"}, + {"QueueID middle", ¯osStages{[]map[MacroName]string{nil, nil, nil, {MacroQueueId: "123"}, nil, nil, nil, nil}}, MacroQueueId, "123"}, + {"QueueID nil", ¯osStages{[]map[MacroName]string{nil, nil, nil, nil, nil, nil, nil, nil}}, MacroQueueId, ""}, + {"QueueID priority", ¯osStages{[]map[MacroName]string{{MacroQueueId: "456"}, nil, nil, nil, nil, nil, {MacroQueueId: "123"}, nil}}, MacroQueueId, "123"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + r := ¯oReader{ + macrosStages: ltt.macrosStages, + } + if got := r.Get(ltt.arg); got != ltt.want { + t.Errorf("Get() = %v, want %v", got, ltt.want) + } + }) + } +} + +func TestMacroReader_GetEx(t *testing.T) { + tests := []struct { + name string + macrosStages *macrosStages + arg MacroName + want string + want1 bool + }{ + {"QueueID last", ¯osStages{[]map[MacroName]string{nil, nil, nil, nil, nil, nil, nil, {MacroQueueId: "123"}}}, MacroQueueId, "123", true}, + {"QueueID first", ¯osStages{[]map[MacroName]string{{MacroQueueId: "123"}, nil, nil, nil, nil, nil, nil, nil}}, MacroQueueId, "123", true}, + {"QueueID middle", ¯osStages{[]map[MacroName]string{nil, nil, nil, {MacroQueueId: "123"}, nil, nil, nil, nil}}, MacroQueueId, "123", true}, + {"QueueID nil", ¯osStages{[]map[MacroName]string{nil, nil, nil, nil, nil, nil, nil, nil}}, MacroQueueId, "", false}, + {"QueueID priority", ¯osStages{[]map[MacroName]string{{MacroQueueId: "456"}, nil, nil, nil, nil, nil, {MacroQueueId: "123"}, nil}}, MacroQueueId, "123", true}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + r := ¯oReader{ + macrosStages: ltt.macrosStages, + } + got, got1 := r.GetEx(ltt.arg) + if got != ltt.want { + t.Errorf("GetEx() got = %v, want %v", got, ltt.want) + } + if got1 != ltt.want1 { + t.Errorf("GetEx() got1 = %v, want %v", got1, ltt.want1) + } + }) + } +} + +func Test_macrosStages_DelMacro(t *testing.T) { + type args struct { + stage MacroStage + name MacroName + } + tests := []struct { + name string + byStages []map[MacroName]string + args args + }{ + {"empty", []map[MacroName]string{nil, nil, nil, nil, nil, nil, nil, nil}, args{StageConnect, MacroQueueId}}, + {"simple", []map[MacroName]string{{MacroQueueId: "123"}, nil, nil, nil, nil, nil, nil, nil}, args{StageConnect, MacroQueueId}}, + {"multiple", []map[MacroName]string{{MacroQueueId: "123"}, {MacroQueueId: "123"}, {MacroQueueId: "123"}, {MacroQueueId: "123"}, nil, nil, nil, nil}, args{StageConnect, MacroQueueId}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + s := ¯osStages{ + byStages: tt.byStages, + } + s.DelMacro(ltt.args.stage, ltt.args.name) + if _, st := s.GetMacroEx(ltt.args.name); st == ltt.args.stage { + t.Errorf("DelMacro() did not delete %v in stage %v", ltt.args.name, ltt.args.stage) + } + }) + } +} + +func Test_macrosStages_DelStage(t *testing.T) { + tests := []struct { + name string + byStages []map[MacroName]string + stage MacroStage + }{ + {"noop", []map[MacroName]string{nil, nil, nil, nil, nil, nil, nil}, StageConnect}, + {"empty", []map[MacroName]string{{}, {}, {}, {}, {}, {}, {}}, StageConnect}, + {"connect", []map[MacroName]string{{MacroQueueId: "123"}, {}, {}, {}, {}, {}, {}}, StageConnect}, + {"helo", []map[MacroName]string{{}, {MacroQueueId: "123"}, {}, {}, {}, {}, {}}, StageHelo}, + {"mail", []map[MacroName]string{{}, {}, {MacroQueueId: "123"}, {}, {}, {}, {}}, StageMail}, + {"rcpt", []map[MacroName]string{{}, {}, {}, {MacroQueueId: "123"}, {}, {}, {}}, StageRcpt}, + {"data", []map[MacroName]string{{}, {}, {}, {}, {MacroQueueId: "123"}, {}, {}}, StageData}, + {"EOM", []map[MacroName]string{{}, {}, {}, {}, {}, {MacroQueueId: "123"}, {}}, StageEOM}, + {"EOH", []map[MacroName]string{{}, {}, {}, {}, {}, {}, {MacroQueueId: "123"}}, StageEOH}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + s := ¯osStages{ + byStages: ltt.byStages, + } + s.DelStage(ltt.stage) + if s.byStages[ltt.stage] != nil { + t.Errorf("DelStage() did not delete stage %v", ltt.stage) + } + }) + } +} + +func Test_macrosStages_DelStageAndAbove(t *testing.T) { + tests := []struct { + name string + byStages []map[MacroName]string + stage MacroStage + }{ + {"noop", []map[MacroName]string{nil, nil, nil, nil, nil, nil, nil, nil}, StageConnect}, + {"empty", []map[MacroName]string{{}, {}, {}, {}, {}, {}, {}, {}}, StageConnect}, + {"connect", []map[MacroName]string{{MacroQueueId: "123"}, {}, {}, {}, {}, {}, {}, {}}, StageConnect}, + {"helo", []map[MacroName]string{{}, {MacroQueueId: "123"}, {}, {}, {}, {}, {}, {}}, StageHelo}, + {"mail", []map[MacroName]string{{}, {}, {MacroQueueId: "123"}, {}, {}, {}, {}, {}}, StageMail}, + {"rcpt", []map[MacroName]string{{}, {}, {}, {MacroQueueId: "123"}, {}, {}, {}, {}}, StageRcpt}, + {"data", []map[MacroName]string{{}, {}, {}, {}, {MacroQueueId: "123"}, {}, {}, {}}, StageData}, + {"EOM", []map[MacroName]string{{}, {}, {}, {}, {}, {MacroQueueId: "123"}, {}, {}}, StageEOM}, + {"EOH", []map[MacroName]string{{}, {}, {}, {}, {}, {}, {MacroQueueId: "123"}, {}}, StageEOH}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + s := ¯osStages{ + byStages: ltt.byStages, + } + s.DelStageAndAbove(ltt.stage) + if ltt.stage == StageEOH { + if s.byStages[StageEOH] != nil { + t.Errorf("DelStageAndAbove() did not delete stage %v", StageEOH) + } + if s.byStages[StageEOM] != nil { + t.Errorf("DelStageAndAbove() did not delete stage %v", StageEOM) + } + } else if ltt.stage == StageEOM { + if s.byStages[StageEOM] != nil { + t.Errorf("DelStageAndAbove() did not delete stage %v", StageEOM) + } + } else { + for st := ltt.stage; st < StageEndMarker; st += 1 { + if s.byStages[st] != nil { + t.Errorf("DelStageAndAbove() did not delete stage %v", st) + } + } + } + }) + } +} + +func Test_macrosStages_GetMacroEx(t *testing.T) { + type fields struct { + byStages []map[MacroName]string + } + type args struct { + name MacroName + } + tests := []struct { + name string + fields fields + args args + wantValue string + wantStageFound MacroStage + }{ + {"empty", fields{[]map[MacroName]string{nil, nil, nil, nil, nil, nil, nil, nil}}, args{MacroQueueId}, "", StageNotFoundMarker}, + {"first", fields{[]map[MacroName]string{{MacroQueueId: "123"}, nil, nil, nil, nil, nil, nil, nil}}, args{MacroQueueId}, "123", StageConnect}, + {"last", fields{[]map[MacroName]string{nil, nil, nil, nil, nil, nil, nil, {MacroQueueId: "123"}}}, args{MacroQueueId}, "123", StageEndMarker}, + {"last1", fields{[]map[MacroName]string{{MacroQueueId: "123"}, nil, nil, nil, nil, nil, nil, {MacroQueueId: "123"}}}, args{MacroQueueId}, "123", StageEndMarker}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + s := ¯osStages{ + byStages: ltt.fields.byStages, + } + gotValue, gotStageFound := s.GetMacroEx(ltt.args.name) + if gotValue != ltt.wantValue { + t.Errorf("GetEx() gotValue = %v, want %v", gotValue, ltt.wantValue) + } + if gotStageFound != ltt.wantStageFound { + t.Errorf("GetEx() gotStageFound = %v, want %v", gotStageFound, ltt.wantStageFound) + } + }) + } +} + +func Test_macrosStages_SetMacro(t *testing.T) { + type fields struct { + byStages []map[MacroName]string + } + type args struct { + stage MacroStage + name MacroName + value string + } + tests := []struct { + name string + fields fields + args args + }{ + {"nil", fields{[]map[MacroName]string{nil, nil, nil, nil, nil, nil, nil}}, args{StageConnect, MacroQueueId, "123"}}, + {"empty", fields{[]map[MacroName]string{{}, nil, nil, nil, nil, nil, nil}}, args{StageConnect, MacroQueueId, "123"}}, + {"overwrite", fields{[]map[MacroName]string{{MacroQueueId: "456"}, nil, nil, nil, nil, nil, nil}}, args{StageConnect, MacroQueueId, "123"}}, + {"last", fields{[]map[MacroName]string{{}, nil, nil, nil, nil, nil, {}}}, args{StageEOM, MacroQueueId, "123"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + s := ¯osStages{ + byStages: ltt.fields.byStages, + } + s.SetMacro(ltt.args.stage, ltt.args.name, ltt.args.value) + if s.byStages[ltt.args.stage][ltt.args.name] != ltt.args.value { + t.Errorf("Set() did not set the correct value = %v, want %v", s.byStages[ltt.args.stage][ltt.args.name], ltt.args.value) + } + }) + } +} + +func Test_macrosStages_SetStage(t *testing.T) { + type fields struct { + byStages []map[MacroName]string + } + type args struct { + stage MacroStage + kv []string + } + tests := []struct { + name string + fields fields + args args + wants map[MacroName]string + }{ + {"empty", fields{[]map[MacroName]string{nil, nil, nil, nil, nil, nil, nil}}, args{StageConnect, []string{}}, map[MacroName]string{}}, + {"simple nil", fields{[]map[MacroName]string{nil, nil, nil, nil, nil, nil, nil}}, args{StageConnect, []string{MacroQueueId, "123"}}, map[MacroName]string{MacroQueueId: "123"}}, + {"simple empty", fields{[]map[MacroName]string{{}, {}, {}, {}, {}, {}, {}}}, args{StageConnect, []string{MacroQueueId, "123"}}, map[MacroName]string{MacroQueueId: "123"}}, + {"multiple", fields{[]map[MacroName]string{{}, {}, {}, {}, {}, {}, {}}}, args{StageConnect, []string{MacroQueueId, "123", MacroAuthAuthen, "123"}}, map[MacroName]string{MacroQueueId: "123", MacroAuthAuthen: "123"}}, + {"overwrite", fields{[]map[MacroName]string{{MacroAuthAuthen: "123"}, {}, {}, {}, {}, {}, {}}}, args{StageConnect, []string{MacroQueueId, "123"}}, map[MacroName]string{MacroQueueId: "123"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + s := ¯osStages{ + byStages: ltt.fields.byStages, + } + s.SetStage(ltt.args.stage, ltt.args.kv...) + if !reflect.DeepEqual(s.byStages[ltt.args.stage], ltt.wants) { + t.Errorf("SetStage() result = %v, want %v", s.byStages[ltt.args.stage], ltt.wants) + } + }) + } +} + +func Test_macrosStages_SetStageMap(t *testing.T) { + type fields struct { + byStages []map[MacroName]string + } + type args struct { + stage MacroStage + kv map[MacroName]string + } + tests := []struct { + name string + fields fields + args args + wants map[MacroName]string + }{ + {"empty", fields{[]map[MacroName]string{nil, nil, nil, nil, nil, nil, nil}}, args{StageConnect, map[MacroName]string{}}, map[MacroName]string{}}, + {"simple nil", fields{[]map[MacroName]string{nil, nil, nil, nil, nil, nil, nil}}, args{StageConnect, map[MacroName]string{MacroQueueId: "123"}}, map[MacroName]string{MacroQueueId: "123"}}, + {"simple empty", fields{[]map[MacroName]string{{}, {}, {}, {}, {}, {}, {}}}, args{StageConnect, map[MacroName]string{MacroQueueId: "123"}}, map[MacroName]string{MacroQueueId: "123"}}, + {"multiple", fields{[]map[MacroName]string{{}, {}, {}, {}, {}, {}, {}}}, args{StageConnect, map[MacroName]string{MacroQueueId: "123", MacroAuthAuthen: "123"}}, map[MacroName]string{MacroQueueId: "123", MacroAuthAuthen: "123"}}, + {"overwrite", fields{[]map[MacroName]string{{MacroAuthAuthen: "123"}, {}, {}, {}, {}, {}, {}}}, args{StageConnect, map[MacroName]string{MacroQueueId: "123"}}, map[MacroName]string{MacroQueueId: "123"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + s := ¯osStages{ + byStages: ltt.fields.byStages, + } + s.SetStageMap(ltt.args.stage, ltt.args.kv) + if !reflect.DeepEqual(s.byStages[ltt.args.stage], ltt.wants) { + t.Errorf("SetStageMap() result = %v, want %v", s.byStages[ltt.args.stage], ltt.wants) + } + }) + } +} + +func Test_newMacroStages(t *testing.T) { + t.Parallel() + got := newMacroStages() + if len(got.byStages) != int(StageEndMarker)+1 { + t.Errorf("newMacroStages() len(byStages) = %d, want %d", len(got.byStages)+1, StageEndMarker) + } +} + +func Test_parseRequestedMacros(t *testing.T) { + tests := []struct { + name string + str string + want []string + }{ + {"empty", "", []string{}}, + {"spaces", "   \t,,", []string{}}, + {"single", "{auth_authen}", []string{"{auth_authen}"}}, + {"single2", " {auth_authen}, ", []string{"{auth_authen}"}}, + {"multiple", " {auth_authen}, {auth_authen} j ", []string{"{auth_authen}", "{auth_authen}", "j"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + if got := parseRequestedMacros(ltt.str); !reflect.DeepEqual(got, ltt.want) { + t.Errorf("parseRequestedMacros() = %v, want %v", got, ltt.want) + } + }) + } +} + +func Test_removeDuplicates(t *testing.T) { + tests := []struct { + name string + str []string + want []string + }{ + {"empty", []string{}, []string{}}, + {"nil", nil, []string{}}, + {"beginning", []string{"a", "a", "b"}, []string{"a", "b"}}, + {"end", []string{"a", "b", "b"}, []string{"a", "b"}}, + {"single", []string{"a"}, []string{"a"}}, + {"multiple", []string{"a", "b", "a", "a"}, []string{"a", "b"}}, + {"multiple2", []string{"b", "a", "b", "a", "a"}, []string{"b", "a"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + if got := removeDuplicates(ltt.str); !reflect.DeepEqual(got, ltt.want) { + t.Errorf("removeDuplicates() = %v, want %v", got, ltt.want) + } + }) + } +} + +func Test_removeEmpty(t *testing.T) { + tests := []struct { + name string + str []string + want []string + }{ + {"empty", []string{}, []string{}}, + {"nil", nil, []string{}}, + {"beginning", []string{"", "a", "b"}, []string{"a", "b"}}, + {"end", []string{"a", "b", ""}, []string{"a", "b"}}, + {"single", []string{""}, []string{}}, + {"multiple", []string{"a", "", "b", ""}, []string{"a", "b"}}, + {"multiple2", []string{"", "", "b", "a", ""}, []string{"b", "a"}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ltt := tt + t.Parallel() + if got := removeEmpty(ltt.str); !reflect.DeepEqual(got, ltt.want) { + t.Errorf("removeEmpty() = %v, want %v", got, ltt.want) + } + }) + } +} diff --git a/message.go b/message.go deleted file mode 100644 index fb3eef6..0000000 --- a/message.go +++ /dev/null @@ -1,68 +0,0 @@ -package milter - -// Message represents a command sent from milter client -type Message struct { - Code byte - Data []byte -} - -type ActionCode byte - -const ( - ActAccept ActionCode = 'a' // SMFIR_ACCEPT - ActContinue ActionCode = 'c' // SMFIR_CONTINUE - ActDiscard ActionCode = 'd' // SMFIR_DISCARD - ActReject ActionCode = 'r' // SMFIR_REJECT - ActTempFail ActionCode = 't' // SMFIR_TEMPFAIL - ActReplyCode ActionCode = 'y' // SMFIR_REPLYCODE - - // [v6] - ActSkip ActionCode = 's' // SMFIR_SKIP -) - -type ModifyActCode byte - -const ( - ActAddRcpt ModifyActCode = '+' // SMFIR_ADDRCPT - ActDelRcpt ModifyActCode = '-' // SMFIR_DELRCPT - ActReplBody ModifyActCode = 'b' // SMFIR_ACCEPT - ActAddHeader ModifyActCode = 'h' // SMFIR_ADDHEADER - ActChangeHeader ModifyActCode = 'm' // SMFIR_CHGHEADER - ActInsertHeader ModifyActCode = 'i' // SMFIR_INSHEADER - ActQuarantine ModifyActCode = 'q' // SMFIR_QUARANTINE - - // [v6] - ActChangeFrom ModifyActCode = 'e' // SMFIR_CHGFROM -) - -type Code byte - -const ( - CodeOptNeg Code = 'O' // SMFIC_OPTNEG - CodeMacro Code = 'D' // SMFIC_MACRO - CodeConn Code = 'C' // SMFIC_CONNECT - CodeQuit Code = 'Q' // SMFIC_QUIT - CodeHelo Code = 'H' // SMFIC_HELO - CodeMail Code = 'M' // SMFIC_MAIL - CodeRcpt Code = 'R' // SMFIC_RCPT - CodeHeader Code = 'L' // SMFIC_HEADER - CodeEOH Code = 'N' // SMFIC_EOH - CodeBody Code = 'B' // SMFIC_BODY - CodeEOB Code = 'E' // SMFIC_BODYEOB - CodeAbort Code = 'A' // SMFIC_ABORT - CodeData Code = 'T' // SMFIC_DATA - - // [v6] - CodeQuitNewConn Code = 'K' // SMFIC_QUIT_NC -) - -const MaxBodyChunk = 65535 - -type ProtoFamily byte - -const ( - FamilyUnknown ProtoFamily = 'U' // SMFIA_UNKNOWN - FamilyUnix ProtoFamily = 'L' // SMFIA_UNIX - FamilyInet ProtoFamily = '4' // SMFIA_INET - FamilyInet6 ProtoFamily = '6' // SMFIA_INET6 -) diff --git a/milter.go b/milter.go index b48c076..ac47771 100644 --- a/milter.go +++ b/milter.go @@ -1,4 +1,4 @@ -// Package milter provides an interface to implement milter mail filters +// Package milter provides an interface to implement milter mail filters and MTAs that can talk to milter programs. package milter // OptAction sets which actions the milter wants to perform. @@ -7,51 +7,95 @@ type OptAction uint32 // Set which actions the milter wants to perform. const ( - OptAddHeader OptAction = 1 << 0 // SMFIF_ADDHDRS - OptChangeBody OptAction = 1 << 1 // SMFIF_CHGBODY - OptAddRcpt OptAction = 1 << 2 // SMFIF_ADDRCPT - OptRemoveRcpt OptAction = 1 << 3 // SMFIF_DELRCPT - OptChangeHeader OptAction = 1 << 4 // SMFIF_CHGHDRS - OptQuarantine OptAction = 1 << 5 // SMFIF_QUARANTINE - - // [v6] - OptChangeFrom OptAction = 1 << 6 // SMFIF_CHGFROM - OptAddRcptWithArgs OptAction = 1 << 7 // SMFIF_ADDRCPT_PAR - OptSetSymList OptAction = 1 << 8 // SMFIF_SETSYMLIST + OptAddHeader OptAction = 1 << 0 // SMFIF_ADDHDRS + OptChangeBody OptAction = 1 << 1 // SMFIF_CHGBODY / SMFIF_MODBODY + OptAddRcpt OptAction = 1 << 2 // SMFIF_ADDRCPT + OptRemoveRcpt OptAction = 1 << 3 // SMFIF_DELRCPT + OptChangeHeader OptAction = 1 << 4 // SMFIF_CHGHDRS + OptQuarantine OptAction = 1 << 5 // SMFIF_QUARANTINE + OptChangeFrom OptAction = 1 << 6 // SMFIF_CHGFROM [v6] + OptAddRcptWithArgs OptAction = 1 << 7 // SMFIF_ADDRCPT_PAR [v6] + OptSetMacros OptAction = 1 << 8 // SMFIF_SETSYMLIST [v6] ) // OptProtocol masks out unwanted parts of the SMTP transaction. // Multiple options can be set using a bitmask. type OptProtocol uint32 +// The options that the milter can send to the MTA during negotiation to tailor the communication. const ( - OptNoConnect OptProtocol = 1 << 0 // SMFIP_NOCONNECT - OptNoHelo OptProtocol = 1 << 1 // SMFIP_NOHELO - OptNoMailFrom OptProtocol = 1 << 2 // SMFIP_NOMAIL - OptNoRcptTo OptProtocol = 1 << 3 // SMFIP_NORCPT - OptNoBody OptProtocol = 1 << 4 // SMFIP_NOBODY - OptNoHeaders OptProtocol = 1 << 5 // SMFIP_NOHDRS - OptNoEOH OptProtocol = 1 << 6 // SMFIP_NOEOH - OptNoUnknown OptProtocol = 1 << 8 // SMFIP_NOUNKNOWN - OptNoData OptProtocol = 1 << 9 // SMFIP_NODATA - - // [v6] MTA supports ActSkip - OptSkip OptProtocol = 1 << 10 // SMFIP_SKIP - // [v6] Filter wants rejected RCPTs - OptRcptRej OptProtocol = 1 << 11 // SMFIP_RCPT_REJ - - // Milter will not send action response for the following MTA messages - OptNoHeaderReply OptProtocol = 1 << 7 // SMFIP_NR_HDR, SMFIP_NOHREPL - // [v6] - OptNoConnReply OptProtocol = 1 << 12 // SMFIP_NR_CONN - OptNoHeloReply OptProtocol = 1 << 13 // SMFIP_NR_HELO - OptNoMailReply OptProtocol = 1 << 14 // SMFIP_NR_MAIL - OptNoRcptReply OptProtocol = 1 << 15 // SMFIP_NR_RCPT - OptNoDataReply OptProtocol = 1 << 16 // SMFIP_NR_DATA - OptNoUnknownReply OptProtocol = 1 << 17 // SMFIP_NR_UNKN - OptNoEOHReply OptProtocol = 1 << 18 // SMFIP_NR_EOH - OptNoBodyReply OptProtocol = 1 << 19 // SMFIP_NR_BODY - - // [v6] - OptHeaderLeadingSpace OptProtocol = 1 << 20 // SMFIP_HDR_LEADSPC + OptNoConnect OptProtocol = 1 << 0 // MTA does not send connect events. SMFIP_NOCONNECT + OptNoHelo OptProtocol = 1 << 1 // MTA does not send HELO/EHLO events. SMFIP_NOHELO + OptNoMailFrom OptProtocol = 1 << 2 // MTA does not send MAIL FROM events. SMFIP_NOMAIL + OptNoRcptTo OptProtocol = 1 << 3 // MTA does not send RCPT TO events. SMFIP_NORCPT + OptNoBody OptProtocol = 1 << 4 // MTA does not send message body data. SMFIP_NOBODY + OptNoHeaders OptProtocol = 1 << 5 // MTA does not send message header data. SMFIP_NOHDRS + OptNoEOH OptProtocol = 1 << 6 // MTA does not send end of header indication event. SMFIP_NOEOH + OptNoHeaderReply OptProtocol = 1 << 7 // Milter does not send a reply to header data. SMFIP_NR_HDR, SMFIP_NOHREPL + OptNoUnknown OptProtocol = 1 << 8 // MTA does not send unknown SMTP command events. SMFIP_NOUNKNOWN + OptNoData OptProtocol = 1 << 9 // MTA does not send the DATA start event. SMFIP_NODATA + OptSkip OptProtocol = 1 << 10 // MTA supports ActSkip. SMFIP_SKIP [v6] + OptRcptRej OptProtocol = 1 << 11 // Filter wants rejected RCPTs. SMFIP_RCPT_REJ [v6] + OptNoConnReply OptProtocol = 1 << 12 // Milter does not send a reply to connection event. SMFIP_NR_CONN [v6] + OptNoHeloReply OptProtocol = 1 << 13 // Milter does not send a reply to HELO/EHLO event. SMFIP_NR_HELO [v6] + OptNoMailReply OptProtocol = 1 << 14 // Milter does not send a reply to MAIL FROM event. SMFIP_NR_MAIL [v6] + OptNoRcptReply OptProtocol = 1 << 15 // Milter does not send a reply to RCPT TO event. SMFIP_NR_RCPT [v6] + OptNoDataReply OptProtocol = 1 << 16 // Milter does not send a reply to DATA start event. SMFIP_NR_DATA [v6] + OptNoUnknownReply OptProtocol = 1 << 17 // Milter does not send a reply to unknown command event. SMFIP_NR_UNKN [v6] + OptNoEOHReply OptProtocol = 1 << 18 // Milter does not send a reply to end of header event. SMFIP_NR_EOH [v6] + OptNoBodyReply OptProtocol = 1 << 19 // Milter does not send a reply to body chunk event. SMFIP_NR_BODY [v6] + + // OptHeaderLeadingSpace lets the Milter request that the MTA does not swallow a leading space + // when passing the header value to the milter. + // Sendmail by default eats one space (not tab) after the colon. So the header line (spaces replaced with _): + // Subject:__Test + // gets transferred as HeaderName "Subject" and HeaderValue "_Test". If the milter + // sends OptHeaderLeadingSpace to the MTA it requests that it wants the header value as is. + // So the MTA should send HeaderName "Subject" and HeaderValue "__Test". + // + // Milter that do e.g. DKIM signing may need the additional space to create valid DKIM signatures. + // + // The Client and ClientSession does not handle this option. It is the responsibility of the MTA to check if the milter + // asked for this and obey this request. In the simplest case just never swallow the space. + // + // SMFIP_HDR_LEADSPC [v6] + OptHeaderLeadingSpace OptProtocol = 1 << 20 +) + +const ( + // OptNoReplies combines all protocol flags that define that your milter does not send a reply + // to the MTA. Use this if your [Milter] only decides at the [Milter.EndOfMessage] handler if the + // email is acceptable or needs to be rejected. + OptNoReplies OptProtocol = OptNoHeaderReply | OptNoConnReply | OptNoHeloReply | OptNoMailReply | OptNoRcptReply | OptNoDataReply | OptNoUnknownReply | OptNoEOHReply | OptNoBodyReply +) + +const ( + optMds256K uint32 = 1 << 28 // SMFIP_MDS_256K + optMds1M uint32 = 1 << 29 // SMFIP_MDS_1M + optInternal = optMds256K | optMds1M | 1<<30 // internal flags: only used between MTA and libmilter (bit 28,29,30). SMFI_INTERNAL + optV2 uint32 = 0x0000007F // All flags that v2 defined (bit 0, 1, 2, 3, 4, 5, 6). SMFI_V2_PROT +) + +// DataSize defines the maximum data size for milter or MTA to use. +// +// The DataSize does not include the one byte for the command byte. +// Only three sizes are defined in the milter protocol. +type DataSize uint32 + +const ( + // DataSize64K is 64KB - 1 byte (command-byte). This is the default buffer size. + DataSize64K DataSize = 1024*64 - 1 + // DataSize256K is 256KB - 1 byte (command-byte) + DataSize256K DataSize = 1024*256 - 1 + // DataSize1M is 1MB - 1 byte (command-byte) + DataSize1M DataSize = 1024*1024 - 1 +) + +type ProtoFamily byte + +const ( + FamilyUnknown ProtoFamily = 'U' // SMFIA_UNKNOWN + FamilyUnix ProtoFamily = 'L' // SMFIA_UNIX + FamilyInet ProtoFamily = '4' // SMFIA_INET + FamilyInet6 ProtoFamily = '6' // SMFIA_INET6 ) diff --git a/milterutil/buffer.go b/milterutil/buffer.go new file mode 100644 index 0000000..c7f7a4b --- /dev/null +++ b/milterutil/buffer.go @@ -0,0 +1,110 @@ +// Package milterutil includes utility functions and types that might be useful for writing milters or MTAs. +package milterutil + +import ( + "bufio" + "io" + "sync" +) + +// FixedBufferScanner is a wrapper around a bufio.Scanner that produces fixed size chunks of data +// given an io.Reader. +type FixedBufferScanner struct { + bufferSize uint32 + buffer []byte + scanner *bufio.Scanner + pool *sync.Pool +} + +func (f *FixedBufferScanner) init(pool *sync.Pool, r io.Reader) { + var bufSize = int(f.bufferSize) + f.pool = pool + f.scanner = bufio.NewScanner(r) + f.scanner.Buffer(f.buffer, bufSize) + f.scanner.Split(func(data []byte, atEOF bool) (advance int, token []byte, err error) { + if atEOF && len(data) == 0 { + return 0, nil, nil + } + // buffer full? Return it. + if len(data) >= bufSize { + return bufSize, data[0:bufSize], nil + } + // If we're at EOF, return the rest even if it is less than bufSize + if atEOF { + return len(data), data, nil + } + // Request more data. + return 0, nil, nil + }) +} + +// Scan returns true when there is new data in Bytes +func (f *FixedBufferScanner) Scan() bool { + return f.scanner.Scan() +} + +// Bytes returns the current chunk of data +func (f *FixedBufferScanner) Bytes() []byte { + return f.scanner.Bytes() +} + +// Err returns the first non-EOF error that was encountered by the FixedBufferScanner. +func (f *FixedBufferScanner) Err() error { + return f.scanner.Err() +} + +// Close need to be called when you are done with the FixedBufferScanner because we maintain a shared pool +// of FixedBufferScanner. +// +// Close does not close the underlying io.Reader. It is the responsibility of the caller to do this. +func (f *FixedBufferScanner) Close() { + f.pool.Put(f) +} + +var fixedBufferPoolsMap map[uint32]*sync.Pool +var fixedBufferPoolsMapMutex sync.RWMutex +var fixedBufferPoolsMapInit sync.Once + +func newFixedBufferScannerPool(bufferSize uint32) *sync.Pool { + return &sync.Pool{New: func() interface{} { + return &FixedBufferScanner{bufferSize: bufferSize, buffer: make([]byte, bufferSize)} + }} +} + +func initFixedBufferPoolsMap() { + fixedBufferPoolsMapMutex.Lock() + fixedBufferPoolsMap = make(map[uint32]*sync.Pool) + // pre-initialize the buffers that the milter library might request + fixedBufferPoolsMap[1024*64-1] = newFixedBufferScannerPool(1024*64 - 1) + fixedBufferPoolsMap[1024*256-1] = newFixedBufferScannerPool(1024*256 - 1) + fixedBufferPoolsMap[1024*1024-1] = newFixedBufferScannerPool(1024*1024 - 1) + fixedBufferPoolsMapMutex.Unlock() +} + +// GetFixedBufferScanner returns a FixedBufferScanner of size bufferSize that is configured to read from r. +// +// It is the responsibility of the caller to close r. +// +// If the caller is done with the returned FixedBufferScanner its Close method should be called to release +// it to the shared pool of FixedBufferScanners. +func GetFixedBufferScanner(bufferSize uint32, r io.Reader) *FixedBufferScanner { + fixedBufferPoolsMapInit.Do(initFixedBufferPoolsMap) + // try with read lock first + fixedBufferPoolsMapMutex.RLock() + pool := fixedBufferPoolsMap[bufferSize] + fixedBufferPoolsMapMutex.RUnlock() + if pool == nil { + // no luck, then get write lock + fixedBufferPoolsMapMutex.Lock() + // re-check the existence of pool + if pool = fixedBufferPoolsMap[bufferSize]; pool == nil { + // create pool in write lock + pool = newFixedBufferScannerPool(bufferSize) + fixedBufferPoolsMap[bufferSize] = pool + } + fixedBufferPoolsMapMutex.Unlock() + } + buffer := pool.Get().(*FixedBufferScanner) + buffer.init(pool, r) + return buffer +} diff --git a/milterutil/buffer_test.go b/milterutil/buffer_test.go new file mode 100644 index 0000000..5c7e6fa --- /dev/null +++ b/milterutil/buffer_test.go @@ -0,0 +1,115 @@ +package milterutil_test + +import ( + "io" + "reflect" + "testing" + + "github.com/d--j/go-milter" + "github.com/d--j/go-milter/milterutil" +) + +func TestFixedBufferScanner(t *testing.T) { + t.Parallel() + type args struct { + bufferSize uint32 + inputs []string + } + tests := []struct { + name string + args args + want []string + wantErr bool + }{ + {"empty", args{uint32(milter.DataSize64K), []string{}}, nil, false}, + {"short", args{10, []string{"12345"}}, []string{"12345"}, false}, + {"two-in-one", args{10, []string{"12345678901234567890"}}, []string{"1234567890", "1234567890"}, false}, + {"two-in-three", args{10, []string{"12345", "678901", "234567890"}}, []string{"1234567890", "1234567890"}, false}, + {"one-and-half", args{10, []string{"12345", "678901", "2345"}}, []string{"1234567890", "12345"}, false}, + } + for _, tt_ := range tests { + t.Run(tt_.name, func(t *testing.T) { + tt := tt_ + t.Parallel() + r, w := io.Pipe() + go func() { + for _, s := range tt.args.inputs { + if _, err := w.Write([]byte(s)); err != nil { + _ = w.CloseWithError(err) + return + } + } + _ = w.Close() + }() + f := milterutil.GetFixedBufferScanner(tt.args.bufferSize, r) + defer f.Close() + var got []string + for f.Scan() { + got = append(got, string(f.Bytes())) + } + if (f.Err() != nil) != tt.wantErr { + t.Fatalf("error = %v, wantErr %v", f.Err(), tt.wantErr) + } + if tt.wantErr { + return + } + if !reflect.DeepEqual(got, tt.want) { + t.Fatalf("got %v, want %v", got, tt.want) + } + }) + } +} + +func doFixedBufferScannerBenchmark(b *testing.B, bufferSize uint32, writeSize int, writeCount int) { + buff := make([]byte, writeSize) + b.ResetTimer() + b.RunParallel(func(pb *testing.PB) { + for pb.Next() { + r, w := io.Pipe() + go func() { + for i := 0; i < writeCount; i++ { + if _, err := w.Write(buff); err != nil { + w.CloseWithError(err) + return + } + } + w.Close() + }() + scanner := milterutil.GetFixedBufferScanner(bufferSize, r) + for scanner.Scan() { + } + if scanner.Err() != nil { + scanner.Close() + b.Fatal(scanner.Err()) + } + scanner.Close() + b.SetBytes(int64(writeSize * writeCount)) + } + }) +} + +func BenchmarkGetFixedBufferScanner_64K_1K_4096(b *testing.B) { + doFixedBufferScannerBenchmark(b, uint32(milter.DataSize64K), 1024, 4096) +} +func BenchmarkGetFixedBufferScanner_64K_4K_1024(b *testing.B) { + doFixedBufferScannerBenchmark(b, uint32(milter.DataSize64K), 4096, 1024) +} +func BenchmarkGetFixedBufferScanner_64K_8K_512(b *testing.B) { + doFixedBufferScannerBenchmark(b, uint32(milter.DataSize64K), 8192, 512) +} +func BenchmarkGetFixedBufferScanner_64K_32K_128(b *testing.B) { + doFixedBufferScannerBenchmark(b, uint32(milter.DataSize64K), 32*1024, 128) +} + +func BenchmarkGetFixedBufferScanner_1M_1K_4096(b *testing.B) { + doFixedBufferScannerBenchmark(b, uint32(milter.DataSize1M), 1024, 4096) +} +func BenchmarkGetFixedBufferScanner_1M_4K_1024(b *testing.B) { + doFixedBufferScannerBenchmark(b, uint32(milter.DataSize1M), 4096, 1024) +} +func BenchmarkGetFixedBufferScanner_1M_8K_512(b *testing.B) { + doFixedBufferScannerBenchmark(b, uint32(milter.DataSize1M), 8192, 512) +} +func BenchmarkGetFixedBufferScanner_1M_32K_128(b *testing.B) { + doFixedBufferScannerBenchmark(b, uint32(milter.DataSize1M), 32*1024, 128) +} diff --git a/milterutil/testdata/fuzz/FuzzSMTPReplyTransformer_Transform/05722b2c67c115953ab46e2031ae1863a8595905a08a9ff47a9fa7e2cd3d7810 b/milterutil/testdata/fuzz/FuzzSMTPReplyTransformer_Transform/05722b2c67c115953ab46e2031ae1863a8595905a08a9ff47a9fa7e2cd3d7810 new file mode 100644 index 0000000..3d3395c --- /dev/null +++ b/milterutil/testdata/fuzz/FuzzSMTPReplyTransformer_Transform/05722b2c67c115953ab46e2031ae1863a8595905a08a9ff47a9fa7e2cd3d7810 @@ -0,0 +1,4 @@ +go test fuzz v1 +[]byte("oneeeee\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc̻\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xbb\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xcc\xccee\rtwo") +[]byte("") +bool(true) diff --git a/milterutil/transformer.go b/milterutil/transformer.go new file mode 100644 index 0000000..f888bce --- /dev/null +++ b/milterutil/transformer.go @@ -0,0 +1,334 @@ +package milterutil + +import ( + "errors" + "fmt" + "unicode/utf8" + + "golang.org/x/text/transform" +) + +const cr = '\r' +const lf = '\n' + +// CrLfToLfTransformer is a [transform.Transformer] that replaces all CR LF and single CR in src to LF in dst. +type CrLfToLfTransformer struct { + prevCR bool +} + +func (t *CrLfToLfTransformer) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) { + for nDst < len(dst) && nSrc < len(src) { + c := src[nSrc] + if c == lf { + if t.prevCR { + nSrc++ + dst[nDst-1] = lf + t.prevCR = false + continue + } + } + t.prevCR = c == cr + if t.prevCR { + c = lf + } + dst[nDst] = c + nDst++ + nSrc++ + } + if nSrc < len(src) { // should never happen since we do not add data, but let's be safe + err = transform.ErrShortDst + } + // if the last char in src is cr then there might be a lf coming + if err == nil && !atEOF && len(src) > 0 && src[len(src)-1] == cr { + err = transform.ErrShortSrc + nSrc-- + nDst-- + return + } + return +} + +func (t *CrLfToLfTransformer) Reset() { + t.prevCR = false +} + +var _ transform.Transformer = &CrLfToLfTransformer{} + +// CrLfToLf is a helper that uses [CrLfToLfTransformer] to replace all line endings to only LF. +// +// postfix wants LF lines endings for header values. Using CRLF results in double CR sequences. +func CrLfToLf(s string) string { + dst, _, err := transform.String(&CrLfToLfTransformer{}, s) + if err != nil { + panic(err) + } + return dst +} + +// CrLfCanonicalizationTransformer is a [transform.Transformer] that replaces line endings in src with CR LF line endings in dst. +type CrLfCanonicalizationTransformer struct { + prev byte +} + +func (t *CrLfCanonicalizationTransformer) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) { + for nDst < len(dst) && nSrc < len(src) { + c := src[nSrc] + if c == lf { + if t.prev != cr { + if len(dst) <= nDst+1 { + err = transform.ErrShortDst + return + } + dst[nDst] = cr + nDst++ + } + } else if c == cr { + if !atEOF && len(src) <= nSrc+1 { + err = transform.ErrShortSrc + return + } + if (atEOF && len(src) == nSrc+1) || src[nSrc+1] != lf { + if len(dst) <= nDst+1 { + err = transform.ErrShortDst + return + } + dst[nDst] = c + nDst++ + c = lf + } + } + dst[nDst] = c + nDst++ + nSrc++ + t.prev = c + } + if nSrc < len(src) { + err = transform.ErrShortDst + } + return +} + +func (t *CrLfCanonicalizationTransformer) Reset() { + t.prev = 0 +} + +var _ transform.Transformer = &CrLfCanonicalizationTransformer{} + +// DoublePercentTransformer is a [transform.Transformer] that replaces all % in src with %% in dst. +type DoublePercentTransformer struct { + transform.NopResetter +} + +func (t *DoublePercentTransformer) Transform(dst, src []byte, _ bool) (nDst, nSrc int, err error) { + for nDst < len(dst) && nSrc < len(src) { + c := src[nSrc] + if c == '%' { + if len(dst) <= nDst+1 { + err = transform.ErrShortDst + return + } + dst[nDst] = c + nDst++ + } + dst[nDst] = c + nDst++ + nSrc++ + } + if nSrc < len(src) { + err = transform.ErrShortDst + } + return +} + +var _ transform.Transformer = &DoublePercentTransformer{} + +// SkipDoublePercentTransformer is a [transform.Transformer] that replaces all %% in src to % in dst. +// Single % signs are left as-is. +type SkipDoublePercentTransformer struct { + prevPercent bool + prevDoublePercent bool +} + +func (t *SkipDoublePercentTransformer) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) { + for nDst < len(dst) && nSrc < len(src) { + c := src[nSrc] + if c == '%' { + if t.prevPercent && !t.prevDoublePercent { + t.prevDoublePercent = true + nSrc++ + continue + } + } + t.prevPercent = c == '%' + t.prevDoublePercent = false + dst[nDst] = c + nDst++ + nSrc++ + } + if nSrc < len(src) { // should never happen since we do not add data, but let's be safe + err = transform.ErrShortDst + } + // if the last char in src is a lonely % then there might be a % coming + if err == nil && !atEOF && len(src) > 0 && t.prevPercent && !t.prevDoublePercent { + err = transform.ErrShortSrc + t.prevPercent = false + nSrc-- + nDst-- + return + } + return +} + +func (t *SkipDoublePercentTransformer) Reset() { + t.prevPercent = false + t.prevDoublePercent = false +} + +var _ transform.Transformer = &SkipDoublePercentTransformer{} + +// SMTPReplyTransformer is a [transform.Transformer] that reads src and produces a valid SMTP response (including multi-line handling) +// +// This transformer does not handle CR LF canonicalization, but it needs src to be properly encoded in this way. +// +// When you combine this Transformer in a [transform.Chain] it can only handle lines with a maximum of 128 bytes. +type SMTPReplyTransformer struct { + Code uint16 + init bool +} + +var errStartWithLF = errors.New("SMTP reply cannot start with LF") + +func (t *SMTPReplyTransformer) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) { + if !t.init && (t.Code < 100 || t.Code > 599) { + return 0, 0, fmt.Errorf("milter: %d is not a valid SMTP code", t.Code) + } + // special case: empty string + if atEOF && !t.init && len(src) == 0 { + if len(dst) <= nDst+4 { + return 0, 0, transform.ErrShortDst + } + nDst += copy(dst[nDst:], fmt.Sprintf("%d ", t.Code)) + return + } + + for nDst < len(dst) && nSrc < len(src) { + c := src[nSrc] + if !t.init || c == lf { + if len(dst) <= nDst+5 { + err = transform.ErrShortDst + return + } + if !t.init && c == lf { + err = errStartWithLF + return + } + // determine if there is a newline following + newline := false + for peek := nSrc + 1; peek < len(src); peek++ { + if src[peek] == lf { + newline = true + break + } + } + // request more data when there might be more data, and we did not find a newline + if !atEOF && !newline { + err = transform.ErrShortSrc + return + } + // insert \n before the SMTP code + if t.init { + dst[nDst] = c + nDst++ + nSrc++ + } + if newline { + nDst += copy(dst[nDst:], fmt.Sprintf("%d-", t.Code)) + } else { + nDst += copy(dst[nDst:], fmt.Sprintf("%d ", t.Code)) + } + // first char is missing + if !t.init { + t.init = true + dst[nDst] = c + nDst++ + nSrc++ + } + } else { + dst[nDst] = c + nDst++ + nSrc++ + } + } + if nSrc < len(src) { + err = transform.ErrShortDst + } + return +} + +func (t *SMTPReplyTransformer) Reset() { + t.init = false +} + +var _ transform.Transformer = &SMTPReplyTransformer{} + +// DefaultMaximumLineLength is the maximum line length (in bytes) that will be used by [MaximumLineLengthTransformer] +// when its MaximumLength value is zero. +// The SMTP protocol theoretically allows up to 1000 bytes. We default to 950 bytes since some MTAs do forceful line +// breaks at lower limits (e.g. 980 bytes). +const DefaultMaximumLineLength = 950 + +var errWrongMaximumLineLength = errors.New("MaximumLength must be 4 or more") + +// MaximumLineLengthTransformer is a [transform.Transformer] that splits src into lines of at most MaximumLength bytes. +// +// CR and LF are considered new line indicators. They do not count to the line length. +// +// This transformer can handle UTF-8 input. +// Because of this we actually start tying to split lines at MaximumLength - 3 bytes. +// This way we can assure that one line is never bigger than MaximumLength bytes. +type MaximumLineLengthTransformer struct { + MaximumLength uint + length uint +} + +func (t *MaximumLineLengthTransformer) Transform(dst, src []byte, _ bool) (nDst, nSrc int, err error) { + if t.MaximumLength == 0 { + t.MaximumLength = DefaultMaximumLineLength + } + if t.MaximumLength < utf8.UTFMax { + return 0, 0, errWrongMaximumLineLength + } + + for nDst < len(dst) && nSrc < len(src) { + c := src[nSrc] + isCrOfLf := c == cr || c == lf + // break when we find a valid UTF8 rune start near the end of the line + // or when we reach the maximum (then the string has invalid UTF-8 anyway) + if !isCrOfLf && ((t.length > t.MaximumLength-utf8.UTFMax && utf8.RuneStart(c)) || (t.length >= t.MaximumLength)) { + if len(dst) <= nDst+2 { + err = transform.ErrShortDst + return + } + nDst += copy(dst[nDst:], "\r\n") + t.length = 0 + } + dst[nDst] = c + nDst++ + nSrc++ + if isCrOfLf { + t.length = 0 + } else { + t.length++ + } + } + if nSrc < len(src) { + err = transform.ErrShortDst + } + return +} + +func (t *MaximumLineLengthTransformer) Reset() { + t.length = 0 +} + +var _ transform.Transformer = &MaximumLineLengthTransformer{} diff --git a/milterutil/transformer_test.go b/milterutil/transformer_test.go new file mode 100644 index 0000000..d598583 --- /dev/null +++ b/milterutil/transformer_test.go @@ -0,0 +1,528 @@ +package milterutil + +import ( + "bufio" + "bytes" + "fmt" + "io" + "net/textproto" + "regexp" + "strings" + "testing" + "unicode/utf8" + + "golang.org/x/text/transform" +) + +type transformerTestCase struct { + inputs []string + expected string +} +type transformerTestCases []transformerTestCase + +func doTransformation(transformer transform.Transformer, inputs []string) ([]byte, error) { + r, w := io.Pipe() + go func() { + for _, s := range inputs { + if _, err := w.Write([]byte(s)); err != nil { + _ = w.CloseWithError(err) + return + } + } + _ = w.Close() + }() + tr := transform.NewReader(r, transformer) + return io.ReadAll(tr) +} + +func doTransformerTest(t *testing.T, getTransformer func() transform.Transformer, extraCheck func(*testing.T, transformerTestCase, string), testCases transformerTestCases) { + runTestCase := func(t *testing.T, tt transformerTestCase, transformer transform.Transformer) { + output, err := doTransformation(transformer, tt.inputs) + if err != nil { + t.Fatal(err) + } + if string(output) != tt.expected { + t.Fatalf("expected %q, got %q", tt.expected, string(output)) + } + output2, _, err := transform.String(transformer, strings.Join(tt.inputs, "")) + if err != nil { + t.Fatal(err) + } + if output2 != tt.expected { + t.Fatalf("expected %q, got %q", tt.expected, output2) + } + if extraCheck != nil { + extraCheck(t, tt, output2) + } + } + for i, tt := range testCases { + prettyName := fmt.Sprintf(":%q", tt.inputs) + if len(prettyName) > 50 { + prettyName = fmt.Sprintf(":(%d inputs with %d bytes total)", len(tt.inputs), len(strings.Join(tt.inputs, ""))) + } + t.Run(fmt.Sprintf("%d%s", i, prettyName), func(t *testing.T) { + ltt := tt + t.Parallel() + runTestCase(t, ltt, getTransformer()) + }) + } + t.Run("Reset", func(t *testing.T) { + t.Parallel() + transformer := getTransformer() + for _, tt := range testCases { + runTestCase(t, tt, transformer) + } + }) +} + +func TestCrLfToLfTransformer(t *testing.T) { + // transform.Transformer uses initial dst buffer size of 4096 bytes + stuffing := strings.Repeat("1234567890", 4090/10) + t.Parallel() + doTransformerTest(t, func() transform.Transformer { + return &CrLfToLfTransformer{} + }, nil, transformerTestCases{ + {[]string{""}, ""}, + {[]string{"\n"}, "\n"}, + {[]string{"\r"}, "\n"}, + {[]string{"\r\n"}, "\n"}, + {[]string{"\r\r\n"}, "\n\n"}, + {[]string{"\r\n\r"}, "\n\n"}, + {[]string{"\r\n\r\n"}, "\n\n"}, + {[]string{"line1\r\nline2\r\n"}, "line1\nline2\n"}, + {[]string{"\r", "\n"}, "\n"}, + {[]string{"\r\r", "\n"}, "\n\n"}, + {[]string{stuffing + "123456\r", "\n"}, stuffing + "123456\n"}, + }) +} + +func TestCrLfCanonicalizationTransformer(t *testing.T) { + // transform.Transformer uses initial dst buffer size of 4096 bytes + stuffing := strings.Repeat("1234567890", 4090/10) + manyCR := strings.Repeat("\r", 4095) + manCRLF := strings.Repeat("\r\n", 4095) + t.Parallel() + doTransformerTest(t, func() transform.Transformer { + return &CrLfCanonicalizationTransformer{} + }, nil, transformerTestCases{ + {[]string{""}, ""}, + {[]string{"\n"}, "\r\n"}, + {[]string{"", "\n"}, "\r\n"}, + {[]string{"\r"}, "\r\n"}, + {[]string{"", "\r"}, "\r\n"}, + {[]string{"\r\n"}, "\r\n"}, + {[]string{"\r\r\n"}, "\r\n\r\n"}, + {[]string{"\r\n\r"}, "\r\n\r\n"}, + {[]string{"\r\n\r\n"}, "\r\n\r\n"}, + {[]string{"line1\nline2\r\nline3\n"}, "line1\r\nline2\r\nline3\r\n"}, + {[]string{"\r", "\n"}, "\r\n"}, + {[]string{"\r\r", "\n"}, "\r\n\r\n"}, + {[]string{"\n\x00\n"}, "\r\n\x00\r\n"}, + {[]string{stuffing + "123456\r", "\n"}, stuffing + "123456\r\n"}, + {[]string{manyCR}, manCRLF}, + }) +} + +func TestDoublePercentTransformer(t *testing.T) { + // transform.Transformer uses initial dst buffer size of 4096 bytes + stuffing := strings.Repeat("1234567890", 4090/10) + manyPercent := strings.Repeat("%", 4096) + t.Parallel() + doTransformerTest(t, func() transform.Transformer { + return &DoublePercentTransformer{} + }, nil, transformerTestCases{ + {[]string{""}, ""}, + {[]string{"%"}, "%%"}, + {[]string{" % "}, " %% "}, + {[]string{"%%"}, "%%%%"}, + {[]string{" ", "%"}, " %%"}, + {[]string{"%", "%"}, "%%%%"}, + {[]string{"%\x00%"}, "%%\x00%%"}, + {[]string{stuffing + "12345%", "%"}, stuffing + "12345%%%%"}, + {[]string{manyPercent}, manyPercent + manyPercent}, + }) +} + +func TestSkipDoublePercentTransformer(t *testing.T) { + // transform.Transformer uses initial dst buffer size of 4096 bytes + stuffing := strings.Repeat("1234567890", 4090/10) + t.Parallel() + doTransformerTest(t, func() transform.Transformer { + return &SkipDoublePercentTransformer{} + }, nil, transformerTestCases{ + {[]string{""}, ""}, + {[]string{"%"}, "%"}, + {[]string{" % "}, " % "}, + {[]string{"%%"}, "%"}, + {[]string{"%", "%"}, "%"}, + {[]string{"%", "%", "%"}, "%%"}, + {[]string{"%%\x00%%"}, "%\x00%"}, + {[]string{stuffing + "12345%", "%"}, stuffing + "12345%"}, + }) +} + +func TestSMTPReplyTransformer(t *testing.T) { + // transform.Transformer uses initial dst buffer size of 4096 bytes + manyLines := strings.Repeat("12\r\n", 786) // 3144 bytes + expectedManyLines := strings.Repeat("499-12\r\n", 786) + "499 " // 6292 bytes + t.Parallel() + doTransformerTest(t, func() transform.Transformer { + return &SMTPReplyTransformer{Code: 499} + }, nil, transformerTestCases{ + {[]string{""}, "499 "}, + {[]string{"", ""}, "499 "}, + {[]string{"4.3.999 testing"}, "499 4.3.999 testing"}, + {[]string{"line1\r\nline2"}, "499-line1\r\n499 line2"}, + {[]string{"line1\r\nline2\r\n"}, "499-line1\r\n499-line2\r\n499 "}, + {[]string{"line1\nline2"}, "499-line1\n499 line2"}, + {[]string{manyLines}, expectedManyLines}, + }) + t.Run("no LF at start", func(t *testing.T) { + t.Parallel() + output, err := doTransformation(&SMTPReplyTransformer{Code: 499}, []string{"\n"}) + if err == nil { + t.Fatalf("expected err, got output = %q", output) + } + }) + t.Run("invalid code", func(t *testing.T) { + t.Parallel() + output, err := doTransformation(&SMTPReplyTransformer{Code: 9999}, []string{"\n"}) + if err == nil { + t.Fatalf("expected err, got output = %q", output) + } + }) +} + +func TestMaximumLineLengthTransformer(t *testing.T) { + t.Parallel() + doTransformerTest(t, func() transform.Transformer { + return &MaximumLineLengthTransformer{MaximumLength: 20} + }, func(t *testing.T, testCase transformerTestCase, output string) { + r := regexp.MustCompile("\r\n|\r|\n") + lines := r.Split(output, -1) + for _, line := range lines { + if len(line) > 20 { + t.Fatalf("output contained line with more than 20 bytes: %q", line) + } + } + }, transformerTestCases{ + {[]string{""}, ""}, + {[]string{"", ""}, ""}, + {[]string{"12345678901234567890123456789012"}, "12345678901234567\r\n890123456789012"}, + {[]string{"1234567890123456789012345678901234567890"}, "12345678901234567\r\n89012345678901234\r\n567890"}, + {[]string{"12345678901234567890\r\n12345678901234567890"}, "12345678901234567\r\n890\r\n12345678901234567\r\n890"}, + {[]string{"12345678901234567\r89012345678901234567890"}, "12345678901234567\r89012345678901234\r\n567890"}, + {[]string{"12345678901234567890\n12345678901234567890"}, "12345678901234567\r\n890\n12345678901234567\r\n890"}, + {[]string{"12345678901234567890", "\r\n12345678901234567890"}, "12345678901234567\r\n890\r\n12345678901234567\r\n890"}, + {[]string{"\r", "\n", "\r", "\n"}, "\r\n\r\n"}, + {[]string{"🚀🚀🚀🚀🚀"}, "🚀🚀🚀🚀🚀"}, + {[]string{"🚀🚀🚀1🚀🚀"}, "🚀🚀🚀1🚀\r\n🚀"}, + {[]string{"🚀🚀🚀12🚀🚀"}, "🚀🚀🚀12🚀\r\n🚀"}, + {[]string{"🚀🚀🚀123🚀🚀"}, "🚀🚀🚀123🚀\r\n🚀"}, + {[]string{"🚀🚀🚀1234🚀🚀"}, "🚀🚀🚀1234🚀\r\n🚀"}, + {[]string{"🚀🚀🚀12345🚀🚀"}, "🚀🚀🚀12345\r\n🚀🚀"}, + }) + t.Run("default line length", func(t *testing.T) { + t.Parallel() + line := strings.Repeat(".", DefaultMaximumLineLength-utf8.UTFMax+1) + output, err := doTransformation(&MaximumLineLengthTransformer{}, []string{line + line}) + if err != nil { + t.Fatalf("not expected err, got %s", err) + } + expected := line + "\r\n" + line + if string(output) != expected { + t.Fatalf("expected %q, got %q", expected, string(output)) + } + }) + t.Run("enforce minimum", func(t *testing.T) { + t.Parallel() + _, err := doTransformation(&MaximumLineLengthTransformer{MaximumLength: 1}, []string{""}) + if err != errWrongMaximumLineLength { + t.Fatalf("err got %s, expected %s", err, errWrongMaximumLineLength) + } + }) + t.Run("work with minimum", func(t *testing.T) { + t.Parallel() + output, err := doTransformation(&MaximumLineLengthTransformer{MaximumLength: 4}, []string{"1234"}) + if err != nil { + t.Fatalf("not expected err, got %s", err) + } + expected := "1\r\n2\r\n3\r\n4" + if string(output) != expected { + t.Fatalf("expected %q, got %q", expected, string(output)) + } + }) +} + +func TestCrLfToLf(t *testing.T) { + tests := []struct { + name string + arg string + want string + }{ + {"empty", "", ""}, + {"simple", "\r\n", "\n"}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := CrLfToLf(tt.arg); got != tt.want { + t.Errorf("CrLfToLf() = %v, want %v", got, tt.want) + } + }) + } +} + +func FuzzCrLfToLfTransformer_Transform(f *testing.F) { + f.Add([]byte("\r\n"), []byte(""), true) + f.Add([]byte("\r"), []byte("\n"), true) + f.Add([]byte("one\r\ntwo"), []byte(""), true) + f.Add([]byte("\r"), []byte(""), true) + f.Add([]byte("one\rtwo"), []byte(""), true) + f.Add([]byte("\n"), []byte(""), true) + f.Add([]byte("one\ntwo"), []byte(""), true) + f.Add([]byte("\r\r\n"), []byte(""), true) + f.Add([]byte("\r\r"), []byte("\n"), true) + f.Fuzz(func(t *testing.T, input1 []byte, input2 []byte, writeEmpty bool) { + r, w := io.Pipe() + go func() { + if len(input1) > 0 || writeEmpty { + if _, err := w.Write(input1); err != nil { + _ = w.CloseWithError(err) + return + } + } + if len(input2) > 0 || writeEmpty { + if _, err := w.Write(input2); err != nil { + _ = w.CloseWithError(err) + return + } + } + _ = w.Close() + }() + output, err := io.ReadAll(transform.NewReader(r, &CrLfToLfTransformer{})) + if err != nil { + t.Fatal(err) + } + if len(output) > len(input1)+len(input2) { + t.Fatalf("output bigger than input %d > %d", len(output), len(input1)+len(input2)) + } + if bytes.Contains(output, []byte("\r\n")) { + t.Fatal("output contains \\r\\n") + } + }) +} + +func FuzzCrLfCanonicalizationTransformer_Transform(f *testing.F) { + lineEndingRegexp := regexp.MustCompile("\r\n|\n\r|\r|\n") + f.Add([]byte("\r\n"), []byte(""), true) + f.Add([]byte("\r"), []byte("\n"), true) + f.Add([]byte("one\r\ntwo"), []byte(""), true) + f.Add([]byte("\r"), []byte(""), true) + f.Add([]byte("one\rtwo"), []byte(""), true) + f.Add([]byte("\n"), []byte(""), true) + f.Add([]byte("one\ntwo"), []byte(""), true) + f.Add([]byte("\r\r\n"), []byte(""), true) + f.Add([]byte("\r\r"), []byte("\n"), true) + f.Fuzz(func(t *testing.T, input1 []byte, input2 []byte, writeEmpty bool) { + r, w := io.Pipe() + go func() { + if len(input1) > 0 || writeEmpty { + if _, err := w.Write(input1); err != nil { + _ = w.CloseWithError(err) + return + } + } + if len(input2) > 0 || writeEmpty { + if _, err := w.Write(input2); err != nil { + _ = w.CloseWithError(err) + return + } + } + _ = w.Close() + }() + output, err := io.ReadAll(transform.NewReader(r, &CrLfCanonicalizationTransformer{})) + if err != nil { + t.Fatal(err) + } + if len(output) < len(input1)+len(input2) { + t.Fatalf("output smaller than input %d < %d", len(output), len(input1)+len(input2)) + } + lineEndings := lineEndingRegexp.FindAll(output, -1) + for _, ending := range lineEndings { + if !bytes.Equal(ending, []byte("\r\n")) { + t.Fatalf("output contained wrong line ending: %q", ending) + } + } + }) +} + +func FuzzMaximumLineLengthTransformer_Transform(f *testing.F) { + lineEndingRegexp := regexp.MustCompile("\r\n|\n\r|\r|\n") + f.Add(uint(20), []byte("\r\n"), []byte(""), true) + f.Add(uint(4), []byte("\r"), []byte("\n"), true) + f.Add(uint(20), []byte("one\r\ntwo"), []byte(""), true) + f.Add(uint(20), []byte("\r"), []byte(""), true) + f.Add(uint(20), []byte("one\rtwo"), []byte(""), true) + f.Add(uint(20), []byte("\n"), []byte(""), true) + f.Add(uint(20), []byte("one\ntwo"), []byte(""), true) + f.Add(uint(20), []byte("\r\r\n"), []byte(""), true) + f.Add(uint(20), []byte("\r\r"), []byte("\n"), true) + f.Fuzz(func(t *testing.T, maxLineLength uint, input1 []byte, input2 []byte, writeEmpty bool) { + if maxLineLength < 4 { + return + } + r, w := io.Pipe() + go func() { + if len(input1) > 0 || writeEmpty { + if _, err := w.Write(input1); err != nil { + _ = w.CloseWithError(err) + return + } + } + if len(input2) > 0 || writeEmpty { + if _, err := w.Write(input2); err != nil { + _ = w.CloseWithError(err) + return + } + } + _ = w.Close() + }() + output, err := io.ReadAll(transform.NewReader(r, &MaximumLineLengthTransformer{MaximumLength: maxLineLength})) + if err != nil { + t.Fatal(err) + } + if len(output) < len(input1)+len(input2) { + t.Fatalf("output smaller than input %d < %d", len(output), len(input1)+len(input2)) + } + lines := lineEndingRegexp.Split(string(output), -1) + for _, line := range lines { + if len(line) > int(maxLineLength) { + t.Fatalf("output contained line with more than %d bytes: %q", maxLineLength, line) + } + } + if utf8.Valid(append(input1, input2...)) && !utf8.Valid(output) { + t.Fatal("input is valid UTF-8 but output is not") + } + }) +} + +func FuzzSkipDoublePercentTransformer_Transform(f *testing.F) { + f.Add([]byte("%"), []byte("%"), true) + f.Add([]byte("%%"), []byte(""), true) + f.Add([]byte(""), []byte("%"), true) + f.Add([]byte(""), []byte("%%"), true) + f.Fuzz(func(t *testing.T, input1 []byte, input2 []byte, writeEmpty bool) { + r, w := io.Pipe() + go func() { + if len(input1) > 0 || writeEmpty { + if _, err := w.Write(input1); err != nil { + _ = w.CloseWithError(err) + return + } + } + if len(input2) > 0 || writeEmpty { + if _, err := w.Write(input2); err != nil { + _ = w.CloseWithError(err) + return + } + } + _ = w.Close() + }() + output, err := io.ReadAll(transform.NewReader(r, &SkipDoublePercentTransformer{})) + if err != nil { + t.Fatal(err) + } + if len(output) > len(input1)+len(input2) { + t.Fatalf("output bigger than input %d > %d", len(output), len(input1)+len(input2)) + } + if bytes.Contains(output, []byte("%%")) { + t.Fatal("output contains %%") + } + }) +} + +func FuzzDoublePercentTransformer_Transform(f *testing.F) { + singlePercentRegexp := regexp.MustCompile("[^%]%|%[^%]") + f.Add([]byte("%"), []byte("%"), true) + f.Add([]byte("%%"), []byte(""), true) + f.Add([]byte(""), []byte("%"), true) + f.Add([]byte(""), []byte("%%"), true) + f.Fuzz(func(t *testing.T, input1 []byte, input2 []byte, writeEmpty bool) { + r, w := io.Pipe() + go func() { + if len(input1) > 0 || writeEmpty { + if _, err := w.Write(input1); err != nil { + _ = w.CloseWithError(err) + return + } + } + if len(input2) > 0 || writeEmpty { + if _, err := w.Write(input2); err != nil { + _ = w.CloseWithError(err) + return + } + } + _ = w.Close() + }() + output, err := io.ReadAll(transform.NewReader(r, &DoublePercentTransformer{})) + if err != nil { + t.Fatal(err) + } + if len(output) < len(input1)+len(input2) { + t.Fatalf("output smaller than input %d < %d", len(output), len(input1)+len(input2)) + } + if singlePercentRegexp.Match(output) { + t.Fatal("output contains single %") + } + }) +} + +func FuzzSMTPReplyTransformer_Transform(f *testing.F) { + f.Add([]byte("\r\n"), []byte(""), true) + f.Add([]byte("\r"), []byte("\n"), true) + f.Add([]byte("one\r\ntwo"), []byte(""), true) + f.Add([]byte("\r"), []byte(""), true) + f.Add([]byte("one\rtwo"), []byte(""), true) + f.Add([]byte("\n"), []byte(""), true) + f.Add([]byte("one\ntwo"), []byte(""), true) + f.Add([]byte("\r\r\n"), []byte(""), true) + f.Add([]byte("\r\r"), []byte("\n"), true) + f.Add([]byte("a long line"), []byte("a long line"), true) + f.Fuzz(func(t *testing.T, input1 []byte, input2 []byte, writeEmpty bool) { + r, w := io.Pipe() + lw := transform.NewWriter(w, &MaximumLineLengthTransformer{MaximumLength: 920}) + go func() { + if len(input1) > 0 || writeEmpty { + if _, err := lw.Write(input1); err != nil { + _ = w.CloseWithError(err) + return + } + } + if len(input2) > 0 || writeEmpty { + if _, err := lw.Write(input2); err != nil { + _ = w.CloseWithError(err) + return + } + } + if err := lw.Close(); err != nil { + _ = w.CloseWithError(err) + } else { + _ = w.Close() + } + }() + output, err := io.ReadAll(transform.NewReader(r, &SMTPReplyTransformer{Code: 300})) + if err != nil { + if err == errStartWithLF { + return + } + t.Fatal(err) + } + if len(output) < len(input1)+len(input2) { + t.Fatalf("output smaller than input %d < %d", len(output), len(input1)+len(input2)) + } + or := textproto.NewReader(bufio.NewReader(bytes.NewReader(output))) + if _, _, err := or.ReadResponse(300); err != nil { + t.Fatalf("not valid SMTP response: %q", output) + } + }) +} diff --git a/modifier.go b/modifier.go index 32e7041..fedb3d9 100644 --- a/modifier.go +++ b/modifier.go @@ -1,96 +1,415 @@ -// Modifier instance is provided to milter handlers to modify email messages - package milter import ( + "bufio" "bytes" "encoding/binary" + "errors" "fmt" + "io" "net/textproto" + + "github.com/d--j/go-milter/internal/wire" + "github.com/d--j/go-milter/milterutil" +) + +type ActionType int + +const ( + ActionAccept ActionType = iota + 1 + ActionContinue + ActionDiscard + ActionReject + ActionTempFail + ActionSkip + ActionRejectWithCode +) + +type Action struct { + Type ActionType + + // SMTP code if milter wants to abort the connection/message. Zero otherwise. + SMTPCode uint16 + // Properly formatted reply text if milter wants to abort the connection/message. Empty string otherwise. + SMTPReply string +} + +// StopProcessing returns true when the milter wants to immediately stop this SMTP connection. +// You can use SMTPReply to send as reply to the current SMTP command. +func (a Action) StopProcessing() bool { + return a.SMTPCode > 0 +} + +func parseAction(msg *wire.Message) (*Action, error) { + act := &Action{} + + switch wire.ActionCode(msg.Code) { + case wire.ActAccept: + act.Type = ActionAccept + case wire.ActContinue: + act.Type = ActionContinue + case wire.ActDiscard: + act.Type = ActionDiscard + case wire.ActReject: + act.Type = ActionReject + case wire.ActTempFail: + act.Type = ActionTempFail + case wire.ActSkip: + act.Type = ActionSkip + case wire.ActReplyCode: + if len(msg.Data) <= 4 { + return nil, fmt.Errorf("action read: unexpected data length: %d", len(msg.Data)) + } + checker := textproto.NewReader(bufio.NewReader(bytes.NewReader(msg.Data))) + // this also accepts FTP style multi-line responses as valid + // It's highly unlikely that milter sends one of those, so we ignore this false positive + code, _, err := checker.ReadResponse(0) + if err != nil { + return nil, fmt.Errorf("action read: malformed SMTP response: %q", msg.Data) + } + act.Type = ActionRejectWithCode + act.SMTPCode = uint16(code) + act.SMTPReply = wire.ReadCString(msg.Data) // use raw response as it was formatted by milter + default: + return nil, fmt.Errorf("action read: unexpected code: %c", msg.Code) + } + + return act, nil +} + +type ModifyActionType int + +const ( + ActionAddRcpt ModifyActionType = iota + 1 + ActionDelRcpt + ActionQuarantine + ActionReplaceBody + ActionChangeFrom + ActionAddHeader + ActionChangeHeader + ActionInsertHeader ) -// postfix wants LF lines endings. Using CRLF results in double CR sequences. -func crlfToLF(b []byte) []byte { - return bytes.ReplaceAll(b, []byte{'\r', '\n'}, []byte{'\n'}) +type ModifyAction struct { + Type ModifyActionType + + // Recipient to add/remove if Type == ActionAddRcpt or ActionDelRcpt. + // This value already includes the necessary <>. + Rcpt string + + // ESMTP arguments for recipient address if Type = ActionAddRcpt. + RcptArgs string + + // New envelope sender if Type = ActionChangeFrom. + // This value already includes the necessary <>. + From string + + // ESMTP arguments for envelope sender if Type = ActionChangeFrom. + FromArgs string + + // Portion of body to be replaced if Type == ActionReplaceBody. + Body []byte + + // Index of the header field to be changed if Type = ActionChangeHeader or Type = ActionInsertHeader. + // Index is 1-based and is per value of HdrName. + // E.g. HeaderIndex = 3 and HdrName = "DKIM-Signature" mean "change third + // DKIM-Signature field". Order is the same as of HeaderField calls. + // A HeaderIndex of 0 for Type = ActionInsertHeader has the special meaning "at the very beginning". + HeaderIndex uint32 + + // Header field name to be added/changed if Type == ActionAddHeader or + // ActionChangeHeader or ActionInsertHeader. + HeaderName string + + // Header field value to be added/changed if Type == ActionAddHeader or + // ActionChangeHeader or ActionInsertHeader. If set to empty string - the field + // should be removed. + HeaderValue string + + // Quarantine reason if Type == ActionQuarantine. + Reason string +} + +func parseModifyAct(msg *wire.Message) (*ModifyAction, error) { + act := &ModifyAction{} + + switch wire.ModifyActCode(msg.Code) { + case wire.ActAddRcpt: + argv := bytes.Split(msg.Data, []byte{0x00}) + if len(argv) != 2 { + return nil, fmt.Errorf("read modify action: wrong number of arguments %d for ActAddRcpt action", len(argv)) + } + act.Type = ActionAddRcpt + act.Rcpt = string(argv[0]) + case wire.ActAddRcptPar: + argv := bytes.Split(msg.Data, []byte{0x00}) + if len(argv) > 3 || len(argv) < 2 { + return nil, fmt.Errorf("read modify action: wrong number of arguments %d for ActAddRcpt action", len(argv)) + } + act.Type = ActionAddRcpt + act.Rcpt = string(argv[0]) + if len(argv) == 3 { + act.RcptArgs = string(argv[1]) + } + case wire.ActDelRcpt: + act.Type = ActionDelRcpt + act.Rcpt = wire.ReadCString(msg.Data) + case wire.ActQuarantine: + act.Type = ActionQuarantine + act.Reason = wire.ReadCString(msg.Data) + case wire.ActReplBody: + act.Type = ActionReplaceBody + act.Body = msg.Data + case wire.ActChangeFrom: + argv := bytes.Split(msg.Data, []byte{0x00}) + if len(argv) > 3 || len(argv) < 2 { + return nil, fmt.Errorf("read modify action: wrong number of arguments %d for ActChangeFrom action", len(argv)) + } + act.Type = ActionChangeFrom + act.From = string(argv[0]) + if len(argv) == 3 { + act.FromArgs = string(argv[1]) + } + case wire.ActChangeHeader, wire.ActInsertHeader: + if len(msg.Data) < 4 { + return nil, fmt.Errorf("read modify action: missing header index") + } + if wire.ModifyActCode(msg.Code) == wire.ActChangeHeader { + act.Type = ActionChangeHeader + } else { + act.Type = ActionInsertHeader + } + act.HeaderIndex = binary.BigEndian.Uint32(msg.Data) + + // Sendmail 8 compatibility + if wire.ModifyActCode(msg.Code) == wire.ActChangeHeader && act.HeaderIndex == 0 { + act.HeaderIndex = 1 + } + + msg.Data = msg.Data[4:] + fallthrough + case wire.ActAddHeader: + argv := bytes.Split(msg.Data, []byte{0x00}) + if len(argv) != 3 { + return nil, fmt.Errorf("read modify action: wrong number of arguments %d for header action: %v", len(argv), argv) + } + if wire.ModifyActCode(msg.Code) == wire.ActAddHeader { + act.Type = ActionAddHeader + } + act.HeaderName = string(argv[0]) + act.HeaderValue = string(argv[1]) + default: + return nil, fmt.Errorf("read modify action: unexpected message code: %v", msg.Code) + } + + return act, nil } -// Modifier provides access to Macros, Headers and Body data to callback handlers. It also defines a -// number of functions that can be used by callback handlers to modify processing of the email message +// Modifier provides access to Macros and Headers to callback handlers. It also defines a +// number of functions that can be used by callback handlers to modify processing of the email message. +// Besides Progress() they can only be called in the EndOfMessage callback. type Modifier struct { - Macros map[string]string - Headers textproto.MIMEHeader + Headers textproto.MIMEHeader + Macros Macros + writeProgressPacket func(*wire.Message) error + writePacket func(*wire.Message) error + actions OptAction + maxDataSize DataSize +} - writePacket func(*Message) error +func hasHats(str string) bool { + return len(str) > 1 && str[0] == '<' && str[len(str)-1] == '>' } -// AddRecipient appends a new envelope recipient for current message -func (m *Modifier) AddRecipient(r string) error { - data := []byte(fmt.Sprintf("<%s>", r) + null) - return m.writePacket(NewResponse('+', data).Response()) +func addHats(str string) string { + if hasHats(str) { + return str + } else { + return fmt.Sprintf("<%s>", str) + } +} + +func removeHats(str string) string { + if hasHats(str) { + return str[1 : len(str)-1] + } else { + return str + } +} + +var ErrModificationNotAllowed = errors.New("milter: modification not allowed via milter protocol negotiation") + +// AddRecipient appends a new envelope recipient for current message. +// You can optionally specify esmtpArgs to pass along. You need to negotiate this via [OptAddRcptWithArgs] with the MTA. +func (m *Modifier) AddRecipient(r string, esmtpArgs string) error { + if m.actions&OptAddRcpt == 0 && m.actions&OptAddRcptWithArgs == 0 { + return ErrModificationNotAllowed + } + if esmtpArgs != "" && m.actions&OptAddRcptWithArgs == 0 { + return ErrModificationNotAllowed + } + code := wire.ActAddRcpt + var buffer bytes.Buffer + buffer.WriteString(addHats(r)) + buffer.WriteByte(0) + // send wire.ActAddRcptPar when that is the only allowed action or we need to send it because esmptArgs ist set + if (esmtpArgs != "" && m.actions&OptAddRcptWithArgs != 0) || (esmtpArgs == "" && m.actions&OptAddRcptWithArgs != 0) { + buffer.WriteString(esmtpArgs) + buffer.WriteByte(0) + code = wire.ActAddRcptPar + } + return m.writePacket(newResponse(wire.Code(code), buffer.Bytes()).Response()) } // DeleteRecipient removes an envelope recipient address from message func (m *Modifier) DeleteRecipient(r string) error { - data := []byte(fmt.Sprintf("<%s>", r) + null) - return m.writePacket(NewResponse('-', data).Response()) + if m.actions&OptRemoveRcpt == 0 { + return ErrModificationNotAllowed + } + resp, err := newResponseStr(wire.Code(wire.ActDelRcpt), addHats(r)) + if err != nil { + return err + } + return m.writePacket(resp.Response()) } -// ReplaceBody substitutes message body with provided body -func (m *Modifier) ReplaceBody(body []byte) error { - body = crlfToLF(body) - return m.writePacket(NewResponse('b', body).Response()) +// ReplaceBodyRawChunk sends one chunk of the body replacement. +// +// The chunk get send as-is. Caller needs to ensure that the chunk does not exceed the maximum configured data size (defaults to [DataSize64K]) +// +// You should do the ReplaceBodyRawChunk calls all in one go without intersecting it with other modification actions. +// MTAs like Postfix do not allow that. +func (m *Modifier) ReplaceBodyRawChunk(chunk []byte) error { + if m.actions&OptChangeBody == 0 { + return ErrModificationNotAllowed + } + if len(chunk) > int(m.maxDataSize) { + return fmt.Errorf("milter: body chunk too large: %d > %d", len(chunk), m.maxDataSize) + } + return m.writePacket(newResponse(wire.Code(wire.ActReplBody), chunk).Response()) } -// AddHeader appends a new email message header the message +// ReplaceBody reads from r and send its contents in the least amount of chunks to the MTA. +// +// This function does not do any CR LF line ending canonicalization or maximum line length enforcements. +// If you need that you can use the various transform.Transformers of this package to wrap your reader. +// +// t := transform.Chain(&milter.CrLfCanonicalizationTransformer{}, &milter.MaximumLineLengthTransformer{}) +// wrappedR := transform.NewReader(r, t) +// m.ReplaceBody(wrappedR) +// +// This function tries to use as few calls to [ReplaceBodyRawChunk] as possible. +// +// You can call ReplaceBody multiple times. The MTA will combine all those calls into one message. +// +// You should do the ReplaceBody calls all in one go without intersecting it with other modification actions. +// MTAs like Postfix do not allow that. +func (m *Modifier) ReplaceBody(r io.Reader) error { + scanner := milterutil.GetFixedBufferScanner(uint32(m.maxDataSize), r) + defer scanner.Close() + for scanner.Scan() { + err := m.ReplaceBodyRawChunk(scanner.Bytes()) + if err != nil { + return err + } + } + return scanner.Err() +} + +// AddHeader appends a new email message header to the message func (m *Modifier) AddHeader(name, value string) error { + if m.actions&OptAddHeader == 0 { + return ErrModificationNotAllowed + } var buffer bytes.Buffer - buffer.WriteString(name + null) - buffer.Write(crlfToLF([]byte(value))) - buffer.WriteString(null) - return m.writePacket(NewResponse('h', buffer.Bytes()).Response()) + buffer.WriteString(name) + buffer.WriteByte(0) + buffer.WriteString(milterutil.CrLfToLf(value)) + buffer.WriteByte(0) + return m.writePacket(newResponse(wire.Code(wire.ActAddHeader), buffer.Bytes()).Response()) } // Quarantine a message by giving a reason to hold it func (m *Modifier) Quarantine(reason string) error { - return m.writePacket(NewResponse('q', []byte(reason+null)).Response()) + if m.actions&OptQuarantine == 0 { + return ErrModificationNotAllowed + } + return m.writePacket(newResponse(wire.Code(wire.ActQuarantine), []byte(reason+"\x00")).Response()) } // ChangeHeader replaces the header at the specified position with a new one. -// The index is per name. +// The index is per name. To delete a header pass an empty value. func (m *Modifier) ChangeHeader(index int, name, value string) error { + if m.actions&OptChangeHeader == 0 { + return ErrModificationNotAllowed + } var buffer bytes.Buffer if err := binary.Write(&buffer, binary.BigEndian, uint32(index)); err != nil { return err } - buffer.WriteString(name + null) - buffer.Write(crlfToLF([]byte(value))) - buffer.WriteString(null) - return m.writePacket(NewResponse('m', buffer.Bytes()).Response()) + buffer.WriteString(name) + buffer.WriteByte(0) + buffer.WriteString(milterutil.CrLfToLf(value)) + buffer.WriteByte(0) + return m.writePacket(newResponse(wire.Code(wire.ActChangeHeader), buffer.Bytes()).Response()) } // InsertHeader inserts the header at the specified position +// index is 1 based. The index 0 means at the very beginning. func (m *Modifier) InsertHeader(index int, name, value string) error { + // Insert header does not have its own action flag + if m.actions&OptChangeHeader == 0 && m.actions&OptAddHeader == 0 { + return ErrModificationNotAllowed + } var buffer bytes.Buffer if err := binary.Write(&buffer, binary.BigEndian, uint32(index)); err != nil { return err } - buffer.WriteString(name + null) - buffer.Write(crlfToLF([]byte(value))) - buffer.WriteString(null) - return m.writePacket(NewResponse('i', buffer.Bytes()).Response()) + buffer.WriteString(name) + buffer.WriteByte(0) + buffer.WriteString(milterutil.CrLfToLf(value)) + buffer.WriteByte(0) + return m.writePacket(newResponse(wire.Code(wire.ActInsertHeader), buffer.Bytes()).Response()) } // ChangeFrom replaces the FROM envelope header with a new one -func (m *Modifier) ChangeFrom(value string) error { - data := []byte(value + null) - return m.writePacket(NewResponse('e', data).Response()) +func (m *Modifier) ChangeFrom(value string, esmtpArgs string) error { + if m.actions&OptChangeFrom == 0 { + return ErrModificationNotAllowed + } + var buffer bytes.Buffer + buffer.WriteString(addHats(value)) + buffer.WriteByte(0) + if esmtpArgs != "" { + buffer.WriteString(esmtpArgs) + buffer.WriteByte(0) + } + return m.writePacket(newResponse(wire.Code(wire.ActChangeFrom), buffer.Bytes()).Response()) +} + +var respProgress = &Response{code: wire.Code(wire.ActProgress)} + +// Progress tells the client that there is progress in a long operation +func (m *Modifier) Progress() error { + return m.writeProgressPacket(respProgress.Response()) +} + +func errorWriteReadOnly(m *wire.Message) error { + return fmt.Errorf("tried to send action %c in read-only state", m.Code) } -// newModifier creates a new Modifier instance from milterSession -func newModifier(s *milterSession) *Modifier { +// newModifier creates a new [Modifier] instance from [serverSession] +func newModifier(s *serverSession, readOnly bool) *Modifier { + writePacket := s.writePacket + if readOnly { + writePacket = errorWriteReadOnly + } return &Modifier{ - Macros: s.macros, - Headers: s.headers, - writePacket: s.WritePacket, + Macros: ¯oReader{macrosStages: s.macros}, + Headers: s.headers, + writePacket: writePacket, + writeProgressPacket: s.writePacket, + actions: s.actions, + maxDataSize: s.maxDataSize, } } diff --git a/options.go b/options.go new file mode 100644 index 0000000..bfe5617 --- /dev/null +++ b/options.go @@ -0,0 +1,192 @@ +package milter + +import "time" + +// NewMilterFunc is the signature of a function that can be used with [WithDynamicMilter] to configure the [Milter] backend. +// The parameters version, action, protocol and maxData are the negotiated values. +type NewMilterFunc func(version uint32, action OptAction, protocol OptProtocol, maxData DataSize) Milter + +// NegotiationCallbackFunc is the signature of a [WithNegotiationCallback] function. +// With this callback function you can override the negotiation process. +type NegotiationCallbackFunc func(mtaVersion, milterVersion uint32, mtaActions, milterActions OptAction, mtaProtocol, milterProtocol OptProtocol, offeredDataSize DataSize) (version uint32, actions OptAction, protocol OptProtocol, maxDataSize DataSize, err error) + +type options struct { + maxVersion uint32 + actions OptAction + protocol OptProtocol + dialer Dialer + readTimeout, writeTimeout time.Duration + offeredMaxData, usedMaxData DataSize + macrosByStage macroRequests + newMilter NewMilterFunc + negotiationCallback NegotiationCallbackFunc +} + +// Option can be used to configure [Client] and [Server]. +type Option func(*options) + +// WithAction adds action to the actions your MTA supports or your [Milter] needs. You need to specify this since this library cannot +// guess what your MTA can handle or your milter needs. +// 0 is a valid value when your MTA does not support any message modification (only rejection) or your milter does not need any message modifications. +func WithAction(action OptAction) Option { + return func(h *options) { + h.actions = h.actions | action + } +} + +// WithoutAction removes action from the list of actions this MTA supports/[Milter] needs. +func WithoutAction(action OptAction) Option { + return func(h *options) { + h.actions = h.actions & ^action + } +} + +// WithActions sets the actions your MTA supports or your [Milter] needs. You need to specify this since this library cannot +// guess what your MTA can handle or your milter needs. +// 0 is a valid value when your MTA does not support any message modification (only rejection) or your milter does not need any message modifications. +func WithActions(actions OptAction) Option { + return func(h *options) { + h.actions = actions + } +} + +// WithProtocol adds protocol to the protocol features your MTA should be able to handle or your [Milter] needs. +// For MTAs you can normally skip setting this option since we then just default to all protocol feature that this library supports. +// [Milter] should specify this option to instruct the MTA to not send any events that your [Milter] does not need or to not expect any response from events that you are not using to accept or reject an SMTP transaction. +func WithProtocol(protocol OptProtocol) Option { + return func(h *options) { + h.protocol = h.protocol | protocol + } +} + +// WithoutProtocol removes protocol from the list of protocol features this MTA supports/[Milter] requests. +func WithoutProtocol(protocol OptProtocol) Option { + return func(h *options) { + h.protocol = h.protocol & ^protocol + } +} + +// WithProtocols sets the protocol features your MTA should be able to handle or your [Milter] needs. +// For MTAs you can normally skip setting this option since we then just default to all protocol feature that this library supports. +// Milter should specify this option to instruct the MTA to not send any events that your [Milter] does not need or to not expect any response from events that you are not using to accept or reject an SMTP transaction. +func WithProtocols(protocol OptProtocol) Option { + return func(h *options) { + h.protocol = protocol + } +} + +// WithMaximumVersion sets the maximum milter version your MTA or milter filter accepts. +// The default is to use the maximum supported version. +func WithMaximumVersion(version uint32) Option { + return func(h *options) { + h.maxVersion = version + } +} + +// WithDialer sets the [net.Dialer] this [Client] will use. You can use this to e.g. set the connection timeout of the client. +// The default is to use a [net.Dialer] with a connection timeout of 10 seconds. +func WithDialer(dialer Dialer) Option { + return func(h *options) { + h.dialer = dialer + } +} + +// WithReadTimeout sets the read-timeout for all read operations of this [Client] or [Server]. +// The default is a read-timeout of 10 seconds. +func WithReadTimeout(timeout time.Duration) Option { + return func(h *options) { + h.readTimeout = timeout + } +} + +// WithWriteTimeout sets the write-timeout for all read operations of this [Client] or [Server]. +// The default is a write-timeout of 10 seconds. +func WithWriteTimeout(timeout time.Duration) Option { + return func(h *options) { + h.writeTimeout = timeout + } +} + +// WithOfferedMaxData sets the [DataSize] that your MTA wants to offer to milters. +// The milter needs to accept this offer in protocol negotiation for it to become effective. +// This is just an indication to the milter that it can send bigger packages. +// This library does not care what value was negotiated and always accept packages of up to 512 MB. +// +// This is a [Client] only [Option]. +func WithOfferedMaxData(offeredMaxData DataSize) Option { + return func(h *options) { + h.offeredMaxData = offeredMaxData + } +} + +// WithUsedMaxData sets the [DataSize] that your MTA or milter uses to send packages to the other party. +// The default value is [DataSize64K] for maximum compatibility. +// If you set this to 0 the [Client] will use the value of [WithOfferedMaxData] and the [Server] will use the dataSize that it +// negotiated with the MTA. +// +// Setting the maximum used data size to something different might trigger the other party to an error. +// MTAs like Postfix/sendmail and newer libmilter versions can handle bigger values without negotiation. +// E.g. Postfix will accept packets of up to 2 GB. This library has a hard maximum packet size of 512 MB. +func WithUsedMaxData(usedMaxData DataSize) Option { + return func(h *options) { + h.usedMaxData = usedMaxData + } +} + +// WithoutDefaultMacros deletes all macro stage definitions that were made before this [Option]. +// Use it in [NewClient] do not use the default. Since [NewServer] does not have a default, it is a no-op in [NewServer]. +func WithoutDefaultMacros() Option { + return func(h *options) { + h.macrosByStage = nil + } +} + +// WithMacroRequest defines the macros that your [Client] intends to send at stage, or it instructs the [Server] to ask for these macros at this stage. +// +// For [Client]: The milter can request other macros at protocol negotiation but if it does not do this (most do not) it will receive these macros at these stages. +// +// For [Server]: MTAs like sendmail and Postfix honor your macro requests and only send you the macros you requested (even if other macros were configured in their configuration). +// If it is possible your milter should gracefully handle the case that the MTA does not honor your macro requests. +// This function automatically sets the action [OptSetMacros] +func WithMacroRequest(stage MacroStage, macros []MacroName) Option { + return func(h *options) { + if h.macrosByStage == nil { + h.macrosByStage = make([][]MacroName, StageEndMarker) + } + h.macrosByStage[stage] = macros + } +} + +// WithMilter sets the [Milter] backend this [Server] uses. +// +// This is a [Server] only [Option]. +func WithMilter(newMilter func() Milter) Option { + return func(h *options) { + h.newMilter = func(uint32, OptAction, OptProtocol, DataSize) Milter { + return newMilter() + } + } +} + +// WithDynamicMilter sets the [Milter] backend this [Server] uses. +// This [Option] sets the milter with the negotiated version, action and protocol. +// You can use this to dynamically configure the [Milter] backend. +// +// This is a [Server] only [Option]. +func WithDynamicMilter(newMilter NewMilterFunc) Option { + return func(h *options) { + h.newMilter = newMilter + } +} + +// WithNegotiationCallback is an expert [Option] with which you can overwrite the negotiation process. +// +// You should not need to use this. You might easily break things. You are responsible to adhere to +// the milter protocol negotiation rules (they unfortunately only exist in sendmail & libmilter source code). +// +// This is a [Server] only [Option]. +func WithNegotiationCallback(negotiationCallback NegotiationCallbackFunc) Option { + return func(h *options) { + h.negotiationCallback = negotiationCallback + } +} diff --git a/options_test.go b/options_test.go new file mode 100644 index 0000000..14e5dcf --- /dev/null +++ b/options_test.go @@ -0,0 +1,161 @@ +package milter + +import ( + "net" + "reflect" + "testing" + "time" +) + +type optionsTestCase struct { + name string + start options + options []Option + want options +} + +func testOptions(t *testing.T, tests []optionsTestCase) { + for _, tt_ := range tests { + t.Run(tt_.name, func(t *testing.T) { + tt := tt_ + t.Parallel() + got := tt.start + for _, f := range tt.options { + f(&got) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("got %+v, want %+v", got, tt.want) + } + }) + } +} + +func TestWithAction(t *testing.T) { + testOptions(t, []optionsTestCase{ + {"set", options{}, []Option{WithAction(OptAddHeader)}, options{actions: OptAddHeader}}, + {"add", options{}, []Option{WithAction(OptAddHeader), WithAction(OptQuarantine)}, options{actions: OptAddHeader | OptQuarantine}}, + {"noop", options{actions: OptChangeHeader}, []Option{WithAction(OptChangeHeader)}, options{actions: OptChangeHeader}}, + {"keep", options{actions: OptChangeHeader}, []Option{WithAction(OptAddHeader)}, options{actions: OptChangeHeader | OptAddHeader}}, + }) +} + +func TestWithoutAction(t *testing.T) { + testOptions(t, []optionsTestCase{ + {"noop", options{}, []Option{WithoutAction(OptAddHeader)}, options{}}, + {"remove", options{actions: OptAddHeader | OptQuarantine}, []Option{WithoutAction(OptAddHeader)}, options{actions: OptQuarantine}}, + }) +} + +func TestWithActions(t *testing.T) { + testOptions(t, []optionsTestCase{ + {"set", options{}, []Option{WithActions(OptAddHeader)}, options{actions: OptAddHeader}}, + {"no-add", options{}, []Option{WithActions(OptAddHeader), WithActions(OptQuarantine)}, options{actions: OptQuarantine}}, + {"noop", options{actions: OptChangeHeader}, []Option{WithActions(OptChangeHeader)}, options{actions: OptChangeHeader}}, + {"remove", options{actions: OptChangeHeader}, []Option{WithActions(OptAddHeader)}, options{actions: OptAddHeader}}, + }) +} + +func TestWithProtocol(t *testing.T) { + testOptions(t, []optionsTestCase{ + {"set", options{}, []Option{WithProtocol(OptNoData)}, options{protocol: OptNoData}}, + {"add", options{}, []Option{WithProtocol(OptNoData), WithProtocol(OptNoMailFrom)}, options{protocol: OptNoData | OptNoMailFrom}}, + {"noop", options{protocol: OptNoData}, []Option{WithProtocol(OptNoData)}, options{protocol: OptNoData}}, + {"keep", options{protocol: OptNoData}, []Option{WithProtocol(OptNoMailFrom)}, options{protocol: OptNoData | OptNoMailFrom}}, + }) +} + +func TestWithoutProtocol(t *testing.T) { + testOptions(t, []optionsTestCase{ + {"noop", options{}, []Option{WithoutProtocol(OptSkip)}, options{}}, + {"remove", options{protocol: OptSkip | OptNoData}, []Option{WithoutProtocol(OptNoData)}, options{protocol: OptSkip}}, + }) +} + +func TestWithProtocols(t *testing.T) { + testOptions(t, []optionsTestCase{ + {"set", options{}, []Option{WithProtocols(OptNoEOH)}, options{protocol: OptNoEOH}}, + {"no-add", options{}, []Option{WithProtocols(OptNoEOH), WithProtocols(OptSkip)}, options{protocol: OptSkip}}, + {"noop", options{protocol: OptNoEOH}, []Option{WithProtocols(OptNoEOH)}, options{protocol: OptNoEOH}}, + {"remove", options{protocol: OptNoEOH}, []Option{WithProtocols(OptSkip)}, options{protocol: OptSkip}}, + }) +} + +func TestWithMaximumVersion(t *testing.T) { + testOptions(t, []optionsTestCase{ + {"set", options{}, []Option{WithMaximumVersion(12)}, options{maxVersion: 12}}, + }) +} + +func TestWithOfferedMaxData(t *testing.T) { + testOptions(t, []optionsTestCase{ + {"set", options{}, []Option{WithOfferedMaxData(12)}, options{offeredMaxData: 12}}, + }) +} + +func TestWithUsedMaxData(t *testing.T) { + testOptions(t, []optionsTestCase{ + {"set", options{}, []Option{WithUsedMaxData(12)}, options{usedMaxData: 12}}, + }) +} +func TestWithReadTimeout(t *testing.T) { + testOptions(t, []optionsTestCase{ + {"set", options{}, []Option{WithReadTimeout(time.Second)}, options{readTimeout: time.Second}}, + }) +} + +func TestWithWriteTimeout(t *testing.T) { + testOptions(t, []optionsTestCase{ + {"set", options{}, []Option{WithWriteTimeout(time.Second)}, options{writeTimeout: time.Second}}, + }) +} + +func TestWithDialer(t *testing.T) { + testOptions(t, []optionsTestCase{ + {"set", options{}, []Option{WithDialer(&net.Dialer{Timeout: time.Second})}, options{dialer: &net.Dialer{Timeout: time.Second}}}, + }) +} + +func TestWithMacroRequest(t *testing.T) { + testOptions(t, []optionsTestCase{ + {"set", options{}, []Option{WithMacroRequest(StageRcpt, []MacroName{MacroRcptAddr})}, options{macrosByStage: macroRequests{nil, nil, nil, []MacroName{MacroRcptAddr}, nil, nil, nil}}}, + }) +} + +func TestWithoutDefaultMacros(t *testing.T) { + testOptions(t, []optionsTestCase{ + {"noop", options{}, []Option{WithoutDefaultMacros()}, options{}}, + {"remove", options{macrosByStage: macroRequests{nil, nil, nil, []MacroName{MacroRcptAddr}, nil, nil, nil}}, []Option{WithoutDefaultMacros()}, options{}}, + }) +} + +func TestWithDynamicMilter(t *testing.T) { + opt := options{} + called := false + WithDynamicMilter(func(uint32, OptAction, OptProtocol, DataSize) Milter { + called = true + return nil + })(&opt) + if opt.newMilter == nil { + t.Fatalf("did not set newMilter") + } + opt.newMilter(0, 0, 0, 0) + if !called { + t.Fatalf("did not set the correct newMilter") + } +} + +func TestWithNegotiationCallback(t *testing.T) { + opt := options{} + called := false + WithNegotiationCallback(func(mtaVersion, milterVersion uint32, mtaActions, milterActions OptAction, mtaProtocol, milterProtocol OptProtocol, offeredDataSize DataSize) (version uint32, actions OptAction, protocol OptProtocol, maxDataSize DataSize, err error) { + called = true + return 0, 0, 0, 0, nil + })(&opt) + if opt.negotiationCallback == nil { + t.Fatalf("did not set negotiationCallback") + } + _, _, _, _, _ = opt.negotiationCallback(0, 0, 0, 0, 0, 0, 0) + if !called { + t.Fatalf("did not set the correct negotiationCallback") + } +} diff --git a/response.go b/response.go index 5e45934..a47ec02 100644 --- a/response.go +++ b/response.go @@ -1,62 +1,105 @@ package milter +import ( + "fmt" + "strings" + + "github.com/d--j/go-milter/internal/wire" + "github.com/d--j/go-milter/milterutil" + "golang.org/x/text/transform" +) + // Response represents a response structure returned by callback // handlers to indicate how the milter server should proceed -type Response interface { - Response() *Message - Continue() bool +type Response struct { + code wire.Code + data []byte } -// SimpleResponse type to define list of pre-defined responses -type SimpleResponse byte - -// Response returns a Message object reference -func (r SimpleResponse) Response() *Message { - return &Message{byte(r), nil} +// Response returns message instance with data +func (c *Response) Response() *wire.Message { + return &wire.Message{Code: c.code, Data: c.data} } -// Continue to process milter messages only if current code is Continue -func (r SimpleResponse) Continue() bool { - return ActionCode(r) == ActContinue +// Continue returns false if the MTA should stop sending events for this transaction, true otherwise. +// A RespDiscard Response will return false because the MTA should end sending events for the current +// SMTP transaction to this milter. +func (c *Response) Continue() bool { + switch wire.ActionCode(c.code) { + case wire.ActAccept, wire.ActDiscard, wire.ActReject, wire.ActTempFail, wire.ActReplyCode: + return false + default: + return true + } } -// Define standard responses with no data -const ( - RespAccept = SimpleResponse(ActAccept) - RespContinue = SimpleResponse(ActContinue) - RespDiscard = SimpleResponse(ActDiscard) - RespReject = SimpleResponse(ActReject) - RespTempFail = SimpleResponse(ActTempFail) -) - -// CustomResponse is a response instance used by callback handlers to indicate -// how the milter should continue processing of current message -type CustomResponse struct { - code byte - data []byte +// newResponse generates a new Response suitable for wire.WritePacket +func newResponse(code wire.Code, data []byte) *Response { + return &Response{code, data} } -// Response returns message instance with data -func (c *CustomResponse) Response() *Message { - return &Message{c.code, c.data} +// newResponseStr generates a new Response with string payload (null-byte terminated) +func newResponseStr(code wire.Code, data string) (*Response, error) { + if len(data) > int(DataSize64K)-1 { // space for null-bytes + return nil, fmt.Errorf("milter: invalid data length: %d > %d", len(data), int(DataSize64K)-1) + } + if strings.ContainsRune(data, 0) { + return nil, fmt.Errorf("milter: invalid data: cannot contain null-bytes") + } + return newResponse(code, []byte(data+"\x00")), nil } -// Continue returns false if milter chain should be stopped, true otherwise -func (c *CustomResponse) Continue() bool { - for _, q := range []ActionCode{ActAccept, ActDiscard, ActReject, ActTempFail} { - if c.code == byte(q) { - return false - } +// RejectWithCodeAndReason stops processing and tells client the error code and reason to sent +// +// smtpCode must be between 400 and 599, otherwise this method will return an error. +// +// The reason can contain new-lines. Line ending canonicalization is done automatically. +// This function returns an error when the resulting SMTP text has a length of more than [DataSize64K] - 1 +func RejectWithCodeAndReason(smtpCode uint16, reason string) (*Response, error) { + if smtpCode < 400 || smtpCode > 599 { + return nil, fmt.Errorf("milter: invalid code %d", smtpCode) + } + if len(reason) > int(DataSize64K)-5 { + return nil, fmt.Errorf("milter: reason too long: %d > %d", len(reason), int(DataSize64K)-5) + } + escapeAndNormalize := transform.Chain(&milterutil.DoublePercentTransformer{}, &milterutil.CrLfCanonicalizationTransformer{}) + data, _, err := transform.String(escapeAndNormalize, strings.TrimRight(reason, "\r\n")) + if err != nil { + return nil, err } - return true + data, _, err = transform.String(&milterutil.MaximumLineLengthTransformer{}, data) + if err != nil { + return nil, err + } + data, _, err = transform.String(&milterutil.SMTPReplyTransformer{Code: smtpCode}, data) + return newResponseStr(wire.Code(wire.ActReplyCode), data) } -// NewResponse generates a new CustomResponse suitable for WritePacket -func NewResponse(code byte, data []byte) *CustomResponse { - return &CustomResponse{code, data} -} +// Define standard responses with no data +var ( + // RespAccept signals to the MTA that the current transaction should be accepted. + // No more events get send to the milter after this response. + RespAccept = &Response{code: wire.Code(wire.ActAccept)} -// NewResponseStr generates a new CustomResponse with string payload -func NewResponseStr(code byte, data string) *CustomResponse { - return NewResponse(code, []byte(data+null)) -} + // RespContinue signals to the MTA that the current transaction should continue + RespContinue = &Response{code: wire.Code(wire.ActContinue)} + + // RespDiscard signals to the MTA that the current transaction should be silently discarded. + // No more events get send to the milter after this response. + RespDiscard = &Response{code: wire.Code(wire.ActDiscard)} + + // RespReject signals to the MTA that the current transaction should be rejected with a hard rejection. + // No more events get send to the milter after this response. + RespReject = &Response{code: wire.Code(wire.ActReject)} + + // RespTempFail signals to the MTA that the current transaction should be rejected with a temporary error code. + // The sending MTA might try to deliver the same message again at a later time. + // No more events get send to the milter after this response. + RespTempFail = &Response{code: wire.Code(wire.ActTempFail)} + + // RespSkip signals to the MTA that transaction should continue and that the MTA + // does not need to send more events of the same type. This response one makes sense/is possible as + // return value of Milter.RcptTo, Milter.Header and Milter.BodyChunk. + // No more events get send to the milter after this response. + RespSkip = &Response{code: wire.Code(wire.ActSkip)} +) diff --git a/response_test.go b/response_test.go new file mode 100644 index 0000000..d1dbdd6 --- /dev/null +++ b/response_test.go @@ -0,0 +1,88 @@ +package milter + +import ( + "reflect" + "strings" + "testing" + + "github.com/d--j/go-milter/internal/wire" +) + +func TestRejectWithCodeAndReason(t *testing.T) { + t.Parallel() + tooBig := strings.Repeat("%%%%%%%%%%%%%%%%", 3000) + type args struct { + smtpCode uint16 + reason string + } + tests := []struct { + name string + args args + want string + wantErr bool + }{ + {"Simple", args{400, "go away"}, "400 go away", false}, + {"Multi", args{400, "go away\r\nreally!"}, "400-go away\r\n400 really!", false}, + {"Trailing CRLF", args{400, "go away\r\nreally!\r\n"}, "400-go away\r\n400 really!", false}, + {"Empty", args{400, ""}, "400 ", false}, + {"Newline1", args{400, "\n"}, "400 ", false}, + {"Newline2", args{400, "\r"}, "400 ", false}, + {"Newline3", args{400, "\r\n"}, "400 ", false}, + {"Newline4", args{400, "\n\r"}, "400 ", false}, + {"%", args{400, "%"}, "400 %%", false}, + {"null-bytes", args{400, "bogus\x00reason"}, "", true}, + {"invalid-code1", args{200, ""}, "", true}, + {"invalid-code2", args{999, ""}, "", true}, + {"too-big", args{400, tooBig}, "", true}, + {"too-big", args{400, tooBig + tooBig}, "", true}, + } + for _, tt_ := range tests { + t.Run(tt_.name, func(t *testing.T) { + tt := tt_ + t.Parallel() + response, err := RejectWithCodeAndReason(tt.args.smtpCode, tt.args.reason) + if (err != nil) != tt.wantErr { + t.Fatalf("RejectWithCodeAndReason() error = %v, wantErr %v", err, tt.wantErr) + } + if tt.wantErr { + return + } + if response == nil { + t.Fatalf("response ") + } + if response.code != wire.Code(wire.ActReplyCode) { + t.Fatalf("response.code got %c, want %c", response.code, wire.ActReplyCode) + } + got := string(response.data[0 : len(response.data)-1]) + if got != tt.want { + t.Errorf("RejectWithCodeAndReason() got = %q, want %q", got, tt.want) + } + }) + } +} + +func TestCustomResponseDefaultResponse(t *testing.T) { + tests := []struct { + name string + r *Response + wantContinue bool + wantMsg *wire.Message + }{ + {"RespContinue", RespContinue, true, &wire.Message{Code: wire.Code(wire.ActContinue)}}, + {"RespSkip", RespSkip, true, &wire.Message{Code: wire.Code(wire.ActSkip)}}, + {"RespAccept", RespAccept, false, &wire.Message{Code: wire.Code(wire.ActAccept)}}, + {"RespDiscard", RespDiscard, false, &wire.Message{Code: wire.Code(wire.ActDiscard)}}, + {"RespReject", RespReject, false, &wire.Message{Code: wire.Code(wire.ActReject)}}, + {"RespTempFail", RespTempFail, false, &wire.Message{Code: wire.Code(wire.ActTempFail)}}, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if gotContinue := tt.r.Continue(); gotContinue != tt.wantContinue { + t.Errorf("Continue() = %v, want %v", gotContinue, tt.wantContinue) + } + if gotResponse := tt.r.Response(); !reflect.DeepEqual(gotResponse, tt.wantMsg) { + t.Errorf("Response() = %v, want %v", gotResponse, tt.wantMsg) + } + }) + } +} diff --git a/server.go b/server.go index 378a1d3..ce171cd 100644 --- a/server.go +++ b/server.go @@ -3,57 +3,69 @@ package milter import ( "errors" "net" - "net/textproto" + "time" ) -// Milter protocol version implemented by the server. -// -// Note: Not exported as we might want to support multiple versions -// transparently in the future. -var serverProtocolVersion uint32 = 2 +// MaxServerProtocolVersion is the maximum Milter protocol version implemented by the server. +const MaxServerProtocolVersion uint32 = 6 -// ErrServerClosed is returned by the Server's Serve method after a call to -// Close. +// ErrServerClosed is returned by the Server's Serve method after a call to Close. var ErrServerClosed = errors.New("milter: server closed") // Milter is an interface for milter callback handlers. type Milter interface { // Connect is called to provide SMTP connection data for incoming message. // Suppress with OptNoConnect. - Connect(host string, family string, port uint16, addr net.IP, m *Modifier) (Response, error) + Connect(host string, family string, port uint16, addr string, m *Modifier) (*Response, error) // Helo is called to process any HELO/EHLO related filters. Suppress with // OptNoHelo. - Helo(name string, m *Modifier) (Response, error) + Helo(name string, m *Modifier) (*Response, error) // MailFrom is called to process filters on envelope FROM address. Suppress // with OptNoMailFrom. - MailFrom(from string, m *Modifier) (Response, error) + MailFrom(from string, esmtpArgs string, m *Modifier) (*Response, error) // RcptTo is called to process filters on envelope TO address. Suppress with // OptNoRcptTo. - RcptTo(rcptTo string, m *Modifier) (Response, error) + RcptTo(rcptTo string, esmtpArgs string, m *Modifier) (*Response, error) + + // Data is called at the beginning of the DATA command (after all RCPT TO commands). + Data(m *Modifier) (*Response, error) // Header is called once for each header in incoming message. Suppress with // OptNoHeaders. - Header(name string, value string, m *Modifier) (Response, error) + Header(name string, value string, m *Modifier) (*Response, error) - // Headers is called when all message headers have been processed. Suppress - // with OptNoEOH. - Headers(h textproto.MIMEHeader, m *Modifier) (Response, error) + // Headers gets called when all message headers have been processed. + // Suppress with OptNoEOH. + // h is a textproto.MIMEHeader of all collected headers. + // If you specified OptNoHeaders this will be of course empty. + Headers(m *Modifier) (*Response, error) // BodyChunk is called to process next message body chunk data (up to 64KB - // in size). Suppress with OptNoBody. - BodyChunk(chunk []byte, m *Modifier) (Response, error) + // in size). Suppress with OptNoBody. If you return RespSkip the MTA will stop + // sending more body chunks. But older MTAs do not support this and in this case there are more calls to BodyChunk. + // Your code should be able to handle this. + BodyChunk(chunk []byte, m *Modifier) (*Response, error) - // Body is called at the end of each message. All changes to message's + // EndOfMessage is called at the end of each message. All changes to message's // content & attributes must be done here. - Body(m *Modifier) (Response, error) + // The MTA can start over with another message in the same connection but that is handled in a new Milter instance. + EndOfMessage(m *Modifier) (*Response, error) - // Abort is called is the current message has been aborted. All message data - // should be reset to prior to the Helo callback. Connection data should be - // preserved. + // Abort is called if the current message has been aborted. All message data + // should be reset prior to the MailFrom callback. Connection data should be + // preserved. Cleanup is not called before or after Abort. Abort(m *Modifier) error + + // Unknown is called when the MTA got an unknown command in the SMTP connection. + Unknown(cmd string, m *Modifier) (*Response, error) + + // Cleanup always gets called when the Milter is about to be discarded. + // E.g. because the MTA closed the connection, one SMTP message was successful or there was an error. + // May be called more than once for a single Milter. + Cleanup() } // NoOpMilter is a dummy Milter implementation that does nothing. @@ -61,55 +73,107 @@ type NoOpMilter struct{} var _ Milter = NoOpMilter{} -func (NoOpMilter) Connect(host string, family string, port uint16, addr net.IP, m *Modifier) (Response, error) { +func (NoOpMilter) Connect(host string, family string, port uint16, addr string, m *Modifier) (*Response, error) { return RespContinue, nil } -func (NoOpMilter) Helo(name string, m *Modifier) (Response, error) { +func (NoOpMilter) Helo(name string, m *Modifier) (*Response, error) { return RespContinue, nil } -func (NoOpMilter) MailFrom(from string, m *Modifier) (Response, error) { +func (NoOpMilter) MailFrom(from string, esmtpArgs string, m *Modifier) (*Response, error) { return RespContinue, nil } -func (NoOpMilter) RcptTo(rcptTo string, m *Modifier) (Response, error) { +func (NoOpMilter) RcptTo(rcptTo string, esmtpArgs string, m *Modifier) (*Response, error) { return RespContinue, nil } -func (NoOpMilter) Header(name string, value string, m *Modifier) (Response, error) { +func (NoOpMilter) Data(m *Modifier) (*Response, error) { return RespContinue, nil } -func (NoOpMilter) Headers(h textproto.MIMEHeader, m *Modifier) (Response, error) { +func (NoOpMilter) Header(name string, value string, m *Modifier) (*Response, error) { return RespContinue, nil } -func (NoOpMilter) BodyChunk(chunk []byte, m *Modifier) (Response, error) { +func (NoOpMilter) Headers(m *Modifier) (*Response, error) { return RespContinue, nil } -func (NoOpMilter) Body(m *Modifier) (Response, error) { +func (NoOpMilter) BodyChunk(chunk []byte, m *Modifier) (*Response, error) { + return RespContinue, nil +} + +func (NoOpMilter) EndOfMessage(m *Modifier) (*Response, error) { return RespAccept, nil } -func (NoOpMilter) Abort(m *Modifier) error { +func (NoOpMilter) Abort(_ *Modifier) error { return nil } +func (NoOpMilter) Unknown(cmd string, m *Modifier) (*Response, error) { + return RespContinue, nil +} + +func (NoOpMilter) Cleanup() { +} + // Server is a milter server. type Server struct { - NewMilter func() Milter - Actions OptAction - Protocol OptProtocol - + options options listeners []net.Listener closed bool } +// NewServer creates a new milter server. +// +// You need to at least specify the used Milter with the option WithMilter. +// You should also specify the actions your Milter will do. Otherwise, you cannot do any message modifications. +// For performance reasons you should disable protocol stages that you do not need with WithProtocol. +// +// This function will panic when you provide invalid options. +func NewServer(opts ...Option) *Server { + options := options{ + maxVersion: MaxServerProtocolVersion, + actions: 0, + protocol: 0, + readTimeout: 10 * time.Second, + writeTimeout: 10 * time.Second, + } + if len(opts) > 0 { + for _, o := range opts { + if o != nil { + o(&options) + } + } + } + + if options.newMilter == nil { + panic("milter: you need to use WithMilter in NewServer call") + } + if options.maxVersion > MaxServerProtocolVersion || options.maxVersion == 1 { + panic("milter: this library cannot handle this milter version") + } + if options.dialer != nil { + panic("milter: WithDialer is a client only option") + } + if options.offeredMaxData > 0 { + panic("milter: WithOfferedMaxData is a client only option") + } + if options.macrosByStage != nil { + options.actions = options.actions | OptSetMacros + } + + return &Server{options: options} +} + // Serve starts the server. func (s *Server) Serve(ln net.Listener) error { - defer ln.Close() + defer func(ln net.Listener) { + _ = ln.Close() + }(ln) s.listeners = append(s.listeners, ln) @@ -122,18 +186,22 @@ func (s *Server) Serve(ln net.Listener) error { return err } - session := milterSession{ + session := serverSession{ server: s, - actions: s.Actions, - protocol: s.Protocol, + version: s.options.maxVersion, + actions: s.options.actions, + protocol: s.options.protocol, conn: conn, - backend: s.NewMilter(), + macros: newMacroStages(), } go session.HandleMilterCommands() } } func (s *Server) Close() error { + if s.closed { + return ErrServerClosed + } s.closed = true for _, ln := range s.listeners { if err := ln.Close(); err != nil { diff --git a/server_test.go b/server_test.go new file mode 100644 index 0000000..1ac56b1 --- /dev/null +++ b/server_test.go @@ -0,0 +1,115 @@ +package milter + +import ( + "bytes" + "testing" + + "github.com/d--j/go-milter/internal/wire" + "github.com/emersion/go-message/textproto" +) + +func TestNoOpMilter(t *testing.T) { + t.Parallel() + asset := func(resp *Response, err error, act wire.ActionCode) { + t.Helper() + if resp.Response().Code != wire.Code(act) { + t.Fatalf("NoOpMilter response is not %c: %+v", act, resp) + } + if err != nil { + t.Fatal(err) + } + } + assetContinue := func(resp *Response, err error) { + t.Helper() + asset(resp, err, wire.ActContinue) + } + assetAccept := func(resp *Response, err error) { + t.Helper() + asset(resp, err, wire.ActAccept) + } + m := NoOpMilter{} + assetContinue(m.Connect("", "", 0, "", nil)) + assetContinue(m.Helo("", nil)) + assetContinue(m.MailFrom("", "", nil)) + assetContinue(m.RcptTo("", "", nil)) + assetContinue(m.Data(nil)) + assetContinue(m.Header("", "", nil)) + assetContinue(m.Headers(nil)) + assetContinue(m.BodyChunk(nil, nil)) + assetAccept(m.EndOfMessage(nil)) + assetContinue(m.Unknown("", nil)) + if err := m.Abort(nil); err != nil { + t.Fatal(err) + } +} + +func TestServer_NoOpMilter(t *testing.T) { + t.Parallel() + assert := func(act *Action, err error, expectedCode ActionType) { + t.Helper() + if err != nil { + t.Fatalf("got err: %v", err) + } + if act == nil { + t.Fatal("act is nil") + } + if act.Type != expectedCode { + t.Fatalf("got action: %+v expected action code %c", act, expectedCode) + } + } + assertContinue := func(act *Action, err error) { + t.Helper() + assert(act, err, ActionContinue) + } + assertEnd := func(mActions []ModifyAction, act *Action, err error) { + t.Helper() + assert(act, err, ActionAccept) + if len(mActions) > 0 { + t.Fatalf("milter returned ModifyActions: %+v", mActions) + } + } + macros := NewMacroBag() + w := newServerClient(t, macros, []Option{WithMilter(func() Milter { + return NoOpMilter{} + })}, nil) + defer w.Cleanup() + macros.Set(MacroMTAFullyQualifiedDomainName, "localhost.local") + macros.Set(MacroTlsVersion, "TLS1.3") + macros.Set(MacroAuthType, "plain") + macros.Set(MacroRcptMailer, "smtp") + macros.Set(MacroQueueId, "123") + assertContinue(w.session.Conn("localhost", FamilyInet, 2525, "127.0.0.1")) + assertContinue(w.session.Helo("localhost")) + assertContinue(w.session.Mail("", "")) + assertContinue(w.session.Rcpt("", "")) + assertContinue(w.session.Rcpt("", "")) + if err := w.session.Abort(nil); err != nil { + t.Fatal(err) + } + assertContinue(w.session.Mail("", "")) + assertContinue(w.session.Rcpt("", "")) + assertContinue(w.session.Rcpt("", "")) + hdrs := textproto.Header{} + hdrs.Add("From", "Mailer Daemon <>") + assertContinue(w.session.Header(hdrs)) + assertEnd(w.session.BodyReadFrom(bytes.NewReader([]byte("test\ntest\n")))) + + if err := w.session.Reset(nil); err != nil { + t.Fatal(err) + } + + assertContinue(w.session.Conn("localhost", FamilyInet, 2525, "127.0.0.1")) + assertContinue(w.session.Helo("localhost")) + assertContinue(w.session.Mail("", "")) + assertContinue(w.session.Rcpt("", "")) + assertContinue(w.session.DataStart()) + assertContinue(w.session.HeaderField("From", "<>", nil)) + assertContinue(w.session.HeaderField("To", "<>", nil)) + assertContinue(w.session.HeaderEnd()) + assertContinue(w.session.BodyChunk([]byte("test\n"))) + assertContinue(w.session.BodyChunk([]byte("test\n"))) + assertEnd(w.session.End()) + if err := w.server.Close(); err != nil { + t.Fatal(err) + } +} diff --git a/session.go b/session.go index 3d232cb..91245d6 100644 --- a/session.go +++ b/session.go @@ -1,244 +1,370 @@ package milter import ( - "bufio" "bytes" "encoding/binary" "errors" + "fmt" "io" - "log" "net" "net/textproto" "strings" - "time" + + "github.com/d--j/go-milter/internal/wire" ) -var errCloseSession = errors.New("Stop current milter processing") - -// milterSession keeps session state during MTA communication -type milterSession struct { - server *Server - actions OptAction - protocol OptProtocol - conn net.Conn - headers textproto.MIMEHeader - macros map[string]string - backend Milter +var errCloseSession = errors.New("stop current milter processing") + +// serverSession keeps session state during MTA communication +type serverSession struct { + server *Server + version uint32 + actions OptAction + protocol OptProtocol + maxDataSize DataSize + conn net.Conn + headers textproto.MIMEHeader + macros *macrosStages + backend Milter } -// ReadPacket reads incoming milter packet -func (c *milterSession) ReadPacket() (*Message, error) { - return readPacket(c.conn, 0) +// readPacket reads incoming milter packet +func (m *serverSession) readPacket() (*wire.Message, error) { + return wire.ReadPacket(m.conn, 0) } -func readPacket(conn net.Conn, timeout time.Duration) (*Message, error) { - if timeout != 0 { - conn.SetReadDeadline(time.Now().Add(timeout)) - defer conn.SetReadDeadline(time.Time{}) - } +// writePacket sends a milter response packet to socket stream +func (m *serverSession) writePacket(msg *wire.Message) error { + return wire.WritePacket(m.conn, msg, 0) +} - // read packet length - var length uint32 - if err := binary.Read(conn, binary.BigEndian, &length); err != nil { - return nil, err +func (m *serverSession) negotiate(msg *wire.Message, milterVersion uint32, milterActions OptAction, milterProtocol OptProtocol, callback NegotiationCallbackFunc, macroRequests macroRequests, usedMaxData DataSize) (*Response, error) { + if msg.Code != wire.CodeOptNeg { + return nil, fmt.Errorf("milter: negotiate: unexpected package with code %c", msg.Code) } - - // read packet data - data := make([]byte, length) - if _, err := io.ReadFull(conn, data); err != nil { - return nil, err + if len(msg.Data) < 4*3 /* version + action mask + proto mask */ { + return nil, fmt.Errorf("milter: negotiate: unexpected data size: %d", len(msg.Data)) } - - // prepare response data - message := Message{ - Code: data[0], - Data: data[1:], + mtaVersion := binary.BigEndian.Uint32(msg.Data[:4]) + mtaActionMask := OptAction(binary.BigEndian.Uint32(msg.Data[4:])) + mtaProtoMask := OptProtocol(binary.BigEndian.Uint32(msg.Data[8:])) + offeredMaxDataSize := DataSize64K + if uint32(mtaProtoMask)&optMds1M == optMds1M { + offeredMaxDataSize = DataSize1M + } else if uint32(mtaProtoMask)&optMds256K == optMds256K { + offeredMaxDataSize = DataSize256K } + mtaProtoMask = mtaProtoMask & (^OptProtocol(optInternal)) - return &message, nil -} - -// WritePacket sends a milter response packet to socket stream -func (m *milterSession) WritePacket(msg *Message) error { - return writePacket(m.conn, msg, 0) -} - -func writePacket(conn net.Conn, msg *Message, timeout time.Duration) error { - if timeout != 0 { - conn.SetWriteDeadline(time.Now().Add(timeout)) - defer conn.SetWriteDeadline(time.Time{}) + var err error + var maxDataSize DataSize + if callback != nil { + if m.version, m.actions, m.protocol, maxDataSize, err = callback(mtaVersion, milterVersion, mtaActionMask, milterActions, mtaProtoMask, milterProtocol, offeredMaxDataSize); err != nil { + return nil, err + } + } else { + if mtaVersion < 2 || mtaVersion > MaxServerProtocolVersion { + return nil, fmt.Errorf("milter: negotiate: unsupported protocol version: %d", mtaVersion) + } + m.version = mtaVersion + if milterActions&mtaActionMask != milterActions { + return nil, fmt.Errorf("milter: negotiate: MTA does not offer required actions. offered: %032b requested: %032b", mtaActionMask, milterActions) + } + m.actions = milterActions & mtaActionMask + if milterProtocol&mtaProtoMask != milterProtocol { + return nil, fmt.Errorf("milter: negotiate: MTA does not offer required protocol options. offered: %032b requested: %032b", mtaProtoMask, milterProtocol) + } + m.protocol = milterProtocol & mtaProtoMask + maxDataSize = offeredMaxDataSize } - - buffer := bufio.NewWriter(conn) - - // calculate and write response length - length := uint32(len(msg.Data) + 1) - if err := binary.Write(buffer, binary.BigEndian, length); err != nil { - return err + if m.version < 2 || m.version > MaxServerProtocolVersion { + return nil, fmt.Errorf("milter: negotiate: unsupported protocol version: %d", m.version) } - - // write response code - if err := buffer.WriteByte(msg.Code); err != nil { - return err + if maxDataSize != DataSize64K && maxDataSize != DataSize256K && maxDataSize != DataSize1M { + maxDataSize = DataSize64K } + if usedMaxData == 0 { + usedMaxData = maxDataSize + } + m.maxDataSize = usedMaxData + + // TODO: activate skip response according to m.version - // write response data - if _, err := buffer.Write(msg.Data); err != nil { - return err + sizeMask := uint32(0) + if maxDataSize == DataSize256K { + sizeMask = optMds256K + } else if maxDataSize == DataSize1M { + sizeMask = optMds1M } - // flush data to network socket stream - if err := buffer.Flush(); err != nil { - return err + // prepare response data + var buffer bytes.Buffer + for _, value := range []uint32{m.version, uint32(m.actions), uint32(m.protocol) | sizeMask} { + if err := binary.Write(&buffer, binary.BigEndian, value); err != nil { + return nil, fmt.Errorf("milter: negotiate: %w", err) + } + } + // send the macros we want to have in the response + if macroRequests != nil && mtaActionMask&OptSetMacros != 0 { + for st := 0; st < int(StageEndMarker) && st < len(macroRequests); st++ { + if macroRequests[st] != nil && len(macroRequests[st]) > 0 { + if err := binary.Write(&buffer, binary.BigEndian, uint32(st)); err != nil { + return nil, fmt.Errorf("milter: negotiate: %w", err) + } + buffer.WriteString(strings.Join(macroRequests[st], " ")) + buffer.WriteByte(0) + } + } + } else if macroRequests != nil { + LogWarning("milter could not send the needed macros since MTA does not support this") } + // build negotiation response + return newResponse(wire.CodeOptNeg, buffer.Bytes()), nil +} - return nil +func (m *serverSession) newBackend() Milter { + return m.server.options.newMilter(m.version, m.actions, m.protocol, m.maxDataSize) } // Process processes incoming milter commands -func (m *milterSession) Process(msg *Message) (Response, error) { - switch Code(msg.Code) { - case CodeAbort: - // abort current message and start over - defer func() { - m.headers = nil - m.macros = nil - }() - return nil, m.backend.Abort(newModifier(m)) - - case CodeBody: - // body chunk - return m.backend.BodyChunk(msg.Data, newModifier(m)) - - case CodeConn: - // new connection, get hostname - hostname := readCString(msg.Data) +func (m *serverSession) Process(msg *wire.Message) (*Response, error) { + switch msg.Code { + case wire.CodeOptNeg: + return nil, fmt.Errorf("milter: negotiate: can only be called once in a connection") + + case wire.CodeConn: + if len(msg.Data) == 0 { + return nil, fmt.Errorf("milter: conn: unexpected data size: %d", len(msg.Data)) + } + m.macros.DelStageAndAbove(StageHelo) + hostname := wire.ReadCString(msg.Data) msg.Data = msg.Data[len(hostname)+1:] // get protocol family protocolFamily := msg.Data[0] msg.Data = msg.Data[1:] - // get port + // get port and address var port uint16 - if protocolFamily == '4' || protocolFamily == '6' { + var address string + if protocolFamily == 'L' || protocolFamily == '4' || protocolFamily == '6' { if len(msg.Data) < 2 { - return RespTempFail, nil + return nil, fmt.Errorf("milter: conn: unexpected data size: %d", len(msg.Data)) } port = binary.BigEndian.Uint16(msg.Data) msg.Data = msg.Data[2:] + // get address + address = wire.ReadCString(msg.Data) } - // get address - address := readCString(msg.Data) - // convert address and port to human readable string - family := map[byte]string{ - 'U': "unknown", - 'L': "unix", - '4': "tcp4", - '6': "tcp6", + // convert family to human-readable string and validate + family := "" + switch protocolFamily { + case 'U': + family = "unknown" + case 'L': + family = "unix" + case '4': + family = "tcp4" + addr := net.ParseIP(address) + if addr == nil || addr.To4() == nil { + return nil, fmt.Errorf("milter: conn: unexpected ip4 address: %q", address) + } + case '6': + family = "tcp6" + var addr net.IP + // also accept [dead::cafe] style IPv6 addresses + if len(address) > 2 && address[0] == '[' && address[len(address)-1] == ']' { + addr = net.ParseIP(address[1 : len(address)-1]) + if addr != nil { + address = addr.String() + } + } else { + addr = net.ParseIP(address) + } + if addr == nil { + return nil, fmt.Errorf("milter: conn: unexpected ip6 address: %q", address) + } + default: + return nil, fmt.Errorf("milter: conn: unexpected protocol family: %c", protocolFamily) } // run handler and return return m.backend.Connect( hostname, - family[protocolFamily], + family, port, - net.ParseIP(address), - newModifier(m)) + address, + newModifier(m, true)) - case CodeMacro: - // define macros - m.macros = make(map[string]string) - // convert data to Go strings - data := decodeCStrings(msg.Data[1:]) - if len(data) != 0 { - if len(data)%2 == 1 { - data = append(data, "") - } - - // store data in a map - for i := 0; i < len(data); i += 2 { - m.macros[data[i]] = data[i+1] - } + case wire.CodeHelo: + if len(msg.Data) == 0 { + return nil, fmt.Errorf("milter: helo: unexpected data size: %d", len(msg.Data)) } - // do not send response - return nil, nil - - case CodeEOB: - // call and return milter handler - return m.backend.Body(newModifier(m)) - - case CodeHelo: - // helo command - name := strings.TrimSuffix(string(msg.Data), null) - return m.backend.Helo(name, newModifier(m)) + m.macros.DelStageAndAbove(StageMail) + name := wire.ReadCString(msg.Data) + return m.backend.Helo(name, newModifier(m, true)) - case CodeHeader: + case wire.CodeMail: + if len(msg.Data) == 0 { + return nil, fmt.Errorf("milter: mail: unexpected data size: %d", len(msg.Data)) + } + m.macros.DelStageAndAbove(StageRcpt) + from := wire.ReadCString(msg.Data) + msg.Data = msg.Data[len(from)+1:] + esmtpArgs := wire.ReadCString(msg.Data) + return m.backend.MailFrom(removeHats(from), esmtpArgs, newModifier(m, true)) + + case wire.CodeRcpt: + if len(msg.Data) == 0 { + return nil, fmt.Errorf("milter: rcpt: unexpected data size: %d", len(msg.Data)) + } + m.macros.DelStageAndAbove(StageData) + to := wire.ReadCString(msg.Data) + msg.Data = msg.Data[len(to)+1:] + esmtpArgs := wire.ReadCString(msg.Data) + return m.backend.RcptTo(removeHats(to), esmtpArgs, newModifier(m, true)) + + case wire.CodeData: + m.macros.DelStageAndAbove(StageEOH) + return m.backend.Data(newModifier(m, true)) + + case wire.CodeHeader: + if len(msg.Data) < 2 { + return nil, fmt.Errorf("milter: header: unexpected data size: %d", len(msg.Data)) + } // make sure headers is initialized if m.headers == nil { m.headers = make(textproto.MIMEHeader) } // add new header to headers map - headerData := decodeCStrings(msg.Data) - // headers with an empty body appear as `text\x00\x00`, decodeCStrings will drop the empty body - if len(headerData) == 1 { - headerData = append(headerData, "") + headerData := wire.DecodeCStrings(msg.Data) + if len(headerData) != 2 { + return nil, fmt.Errorf("milter: header: unexpected number of strings: %d", len(headerData)) } - if len(headerData) == 2 { - m.headers.Add(headerData[0], headerData[1]) - // call and return milter handler - return m.backend.Header(headerData[0], headerData[1], newModifier(m)) + m.headers.Add(headerData[0], headerData[1]) + // call and return milter handler + resp, err := m.backend.Header(headerData[0], headerData[1], newModifier(m, true)) + m.macros.DelStageAndAbove(StageEndMarker) + return resp, err + + case wire.CodeEOH: + m.macros.DelStageAndAbove(StageEOM) + return m.backend.Headers(newModifier(m, true)) + + case wire.CodeBody: + resp, err := m.backend.BodyChunk(msg.Data, newModifier(m, true)) + m.macros.DelStageAndAbove(StageEndMarker) + return resp, err + + case wire.CodeEOB: + return m.backend.EndOfMessage(newModifier(m, false)) + + case wire.CodeUnknown: + cmd := wire.ReadCString(msg.Data) + resp, err := m.backend.Unknown(cmd, newModifier(m, true)) + m.macros.DelStageAndAbove(StageEndMarker) + return resp, err + + case wire.CodeMacro: + if len(msg.Data) == 0 { + return nil, fmt.Errorf("milter: macro: unexpected data size: %d", len(msg.Data)) } - - case CodeMail: - // envelope from address - from := readCString(msg.Data) - return m.backend.MailFrom(strings.Trim(from, "<>"), newModifier(m)) - - case CodeEOH: - // end of headers - return m.backend.Headers(m.headers, newModifier(m)) - - case CodeOptNeg: - // ignore request and prepare response buffer - var buffer bytes.Buffer - // prepare response data - for _, value := range []uint32{serverProtocolVersion, uint32(m.actions), uint32(m.protocol)} { - if err := binary.Write(&buffer, binary.BigEndian, value); err != nil { - return nil, err + code := wire.Code(msg.Data[0]) + var stage MacroStage + switch code { + case wire.CodeConn: + stage = StageConnect + case wire.CodeHelo: + stage = StageHelo + case wire.CodeMail: + stage = StageMail + case wire.CodeRcpt: + stage = StageRcpt + case wire.CodeData: + stage = StageData + case wire.CodeEOH: + stage = StageEOH + case wire.CodeEOB: + stage = StageEOM + case wire.CodeUnknown, wire.CodeHeader, wire.CodeAbort, wire.CodeBody: + stage = StageEndMarker // this stage gets cleared after the command + default: + LogWarning("MTA sent macro for %c. we cannot handle this so we ignore it", code) + return nil, nil + } + m.macros.DelStageAndAbove(stage) + // convert data to Go strings + data := wire.DecodeCStrings(msg.Data[1:]) + if len(data) != 0 { + if len(data)%2 == 1 { + data = append(data, "") } + m.macros.SetStage(stage, data...) } - // build and send packet - return NewResponse('O', buffer.Bytes()), nil + // do not send response + return nil, nil - case CodeQuit: - // client requested session close - return nil, errCloseSession + case wire.CodeAbort: + // abort current message and start over + err := m.backend.Abort(newModifier(m, true)) + m.headers = nil + m.macros.DelStageAndAbove(StageHelo) + return nil, err - case CodeRcpt: - // envelope to address - to := readCString(msg.Data) - return m.backend.RcptTo(strings.Trim(to, "<>"), newModifier(m)) + case wire.CodeQuitNewConn: + // abort current connection and start over + m.backend.Cleanup() + m.headers = nil + m.macros.DelStageAndAbove(StageConnect) + m.backend = m.newBackend() + // do not send response + return nil, nil - case CodeData: - // data, ignore + case wire.CodeQuit: + m.backend.Cleanup() + // client requested session close + return nil, errCloseSession default: // print error and close session - log.Printf("Unrecognized command code: %c", msg.Code) + LogWarning("Unrecognized command code: %c", msg.Code) return nil, errCloseSession } - - // by default continue with next milter message - return RespContinue, nil } -// HandleMilterComands processes all milter commands in the same connection -func (m *milterSession) HandleMilterCommands() { - defer m.conn.Close() +// HandleMilterCommands processes all milter commands in the same connection +func (m *serverSession) HandleMilterCommands() { + defer func() { + if m.backend != nil { + m.backend.Cleanup() + } + if m.conn != nil { + if err := m.conn.Close(); err != nil && err != io.EOF { + LogWarning("Error closing connection: %v", err) + } + } + }() + // first do the negotiation + msg, err := m.readPacket() + if err != nil { + LogWarning("Error reading milter command: %v", err) + return + } + resp, err := m.negotiate(msg, m.server.options.maxVersion, m.server.options.actions, m.server.options.protocol, m.server.options.negotiationCallback, m.server.options.macrosByStage, 0) + if err != nil { + LogWarning("Error negotiating: %v", err) + return + } + m.backend = m.newBackend() + if err = m.writePacket(resp.Response()); err != nil { + LogWarning("Error writing packet: %v", err) + return + } + + // now we can process the events for { - msg, err := m.ReadPacket() + msg, err := m.readPacket() if err != nil { if err != io.EOF { - log.Printf("Error reading milter command: %v", err) + LogWarning("Error reading milter command: %v", err) } return } @@ -247,23 +373,56 @@ func (m *milterSession) HandleMilterCommands() { if err != nil { if err != errCloseSession { // log error condition - log.Printf("Error performing milter command: %v", err) + LogWarning("Error performing milter command: %v", err) } return } - // ignore empty responses - if resp != nil { - // send back response message - if err = m.WritePacket(resp.Response()); err != nil { - log.Printf("Error writing packet: %v", err) - return - } + // ignore empty responses or responses we indicated to not send + if resp == nil || m.skipResponse(msg.Code) { + continue + } - if !resp.Continue() { - // prepare backend for next message - m.backend = m.server.NewMilter() - } + // send back response message + if err = m.writePacket(resp.Response()); err != nil { + LogWarning("Error writing packet: %v", err) + return } + + if !resp.Continue() { + m.backend.Cleanup() + // prepare backend for next message + m.backend = m.newBackend() + m.macros.DelStageAndAbove(StageMail) + } + } +} + +// protocolOption checks whether the option is set in negotiated options, that +// is, requested by the milter and offered by the MTA. +func (m *serverSession) protocolOption(opt OptProtocol) bool { + return m.protocol&opt != 0 +} + +func (m *serverSession) skipResponse(code wire.Code) bool { + switch code { + case wire.CodeConn: + return m.protocolOption(OptNoConnReply) + case wire.CodeHelo: + return m.protocolOption(OptNoHeloReply) + case wire.CodeMail: + return m.protocolOption(OptNoMailReply) + case wire.CodeRcpt: + return m.protocolOption(OptNoRcptReply) + case wire.CodeData: + return m.protocolOption(OptNoDataReply) + case wire.CodeUnknown: + return m.protocolOption(OptNoUnknownReply) + case wire.CodeEOH: + return m.protocolOption(OptNoEOHReply) + case wire.CodeBody: + return m.protocolOption(OptNoBodyReply) + default: + return false } } diff --git a/session_test.go b/session_test.go new file mode 100644 index 0000000..8f8e341 --- /dev/null +++ b/session_test.go @@ -0,0 +1,510 @@ +package milter + +import ( + "bytes" + "errors" + "net/textproto" + "reflect" + "testing" + + "github.com/d--j/go-milter/internal/wire" +) + +type processTestMilter struct { + cleanupCalled int + host string + family string + port uint16 + addr string + name string + from string + fromEsmtp string + rcptTo string + rcptEsmtp string + dataCalled bool + hdrName, hdrValue string + headers textproto.MIMEHeader + headersCalled bool + chunk []byte + eomCalled bool + abortCalled bool + cmd string +} + +func (p *processTestMilter) Connect(host string, family string, port uint16, addr string, m *Modifier) (*Response, error) { + p.host = host + p.family = family + p.port = port + p.addr = addr + return RespContinue, nil +} + +func (p *processTestMilter) Helo(name string, m *Modifier) (*Response, error) { + p.name = name + return RespContinue, nil +} + +func (p *processTestMilter) MailFrom(from string, esmtpArgs string, m *Modifier) (*Response, error) { + p.from = from + p.fromEsmtp = esmtpArgs + return RespContinue, nil +} + +func (p *processTestMilter) RcptTo(rcptTo string, esmtpArgs string, m *Modifier) (*Response, error) { + p.rcptTo = rcptTo + p.rcptEsmtp = esmtpArgs + return RespContinue, nil +} + +func (p *processTestMilter) Data(m *Modifier) (*Response, error) { + p.dataCalled = true + return RespContinue, nil +} + +func (p *processTestMilter) Header(name string, value string, m *Modifier) (*Response, error) { + p.hdrName = name + p.hdrValue = value + return RespContinue, nil +} + +func (p *processTestMilter) Headers(m *Modifier) (*Response, error) { + p.headers = m.Headers + p.headersCalled = true + return RespContinue, nil +} + +func (p *processTestMilter) BodyChunk(chunk []byte, m *Modifier) (*Response, error) { + p.chunk = chunk + return RespContinue, nil +} + +func (p *processTestMilter) EndOfMessage(m *Modifier) (*Response, error) { + p.eomCalled = true + return RespAccept, nil +} + +func (p *processTestMilter) Abort(_ *Modifier) error { + p.abortCalled = true + return nil +} + +func (p *processTestMilter) Unknown(cmd string, m *Modifier) (*Response, error) { + p.cmd = cmd + return RespContinue, nil +} + +func (p *processTestMilter) Cleanup() { + p.cleanupCalled++ +} + +var _ Milter = &processTestMilter{} + +func Test_milterSession_negotiate(t *testing.T) { + type fields struct { + milterVersion uint32 + milterActions OptAction + milterProtocol OptProtocol + callback NegotiationCallbackFunc + macroRequests macroRequests + } + + tests := []struct { + name string + fields fields + msg *wire.Message + want *wire.Message + wantErr bool + }{ + {"negotiation error 1", fields{}, &wire.Message{wire.CodeOptNeg, nil}, nil, true}, + {"negotiation error 2", fields{}, &wire.Message{wire.CodeOptNeg, []byte{0, 0, 0, 99, 0, 0, 0, 0, 0, 0, 0, 0}}, nil, true}, + {"negotiation error 3", fields{}, &wire.Message{wire.CodeOptNeg, []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}}, nil, true}, + {"negotiation error 4", fields{callback: func(mtaVersion, milterVersion uint32, mtaActions, milterActions OptAction, mtaProtocol, milterProtocol OptProtocol, offeredMaxData DataSize) (version uint32, actions OptAction, protocol OptProtocol, maxData DataSize, err error) { + return 0, 0, 0, 0, errors.New("error") + }}, &wire.Message{wire.CodeOptNeg, []byte{0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0}}, nil, true}, + {"negotiation", fields{callback: func(mtaVersion, milterVersion uint32, mtaActions, milterActions OptAction, mtaProtocol, milterProtocol OptProtocol, offeredMaxData DataSize) (version uint32, actions OptAction, protocol OptProtocol, maxData DataSize, err error) { + return milterVersion, OptAddHeader, OptNoConnect, DataSize64K, nil + }}, &wire.Message{wire.CodeOptNeg, []byte{0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0}}, &wire.Message{wire.CodeOptNeg, []byte{0, 0, 0, 6, 0, 0, 0, 1, 0, 0, 0, 1}}, false}, + {"negotiation macros", fields{milterActions: OptSetMacros, macroRequests: macroRequests{{"j", "_"}, {"i"}}}, &wire.Message{wire.CodeOptNeg, []byte{0, 0, 0, 2, 0, 0, 1, 0, 0, 0, 0, 0}}, &wire.Message{wire.CodeOptNeg, []byte{0, 0, 0, 2, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 'j', ' ', '_', 0, 0, 0, 0, 1, 'i', 0}}, false}, + } + for _, tt_ := range tests { + t.Run(tt_.name, func(t *testing.T) { + tt := tt_ + t.Parallel() + m := &serverSession{} + milterVersion := tt.fields.milterVersion + if milterVersion == 0 { + milterVersion = MaxServerProtocolVersion + } + gotR, err := m.negotiate(tt.msg, milterVersion, tt.fields.milterActions, tt.fields.milterProtocol, tt.fields.callback, tt.fields.macroRequests, 0) + if (err != nil) != tt.wantErr { + t.Errorf("Process() error = %v, wantErr %v", err, tt.wantErr) + return + } + var got *wire.Message + if gotR != nil { + got = gotR.Response() + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Process() got = %v, want %v", got, tt.want) + } + }) + } +} + +func Test_milterSession_Process(t *testing.T) { + type fields struct { + actions OptAction + protocol OptProtocol + backend Milter + check func(*testing.T, *serverSession) + } + cont := &wire.Message{wire.Code(wire.ActContinue), nil} + + tests := []struct { + name string + fields fields + msg *wire.Message + want *wire.Message + wantErr bool + }{ + {"abort", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + p := s.backend.(*processTestMilter) + if p.cleanupCalled != 0 { + t.Errorf("Cleanup() called %d times", p.cleanupCalled) + } + if !p.abortCalled { + t.Errorf("Abort() not called") + } + }, + }, &wire.Message{wire.CodeAbort, nil}, nil, false}, + {"quit-new-conn", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + if s.backend.(*processTestMilter).cleanupCalled != 1 { + t.Fatalf("Cleanup() called %d times", s.backend.(*processTestMilter).cleanupCalled) + } + }, + }, &wire.Message{wire.CodeQuitNewConn, nil}, nil, false}, + {"quit", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + if s.backend.(*processTestMilter).cleanupCalled != 1 { + t.Fatalf("Cleanup() called %d times", s.backend.(*processTestMilter).cleanupCalled) + } + }, + }, &wire.Message{wire.CodeQuit, nil}, nil, true}, + {"unknown", fields{ + backend: &processTestMilter{}, + }, &wire.Message{wire.Code('@'), nil}, nil, true}, + {"conn err 1", fields{backend: &processTestMilter{}}, &wire.Message{wire.CodeConn, nil}, nil, true}, + {"conn unknown protocol", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + p := s.backend.(*processTestMilter) + if p.family != "unknown" { + t.Errorf("expected unknown, got %q", p.family) + } + if p.addr != "" { + t.Errorf("expected \"\", got %q", p.addr) + } + if p.port != 0 { + t.Errorf("expected 0, got %v", p.port) + } + if p.host != "" { + t.Errorf("expected \"\", got %q", p.host) + } + }, + }, &wire.Message{wire.CodeConn, []byte{0, 'U'}}, cont, false}, + {"conn unix protocol", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + p := s.backend.(*processTestMilter) + if p.family != "unix" { + t.Errorf("expected unix, got %q", p.family) + } + if p.addr != "/run" { + t.Errorf("expected /run, got %q", p.addr) + } + if p.port != 0 { + t.Errorf("expected 0, got %v", p.port) + } + if p.host != "h" { + t.Errorf("expected \"h\", got %q", p.host) + } + }, + }, &wire.Message{wire.CodeConn, []byte{'h', 0, 'L', 0, 0, '/', 'r', 'u', 'n', 0}}, cont, false}, + {"conn tcp4 protocol", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + p := s.backend.(*processTestMilter) + if p.family != "tcp4" { + t.Errorf("expected tcp4, got %q", p.family) + } + if p.addr != "127.0.0.12" { + t.Errorf("expected 127.0.0.12, got %q", p.addr) + } + if p.port != 2555 { + t.Errorf("expected 2555, got %v", p.port) + } + if p.host != "h" { + t.Errorf("expected \"h\", got %q", p.host) + } + }, + }, &wire.Message{wire.CodeConn, []byte{'h', 0, '4', 9, 251, '1', '2', '7', '.', '0', '.', '0', '.', '1', '2', 0}}, cont, false}, + {"conn tcp4 protocol err", fields{ + backend: &processTestMilter{}, + }, &wire.Message{wire.CodeConn, []byte{'h', 0, '4', 9, 251, '6', '6', '6', '.', '0', '.', '0', '.', '1', '2', 0}}, nil, true}, + {"conn tcp6 protocol", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + p := s.backend.(*processTestMilter) + if p.family != "tcp6" { + t.Errorf("expected tcp4, got %q", p.family) + } + if p.addr != "::" { + t.Errorf("expected ::, got %q", p.addr) + } + if p.port != 2555 { + t.Errorf("expected 2555, got %v", p.port) + } + if p.host != "h" { + t.Errorf("expected \"h\", got %q", p.host) + } + }, + }, &wire.Message{wire.CodeConn, []byte{'h', 0, '6', 9, 251, ':', ':', 0}}, cont, false}, + {"conn tcp6 protocol 2", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + p := s.backend.(*processTestMilter) + if p.family != "tcp6" { + t.Errorf("expected tcp4, got %q", p.family) + } + if p.addr != "::" { + t.Errorf("expected ::, got %q", p.addr) + } + if p.port != 2555 { + t.Errorf("expected 2555, got %v", p.port) + } + if p.host != "h" { + t.Errorf("expected \"h\", got %q", p.host) + } + }, + }, &wire.Message{wire.CodeConn, []byte{'h', 0, '6', 9, 251, '[', ':', ':', ']', 0}}, cont, false}, + {"conn tcp6 protocol err", fields{ + backend: &processTestMilter{}, + }, &wire.Message{wire.CodeConn, []byte{'h', 0, '6', 9, 251, '[', '@', ']', 0}}, nil, true}, + {"conn tcp6 protocol err 2", fields{ + backend: &processTestMilter{}, + }, &wire.Message{wire.CodeConn, []byte{'h', 0, '6', 9}}, nil, true}, + {"conn bogus protocol err", fields{ + backend: &processTestMilter{}, + }, &wire.Message{wire.CodeConn, []byte{'h', 0, '+', 9, 251, '[', ':', ':', ']', 0}}, nil, true}, + {"helo", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + p := s.backend.(*processTestMilter) + if p.name != "h" { + t.Errorf("expected h, got %q", p.name) + } + }, + }, &wire.Message{wire.CodeHelo, []byte{'h', 0}}, cont, false}, + {"helo err", fields{ + backend: &processTestMilter{}, + }, &wire.Message{wire.CodeHelo, []byte{}}, nil, true}, + {"mail", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + p := s.backend.(*processTestMilter) + if p.from != "r" { + t.Errorf("expected r, got %q", p.from) + } + if p.fromEsmtp != "" { + t.Errorf("expected \"\", got %q", p.fromEsmtp) + } + }, + }, &wire.Message{wire.CodeMail, []byte{'<', 'r', '>', 0}}, cont, false}, + {"mail esmtp", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + p := s.backend.(*processTestMilter) + if p.from != "r" { + t.Errorf("expected r, got %q", p.from) + } + if p.fromEsmtp != "A=B" { + t.Errorf("expected A=B, got %q", p.fromEsmtp) + } + }, + }, &wire.Message{wire.CodeMail, []byte{'<', 'r', '>', 0, 'A', '=', 'B', 0}}, cont, false}, + {"mail err", fields{ + backend: &processTestMilter{}, + }, &wire.Message{wire.CodeMail, []byte{}}, nil, true}, + {"rcpt", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + p := s.backend.(*processTestMilter) + if p.rcptTo != "r" { + t.Errorf("expected r, got %q", p.rcptTo) + } + if p.rcptEsmtp != "" { + t.Errorf("expected \"\", got %q", p.rcptEsmtp) + } + }, + }, &wire.Message{wire.CodeRcpt, []byte{'<', 'r', '>', 0}}, cont, false}, + {"rcpt esmtp", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + p := s.backend.(*processTestMilter) + if p.rcptTo != "r" { + t.Errorf("expected r, got %q", p.rcptTo) + } + if p.rcptEsmtp != "A=B" { + t.Errorf("expected A=B, got %q", p.rcptEsmtp) + } + }, + }, &wire.Message{wire.CodeRcpt, []byte{'<', 'r', '>', 0, 'A', '=', 'B', 0}}, cont, false}, + {"rcpt err", fields{ + backend: &processTestMilter{}, + }, &wire.Message{wire.CodeRcpt, []byte{}}, nil, true}, + {"data", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + p := s.backend.(*processTestMilter) + if !p.dataCalled { + t.Errorf("expected dataCalled true") + } + }, + }, &wire.Message{wire.CodeData, nil}, cont, false}, + {"header", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + p := s.backend.(*processTestMilter) + if p.hdrName != "To" { + t.Errorf("expected To, got %q", p.hdrName) + } + if p.hdrValue != "<>" { + t.Errorf("expected <>, got %q", p.hdrName) + } + }, + }, &wire.Message{wire.CodeHeader, []byte{'T', 'o', 0, '<', '>', 0}}, cont, false}, + {"header err 1", fields{ + backend: &processTestMilter{}, + }, &wire.Message{wire.CodeHeader, []byte{'T', 'o', 0}}, nil, true}, + {"header err 2", fields{ + backend: &processTestMilter{}, + }, &wire.Message{wire.CodeHeader, []byte{}}, nil, true}, + {"eoh", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + p := s.backend.(*processTestMilter) + if !p.headersCalled { + t.Errorf("Headers() not called") + } + }, + }, &wire.Message{wire.CodeEOH, nil}, cont, false}, + {"body empty", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + p := s.backend.(*processTestMilter) + if p.chunk == nil || len(p.chunk) > 0 { + t.Errorf("expected \"\", got %q", p.chunk) + } + }, + }, &wire.Message{wire.CodeBody, []byte{}}, cont, false}, + {"body", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + p := s.backend.(*processTestMilter) + if !bytes.Equal(p.chunk, []byte("abc")) { + t.Errorf("expected \"abc\", got %q", p.chunk) + } + }, + }, &wire.Message{wire.CodeBody, []byte{'a', 'b', 'c'}}, cont, false}, + {"end", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + p := s.backend.(*processTestMilter) + if !p.eomCalled { + t.Errorf("EndOfMessage() not called") + } + }, + }, &wire.Message{wire.CodeEOB, []byte{'a', 'b', 'c'}}, &wire.Message{wire.Code(wire.ActAccept), nil}, false}, + {"unknown", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + p := s.backend.(*processTestMilter) + if p.cmd != "abc" { + t.Errorf("expected abc, got %q", p.cmd) + } + }, + }, &wire.Message{wire.CodeUnknown, []byte{'a', 'b', 'c', 0}}, cont, false}, + {"macro 1", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + if s.macros.byStages[StageConnect] != nil { + t.Errorf("should be ") + } + }, + }, &wire.Message{wire.CodeMacro, []byte{byte(wire.CodeConn)}}, nil, false}, + {"macro 2", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + expect := map[MacroName]string{ + MacroMTAFullyQualifiedDomainName: "1", + MacroQueueId: "2", + } + if !reflect.DeepEqual(expect, s.macros.byStages[StageConnect]) { + t.Errorf("expect %+v, got %+v", expect, s.macros.byStages[StageConnect]) + } + }, + }, &wire.Message{wire.CodeMacro, []byte{byte(wire.CodeConn), 'j', 0, '1', 0, 'i', 0, '2', 0}}, nil, false}, + {"macro 3", fields{ + backend: &processTestMilter{}, + check: func(t *testing.T, s *serverSession) { + expect := map[MacroName]string{ + MacroMTAFullyQualifiedDomainName: "1", + MacroQueueId: "", + } + if !reflect.DeepEqual(expect, s.macros.byStages[StageConnect]) { + t.Errorf("expect %+v, got %+v", expect, s.macros.byStages[StageConnect]) + } + }, + }, &wire.Message{wire.CodeMacro, []byte{byte(wire.CodeConn), 'j', 0, '1', 0, 'i', 0}}, nil, false}, + {"macro err", fields{ + backend: &processTestMilter{}, + }, &wire.Message{wire.CodeMacro, []byte{}}, nil, true}, + } + for _, tt_ := range tests { + t.Run(tt_.name, func(t *testing.T) { + tt := tt_ + t.Parallel() + s := NewServer(WithMilter(func() Milter { + return tt.fields.backend + })) + m := &serverSession{ + server: s, + version: MaxServerProtocolVersion, + actions: tt.fields.actions, + protocol: tt.fields.protocol, + macros: newMacroStages(), + backend: tt.fields.backend, + } + gotR, err := m.Process(tt.msg) + if (err != nil) != tt.wantErr { + t.Errorf("Process() error = %v, wantErr %v", err, tt.wantErr) + return + } + var got *wire.Message + if gotR != nil { + got = gotR.Response() + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("Process() got = %v, want %v", got, tt.want) + } + if tt.fields.check != nil { + tt.fields.check(t, m) + } + }) + } +}