Skip to content

Commit

Permalink
Merge pull request #353 from ibice/limit-read-message-size
Browse files Browse the repository at this point in the history
Limit read message size
  • Loading branch information
mreiferson authored May 30, 2023
2 parents 0e8d7a7 + dc8315d commit c647fa6
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 9 deletions.
3 changes: 3 additions & 0 deletions config.go
Original file line number Diff line number Diff line change
Expand Up @@ -177,6 +177,9 @@ type Config struct {
// The server-side message timeout for messages delivered to this client
MsgTimeout time.Duration `opt:"msg_timeout" min:"0"`

// Maximum size of a single message in bytes (0 means no limit)
MaxMsgSize int32 `opt:"max_msg_size" min:"0" default:"0"`

// Secret for nsqd authentication (requires nsqd 0.2.29+)
AuthSecret string `opt:"auth_secret"`
// Use AuthSecret as 'Authorization: Bearer {AuthSecret}' on lookupd queries
Expand Down
12 changes: 6 additions & 6 deletions conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,7 +355,7 @@ func (c *Conn) identify() (*IdentifyResponse, error) {
return nil, ErrIdentify{err.Error()}
}

frameType, data, err := ReadUnpackedResponse(c)
frameType, data, err := ReadUnpackedResponse(c, c.config.MaxMsgSize)
if err != nil {
return nil, ErrIdentify{err.Error()}
}
Expand Down Expand Up @@ -434,7 +434,7 @@ func (c *Conn) upgradeTLS(tlsConf *tls.Config) error {
}
c.r = c.tlsConn
c.w = c.tlsConn
frameType, data, err := ReadUnpackedResponse(c)
frameType, data, err := ReadUnpackedResponse(c, c.config.MaxMsgSize)
if err != nil {
return err
}
Expand All @@ -452,7 +452,7 @@ func (c *Conn) upgradeDeflate(level int) error {
fw, _ := flate.NewWriter(conn, level)
c.r = flate.NewReader(conn)
c.w = fw
frameType, data, err := ReadUnpackedResponse(c)
frameType, data, err := ReadUnpackedResponse(c, c.config.MaxMsgSize)
if err != nil {
return err
}
Expand All @@ -469,7 +469,7 @@ func (c *Conn) upgradeSnappy() error {
}
c.r = snappy.NewReader(conn)
c.w = snappy.NewWriter(conn)
frameType, data, err := ReadUnpackedResponse(c)
frameType, data, err := ReadUnpackedResponse(c, c.config.MaxMsgSize)
if err != nil {
return err
}
Expand All @@ -490,7 +490,7 @@ func (c *Conn) auth(secret string) error {
return err
}

frameType, data, err := ReadUnpackedResponse(c)
frameType, data, err := ReadUnpackedResponse(c, c.config.MaxMsgSize)
if err != nil {
return err
}
Expand Down Expand Up @@ -518,7 +518,7 @@ func (c *Conn) readLoop() {
goto exit
}

frameType, data, err := ReadUnpackedResponse(c)
frameType, data, err := ReadUnpackedResponse(c, c.config.MaxMsgSize)
if err != nil {
if err == io.EOF && atomic.LoadInt32(&c.closeFlag) == 1 {
goto exit
Expand Down
23 changes: 23 additions & 0 deletions producer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"os"
"runtime"
"strconv"
"strings"
"sync"
"sync/atomic"
"testing"
Expand Down Expand Up @@ -255,6 +256,28 @@ func TestProducerHeartbeat(t *testing.T) {
readMessages(topicName, t, msgCount+1)
}

func TestProducerHTTPConnectionFails(t *testing.T) {
config := NewConfig()
laddr := "127.0.0.1"

config.LocalAddr, _ = net.ResolveTCPAddr("tcp", laddr+":0")
config.MaxMsgSize = 1048576

w, _ := NewProducer("127.0.0.1:4151", config)
w.SetLogger(nullLogger, LogLevelInfo)

err := w.Publish("write_test", []byte("test"))
if err == nil {
t.Fatal("should fail connecting to HTTP endpoint", err)
}

if !strings.Contains(err.Error(), "unexpected HTTP response") {
t.Fatalf("should detect unexpected HTTP response, but got err: %s", err)
}

w.Stop()
}

func readMessages(topicName string, t *testing.T, msgCount int) {
config := NewConfig()
config.DefaultRequeueDelay = 0
Expand Down
17 changes: 14 additions & 3 deletions protocol.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@ const (
FrameTypeMessage int32 = 2
)

// Used to detect if an unexpected HTTP response is read
const httpResponseMsgSize = 1213486160

var validTopicChannelNameRegex = regexp.MustCompile(`^[\.a-zA-Z0-9_-]+(#ephemeral)?$`)

// IsValidTopicName checks a topic name for correctness
Expand Down Expand Up @@ -48,7 +51,7 @@ func isValidName(name string) bool {
// | 4-byte || N-byte
// ------------------------...
// size data
func ReadResponse(r io.Reader) ([]byte, error) {
func ReadResponse(r io.Reader, maxMsgSize int32) ([]byte, error) {
var msgSize int32

// message size
Expand All @@ -60,6 +63,14 @@ func ReadResponse(r io.Reader) ([]byte, error) {
if msgSize < 0 {
return nil, fmt.Errorf("response msg size is negative: %v", msgSize)
}

if maxMsgSize > 0 && msgSize > maxMsgSize {
if msgSize == httpResponseMsgSize {
return nil, fmt.Errorf("unexpected HTTP response, a nsqd TCP endpoint is required")
}
return nil, fmt.Errorf("response msg size %v exceeds configured maximum (%v)", msgSize, maxMsgSize)
}

// message binary data
buf := make([]byte, msgSize)
_, err = io.ReadFull(r, buf)
Expand Down Expand Up @@ -91,8 +102,8 @@ func UnpackResponse(response []byte) (int32, []byte, error) {
// ReadUnpackedResponse reads and parses data from the underlying
// TCP connection according to the NSQ TCP protocol spec and
// returns the frameType, data or error
func ReadUnpackedResponse(r io.Reader) (int32, []byte, error) {
resp, err := ReadResponse(r)
func ReadUnpackedResponse(r io.Reader, maxMsgSize int32) (int32, []byte, error) {
resp, err := ReadResponse(r, maxMsgSize)
if err != nil {
return -1, nil, err
}
Expand Down

0 comments on commit c647fa6

Please sign in to comment.