diff --git a/.github/workflows/pr-check.yml b/.github/workflows/pr-check.yml index 9d397d0..b1e8a08 100644 --- a/.github/workflows/pr-check.yml +++ b/.github/workflows/pr-check.yml @@ -24,7 +24,7 @@ jobs: - name: Set up Go uses: actions/setup-go@v3 with: - go-version: 1.18 + go-version: 1.16 - uses: actions/cache@v3 with: @@ -43,4 +43,4 @@ jobs: # Exit with 1 when it find at least one finding. fail_on_error: true # Set staticcheck flags - staticcheck_flags: -checks=inherit,-SA1029 + staticcheck_flags: -checks=inherit,-SA1029,-SA6002 diff --git a/.gitignore b/.gitignore index 62c8935..22e98cf 100644 --- a/.gitignore +++ b/.gitignore @@ -1 +1,17 @@ -.idea/ \ No newline at end of file +# Binaries for programs and plugins +*.exe +*.exe~ +*.dll +*.so +*.dylib + +# Test binary, build with `go test -c` +*.test + +# Output of the go coverage tool, specifically when used with LiteIDE +*.out + +# goland +.idea +# vscode +.vscode diff --git a/.licenserc.yaml b/.licenserc.yaml new file mode 100644 index 0000000..cb59686 --- /dev/null +++ b/.licenserc.yaml @@ -0,0 +1,17 @@ +header: + license: + spdx-id: Apache-2.0 + copyright-owner: CloudWeGo Authors + content: | + // Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. + // Use of this source code is governed by a BSD-style + // license that can be found in the LICENSE file. + // + // This file may have been modified by CloudWeGo authors. All CloudWeGo + // Modifications are Copyright 2022 CloudWeGo Authors. + + paths: + - '**/*.go' + - '**/*.s' + + comment: on-failure diff --git a/AUTHORS b/AUTHORS new file mode 100644 index 0000000..1931f40 --- /dev/null +++ b/AUTHORS @@ -0,0 +1,9 @@ +# This is the official list of Gorilla WebSocket authors for copyright +# purposes. +# +# Please keep the list sorted. + +Gary Burd +Google LLC (https://opensource.google.com/) +Joachim Bauch + diff --git a/README.md b/README.md new file mode 100644 index 0000000..86e6404 --- /dev/null +++ b/README.md @@ -0,0 +1,57 @@ +# Hertz-WebSocket(This is a community driven project) + + +This repo is forked from [Gorilla WebSocket](https://github.com/gorilla/websocket/) and adapted to Hertz. + +### How to use +```go +package main + +import ( + "context" + "log" + + "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/app/server" + "github.com/hertz-contrib/websocket" +) + +var upgrader = websocket.HertzUpgrader{} // use default options + +func echo(_ context.Context, c *app.RequestContext) { + err := upgrader.Upgrade(c, func(conn *websocket.Conn) { + for { + mt, message, err := conn.ReadMessage() + if err != nil { + log.Println("read:", err) + break + } + log.Printf("recv: %s", message) + err = conn.WriteMessage(mt, message) + if err != nil { + log.Println("write:", err) + break + } + } + }) + if err != nil { + log.Print("upgrade:", err) + return + } +} + + +func main() { + h := server.Default(server.WithHostPorts(addr)) + // https://github.com/cloudwego/hertz/issues/121 + h.NoHijackConnPool = true + h.GET("/echo", echo) + h.Spin() +} + +``` + +### More info + +See [examples](examples/) + diff --git a/_typos.toml b/_typos.toml new file mode 100644 index 0000000..03db66d --- /dev/null +++ b/_typos.toml @@ -0,0 +1,23 @@ +# Typo check: https://github.com/crate-ci/typos + +[files] +extend-exclude = ["go.sum"] + +[default.extend-identifiers] +# *sigh* this just isn't worth the cost of fixing +ConnTLSer = "ConnTLSer" +flate = "flate" +TestCompressFlateSerial = "TestCompressFlateSerial" +testCompressFlate = "testCompressFlate" +TestCompressFlateConcurrent = "TestCompressFlateConcurrent" +trUe = "trUe" +OPTIO = "OPTIO" +contant = "contant" +referer = "referer" +HeaderReferer = "HeaderReferer" +expectedReferer = "expectedReferer" +Referer = "Referer" +flateWriterPools = "flateWriterPools" +flateReaderPool = "flateReaderPool" +flateWriteWrapper = "flateWriteWrapper" +flateReadWrapper = "flateReadWrapper" diff --git a/compression.go b/compression.go new file mode 100644 index 0000000..2c3c52f --- /dev/null +++ b/compression.go @@ -0,0 +1,151 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file may have been modified by CloudWeGo authors. All CloudWeGo +// Modifications are Copyright 2022 CloudWeGo Authors. + +package websocket + +import ( + "compress/flate" + "errors" + "io" + "strings" + "sync" +) + +const ( + minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6 + maxCompressionLevel = flate.BestCompression + defaultCompressionLevel = 1 +) + +var ( + flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool + flateReaderPool = sync.Pool{New: func() interface{} { + return flate.NewReader(nil) + }} +) + +func decompressNoContextTakeover(r io.Reader) io.ReadCloser { + const tail = + // Add four bytes as specified in RFC + "\x00\x00\xff\xff" + + // Add final block to squelch unexpected EOF error from flate reader. + "\x01\x00\x00\xff\xff" + + fr, _ := flateReaderPool.Get().(io.ReadCloser) + fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) + return &flateReadWrapper{fr} +} + +func isValidCompressionLevel(level int) bool { + return minCompressionLevel <= level && level <= maxCompressionLevel +} + +func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { + p := &flateWriterPools[level-minCompressionLevel] + tw := &truncWriter{w: w} + fw, _ := p.Get().(*flate.Writer) + if fw == nil { + fw, _ = flate.NewWriter(tw, level) + } else { + fw.Reset(tw) + } + return &flateWriteWrapper{fw: fw, tw: tw, p: p} +} + +// truncWriter is an io.Writer that writes all but the last four bytes of the +// stream to another io.Writer. +type truncWriter struct { + w io.WriteCloser + n int + p [4]byte +} + +func (w *truncWriter) Write(p []byte) (int, error) { + n := 0 + + // fill buffer first for simplicity. + if w.n < len(w.p) { + n = copy(w.p[w.n:], p) + p = p[n:] + w.n += n + if len(p) == 0 { + return n, nil + } + } + + m := len(p) + if m > len(w.p) { + m = len(w.p) + } + + if nn, err := w.w.Write(w.p[:m]); err != nil { + return n + nn, err + } + + copy(w.p[:], w.p[m:]) + copy(w.p[len(w.p)-m:], p[len(p)-m:]) + nn, err := w.w.Write(p[:len(p)-m]) + return n + nn, err +} + +type flateWriteWrapper struct { + fw *flate.Writer + tw *truncWriter + p *sync.Pool +} + +func (w *flateWriteWrapper) Write(p []byte) (int, error) { + if w.fw == nil { + return 0, errWriteClosed + } + return w.fw.Write(p) +} + +func (w *flateWriteWrapper) Close() error { + if w.fw == nil { + return errWriteClosed + } + err1 := w.fw.Flush() + w.p.Put(w.fw) + w.fw = nil + if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { + return errors.New("websocket: internal error, unexpected bytes at end of flate stream") + } + err2 := w.tw.w.Close() + if err1 != nil { + return err1 + } + return err2 +} + +type flateReadWrapper struct { + fr io.ReadCloser +} + +func (r *flateReadWrapper) Read(p []byte) (int, error) { + if r.fr == nil { + return 0, io.ErrClosedPipe + } + n, err := r.fr.Read(p) + if err == io.EOF { + // Preemptively place the reader back in the pool. This helps with + // scenarios where the application does not call NextReader() soon after + // this final read. + r.Close() + } + return n, err +} + +func (r *flateReadWrapper) Close() error { + if r.fr == nil { + return io.ErrClosedPipe + } + err := r.fr.Close() + flateReaderPool.Put(r.fr) + r.fr = nil + return err +} diff --git a/compression_test.go b/compression_test.go new file mode 100644 index 0000000..8b71841 --- /dev/null +++ b/compression_test.go @@ -0,0 +1,87 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file may have been modified by CloudWeGo authors. All CloudWeGo +// Modifications are Copyright 2022 CloudWeGo Authors. + +package websocket + +import ( + "bytes" + "fmt" + "io" + "io/ioutil" + "testing" +) + +type nopCloser struct{ io.Writer } + +func (nopCloser) Close() error { return nil } + +func TestTruncWriter(t *testing.T) { + const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321" + for n := 1; n <= 10; n++ { + var b bytes.Buffer + w := &truncWriter{w: nopCloser{&b}} + p := []byte(data) + for len(p) > 0 { + m := len(p) + if m > n { + m = n + } + w.Write(p[:m]) + p = p[m:] + } + if b.String() != data[:len(data)-len(w.p)] { + t.Errorf("%d: %q", n, b.String()) + } + } +} + +func textMessages(num int) [][]byte { + messages := make([][]byte, num) + for i := 0; i < num; i++ { + msg := fmt.Sprintf("planet: %d, country: %d, city: %d, street: %d", i, i, i, i) + messages[i] = []byte(msg) + } + return messages +} + +func BenchmarkWriteNoCompression(b *testing.B) { + w := ioutil.Discard + c := newTestConn(nil, w, false) + messages := textMessages(100) + b.ResetTimer() + for i := 0; i < b.N; i++ { + c.WriteMessage(TextMessage, messages[i%len(messages)]) + } + b.ReportAllocs() +} + +func BenchmarkWriteWithCompression(b *testing.B) { + w := ioutil.Discard + c := newTestConn(nil, w, false) + messages := textMessages(100) + c.enableWriteCompression = true + c.newCompressionWriter = compressNoContextTakeover + b.ResetTimer() + for i := 0; i < b.N; i++ { + c.WriteMessage(TextMessage, messages[i%len(messages)]) + } + b.ReportAllocs() +} + +func TestValidCompressionLevel(t *testing.T) { + c := newTestConn(nil, nil, false) + for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} { + if err := c.SetCompressionLevel(level); err == nil { + t.Errorf("no error for level %d", level) + } + } + for _, level := range []int{minCompressionLevel, maxCompressionLevel} { + if err := c.SetCompressionLevel(level); err != nil { + t.Errorf("error for level %d", level) + } + } +} diff --git a/conn.go b/conn.go new file mode 100644 index 0000000..2eb8874 --- /dev/null +++ b/conn.go @@ -0,0 +1,1238 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file may have been modified by CloudWeGo authors. All CloudWeGo +// Modifications are Copyright 2022 CloudWeGo Authors. + +package websocket + +import ( + "bufio" + "encoding/binary" + "errors" + "io" + "io/ioutil" + "math/rand" + "net" + "strconv" + "strings" + "sync" + "time" + "unicode/utf8" +) + +const ( + // Frame header byte 0 bits from Section 5.2 of RFC 6455 + finalBit = 1 << 7 + rsv1Bit = 1 << 6 + rsv2Bit = 1 << 5 + rsv3Bit = 1 << 4 + + // Frame header byte 1 bits from Section 5.2 of RFC 6455 + maskBit = 1 << 7 + + maxFrameHeaderSize = 2 + 8 + 4 // Fixed header + length + mask + maxControlFramePayloadSize = 125 + + writeWait = time.Second + + defaultReadBufferSize = 4096 + defaultWriteBufferSize = 4096 + + continuationFrame = 0 + noFrame = -1 +) + +// Close codes defined in RFC 6455, section 11.7. +const ( + CloseNormalClosure = 1000 + CloseGoingAway = 1001 + CloseProtocolError = 1002 + CloseUnsupportedData = 1003 + CloseNoStatusReceived = 1005 + CloseAbnormalClosure = 1006 + CloseInvalidFramePayloadData = 1007 + ClosePolicyViolation = 1008 + CloseMessageTooBig = 1009 + CloseMandatoryExtension = 1010 + CloseInternalServerErr = 1011 + CloseServiceRestart = 1012 + CloseTryAgainLater = 1013 + CloseTLSHandshake = 1015 +) + +// The message types are defined in RFC 6455, section 11.8. +const ( + // TextMessage denotes a text data message. The text message payload is + // interpreted as UTF-8 encoded text data. + TextMessage = 1 + + // BinaryMessage denotes a binary data message. + BinaryMessage = 2 + + // CloseMessage denotes a close control message. The optional message + // payload contains a numeric code and text. Use the FormatCloseMessage + // function to format a close message payload. + CloseMessage = 8 + + // PingMessage denotes a ping control message. The optional message payload + // is UTF-8 encoded text. + PingMessage = 9 + + // PongMessage denotes a pong control message. The optional message payload + // is UTF-8 encoded text. + PongMessage = 10 +) + +// ErrCloseSent is returned when the application writes a message to the +// connection after sending a close message. +var ErrCloseSent = errors.New("websocket: close sent") + +// ErrReadLimit is returned when reading a message that is larger than the +// read limit set for the connection. +var ErrReadLimit = errors.New("websocket: read limit exceeded") + +// netError satisfies the net Error interface. +type netError struct { + msg string + temporary bool + timeout bool +} + +func (e *netError) Error() string { return e.msg } +func (e *netError) Temporary() bool { return e.temporary } +func (e *netError) Timeout() bool { return e.timeout } + +// CloseError represents a close message. +type CloseError struct { + // Code is defined in RFC 6455, section 11.7. + Code int + + // Text is the optional text payload. + Text string +} + +func (e *CloseError) Error() string { + s := []byte("websocket: close ") + s = strconv.AppendInt(s, int64(e.Code), 10) + switch e.Code { + case CloseNormalClosure: + s = append(s, " (normal)"...) + case CloseGoingAway: + s = append(s, " (going away)"...) + case CloseProtocolError: + s = append(s, " (protocol error)"...) + case CloseUnsupportedData: + s = append(s, " (unsupported data)"...) + case CloseNoStatusReceived: + s = append(s, " (no status)"...) + case CloseAbnormalClosure: + s = append(s, " (abnormal closure)"...) + case CloseInvalidFramePayloadData: + s = append(s, " (invalid payload data)"...) + case ClosePolicyViolation: + s = append(s, " (policy violation)"...) + case CloseMessageTooBig: + s = append(s, " (message too big)"...) + case CloseMandatoryExtension: + s = append(s, " (mandatory extension missing)"...) + case CloseInternalServerErr: + s = append(s, " (internal server error)"...) + case CloseTLSHandshake: + s = append(s, " (TLS handshake error)"...) + } + if e.Text != "" { + s = append(s, ": "...) + s = append(s, e.Text...) + } + return string(s) +} + +// IsCloseError returns boolean indicating whether the error is a *CloseError +// with one of the specified codes. +func IsCloseError(err error, codes ...int) bool { + if e, ok := err.(*CloseError); ok { + for _, code := range codes { + if e.Code == code { + return true + } + } + } + return false +} + +// IsUnexpectedCloseError returns boolean indicating whether the error is a +// *CloseError with a code not in the list of expected codes. +func IsUnexpectedCloseError(err error, expectedCodes ...int) bool { + if e, ok := err.(*CloseError); ok { + for _, code := range expectedCodes { + if e.Code == code { + return false + } + } + return true + } + return false +} + +var ( + errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true, temporary: true} + errUnexpectedEOF = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()} + errBadWriteOpCode = errors.New("websocket: bad write message type") + errWriteClosed = errors.New("websocket: write closed") + errInvalidControlFrame = errors.New("websocket: invalid control frame") +) + +func newMaskKey() [4]byte { + n := rand.Uint32() + return [4]byte{byte(n), byte(n >> 8), byte(n >> 16), byte(n >> 24)} +} + +func hideTempErr(err error) error { + if e, ok := err.(net.Error); ok && e.Temporary() { + err = &netError{msg: e.Error(), timeout: e.Timeout()} + } + return err +} + +func isControl(frameType int) bool { + return frameType == CloseMessage || frameType == PingMessage || frameType == PongMessage +} + +func isData(frameType int) bool { + return frameType == TextMessage || frameType == BinaryMessage +} + +var validReceivedCloseCodes = map[int]bool{ + // see http://www.iana.org/assignments/websocket/websocket.xhtml#close-code-number + CloseNormalClosure: true, + CloseGoingAway: true, + CloseProtocolError: true, + CloseUnsupportedData: true, + CloseNoStatusReceived: false, + CloseAbnormalClosure: false, + CloseInvalidFramePayloadData: true, + ClosePolicyViolation: true, + CloseMessageTooBig: true, + CloseMandatoryExtension: true, + CloseInternalServerErr: true, + CloseServiceRestart: true, + CloseTryAgainLater: true, + CloseTLSHandshake: false, +} + +func isValidReceivedCloseCode(code int) bool { + return validReceivedCloseCodes[code] || (code >= 3000 && code <= 4999) +} + +// BufferPool represents a pool of buffers. The *sync.Pool type satisfies this +// interface. The type of the value stored in a pool is not specified. +type BufferPool interface { + // Get gets a value from the pool or returns nil if the pool is empty. + Get() interface{} + // Put adds a value to the pool. + Put(interface{}) +} + +// writePoolData is the type added to the write buffer pool. This wrapper is +// used to prevent applications from peeking at and depending on the values +// added to the pool. +type writePoolData struct{ buf []byte } + +// The Conn type represents a WebSocket connection. +type Conn struct { + conn net.Conn + isServer bool + subprotocol string + + // Write fields + mu chan struct{} // used as mutex to protect write to conn + writeBuf []byte // frame is constructed in this buffer. + writePool BufferPool + writeBufSize int + writeDeadline time.Time + writer io.WriteCloser // the current writer returned to the application + isWriting bool // for best-effort concurrent write detection + + writeErrMu sync.Mutex + writeErr error + + enableWriteCompression bool + compressionLevel int + newCompressionWriter func(io.WriteCloser, int) io.WriteCloser + + // Read fields + reader io.ReadCloser // the current reader returned to the application + readErr error + br *bufio.Reader + // bytes remaining in current frame. + // set setReadRemaining to safely update this value and prevent overflow + readRemaining int64 + readFinal bool // true the current message has more frames. + readLength int64 // Message size. + readLimit int64 // Maximum message size. + readMaskPos int + readMaskKey [4]byte + handlePong func(string) error + handlePing func(string) error + handleClose func(int, string) error + readErrCount int + messageReader *messageReader // the current low-level reader + + readDecompress bool // whether last read frame had RSV1 set + newDecompressionReader func(io.Reader) io.ReadCloser +} + +func newConn(conn net.Conn, isServer bool, readBufferSize, writeBufferSize int, writeBufferPool BufferPool, br *bufio.Reader, writeBuf []byte) *Conn { + if br == nil { + if readBufferSize == 0 { + readBufferSize = defaultReadBufferSize + } else if readBufferSize < maxControlFramePayloadSize { + // must be large enough for control frame + readBufferSize = maxControlFramePayloadSize + } + br = bufio.NewReaderSize(conn, readBufferSize) + } + + if writeBufferSize <= 0 { + writeBufferSize = defaultWriteBufferSize + } + writeBufferSize += maxFrameHeaderSize + + if writeBuf == nil && writeBufferPool == nil { + writeBuf = make([]byte, writeBufferSize) + } + + mu := make(chan struct{}, 1) + mu <- struct{}{} + c := &Conn{ + isServer: isServer, + br: br, + conn: conn, + mu: mu, + readFinal: true, + writeBuf: writeBuf, + writePool: writeBufferPool, + writeBufSize: writeBufferSize, + enableWriteCompression: true, + compressionLevel: defaultCompressionLevel, + } + c.SetCloseHandler(nil) + c.SetPingHandler(nil) + c.SetPongHandler(nil) + return c +} + +// setReadRemaining tracks the number of bytes remaining on the connection. If n +// overflows, an ErrReadLimit is returned. +func (c *Conn) setReadRemaining(n int64) error { + if n < 0 { + return ErrReadLimit + } + + c.readRemaining = n + return nil +} + +// Subprotocol returns the negotiated protocol for the connection. +func (c *Conn) Subprotocol() string { + return c.subprotocol +} + +// Close closes the underlying network connection without sending or waiting +// for a close message. +func (c *Conn) Close() error { + return c.conn.Close() +} + +// LocalAddr returns the local network address. +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// Write methods + +func (c *Conn) writeFatal(err error) error { + err = hideTempErr(err) + c.writeErrMu.Lock() + if c.writeErr == nil { + c.writeErr = err + } + c.writeErrMu.Unlock() + return err +} + +func (c *Conn) read(n int) ([]byte, error) { + p, err := c.br.Peek(n) + if err == io.EOF { + err = errUnexpectedEOF + } + c.br.Discard(len(p)) + return p, err +} + +func (c *Conn) write(frameType int, deadline time.Time, buf0, buf1 []byte) error { + <-c.mu + defer func() { c.mu <- struct{}{} }() + + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err + } + + c.conn.SetWriteDeadline(deadline) + if len(buf1) == 0 { + _, err = c.conn.Write(buf0) + } else { + err = c.writeBufs(buf0, buf1) + } + if err != nil { + return c.writeFatal(err) + } + if frameType == CloseMessage { + c.writeFatal(ErrCloseSent) + } + return nil +} + +func (c *Conn) writeBufs(bufs ...[]byte) error { + b := net.Buffers(bufs) + _, err := b.WriteTo(c.conn) + return err +} + +// WriteControl writes a control message with the given deadline. The allowed +// message types are CloseMessage, PingMessage and PongMessage. +func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) error { + if !isControl(messageType) { + return errBadWriteOpCode + } + if len(data) > maxControlFramePayloadSize { + return errInvalidControlFrame + } + + b0 := byte(messageType) | finalBit + b1 := byte(len(data)) + if !c.isServer { + b1 |= maskBit + } + + buf := make([]byte, 0, maxFrameHeaderSize+maxControlFramePayloadSize) + buf = append(buf, b0, b1) + + if c.isServer { + buf = append(buf, data...) + } else { + key := newMaskKey() + buf = append(buf, key[:]...) + buf = append(buf, data...) + maskBytes(key, 0, buf[6:]) + } + + d := 1000 * time.Hour + if !deadline.IsZero() { + d = time.Until(deadline) + if d < 0 { + return errWriteTimeout + } + } + + timer := time.NewTimer(d) + select { + case <-c.mu: + timer.Stop() + case <-timer.C: + return errWriteTimeout + } + defer func() { c.mu <- struct{}{} }() + + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err + } + + c.conn.SetWriteDeadline(deadline) + _, err = c.conn.Write(buf) + if err != nil { + return c.writeFatal(err) + } + if messageType == CloseMessage { + c.writeFatal(ErrCloseSent) + } + return err +} + +// beginMessage prepares a connection and message writer for a new message. +func (c *Conn) beginMessage(mw *messageWriter, messageType int) error { + // Close previous writer if not already closed by the application. It's + // probably better to return an error in this situation, but we cannot + // change this without breaking existing applications. + if c.writer != nil { + c.writer.Close() + c.writer = nil + } + + if !isControl(messageType) && !isData(messageType) { + return errBadWriteOpCode + } + + c.writeErrMu.Lock() + err := c.writeErr + c.writeErrMu.Unlock() + if err != nil { + return err + } + + mw.c = c + mw.frameType = messageType + mw.pos = maxFrameHeaderSize + + if c.writeBuf == nil { + wpd, ok := c.writePool.Get().(writePoolData) + if ok { + c.writeBuf = wpd.buf + } else { + c.writeBuf = make([]byte, c.writeBufSize) + } + } + return nil +} + +// NextWriter returns a writer for the next message to send. The writer's Close +// method flushes the complete message to the network. +// +// There can be at most one open writer on a connection. NextWriter closes the +// previous writer if the application has not already done so. +// +// All message types (TextMessage, BinaryMessage, CloseMessage, PingMessage and +// PongMessage) are supported. +func (c *Conn) NextWriter(messageType int) (io.WriteCloser, error) { + var mw messageWriter + if err := c.beginMessage(&mw, messageType); err != nil { + return nil, err + } + c.writer = &mw + if c.newCompressionWriter != nil && c.enableWriteCompression && isData(messageType) { + w := c.newCompressionWriter(c.writer, c.compressionLevel) + mw.compress = true + c.writer = w + } + return c.writer, nil +} + +type messageWriter struct { + c *Conn + compress bool // whether next call to flushFrame should set RSV1 + pos int // end of data in writeBuf. + frameType int // type of the current frame. + err error +} + +func (w *messageWriter) endMessage(err error) error { + if w.err != nil { + return err + } + c := w.c + w.err = err + c.writer = nil + if c.writePool != nil { + c.writePool.Put(writePoolData{buf: c.writeBuf}) + c.writeBuf = nil + } + return err +} + +// flushFrame writes buffered data and extra as a frame to the network. The +// final argument indicates that this is the last frame in the message. +func (w *messageWriter) flushFrame(final bool, extra []byte) error { + c := w.c + length := w.pos - maxFrameHeaderSize + len(extra) + + // Check for invalid control frames. + if isControl(w.frameType) && + (!final || length > maxControlFramePayloadSize) { + return w.endMessage(errInvalidControlFrame) + } + + b0 := byte(w.frameType) + if final { + b0 |= finalBit + } + if w.compress { + b0 |= rsv1Bit + } + w.compress = false + + b1 := byte(0) + if !c.isServer { + b1 |= maskBit + } + + // Assume that the frame starts at beginning of c.writeBuf. + framePos := 0 + if c.isServer { + // Adjust up if mask not included in the header. + framePos = 4 + } + + switch { + case length >= 65536: + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | 127 + binary.BigEndian.PutUint64(c.writeBuf[framePos+2:], uint64(length)) + case length > 125: + framePos += 6 + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | 126 + binary.BigEndian.PutUint16(c.writeBuf[framePos+2:], uint16(length)) + default: + framePos += 8 + c.writeBuf[framePos] = b0 + c.writeBuf[framePos+1] = b1 | byte(length) + } + + if !c.isServer { + key := newMaskKey() + copy(c.writeBuf[maxFrameHeaderSize-4:], key[:]) + maskBytes(key, 0, c.writeBuf[maxFrameHeaderSize:w.pos]) + if len(extra) > 0 { + return w.endMessage(c.writeFatal(errors.New("websocket: internal error, extra used in client mode"))) + } + } + + // Write the buffers to the connection with best-effort detection of + // concurrent writes. See the concurrency section in the package + // documentation for more info. + + if c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = true + + err := c.write(w.frameType, c.writeDeadline, c.writeBuf[framePos:w.pos], extra) + + if !c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = false + + if err != nil { + return w.endMessage(err) + } + + if final { + w.endMessage(errWriteClosed) + return nil + } + + // Setup for next frame. + w.pos = maxFrameHeaderSize + w.frameType = continuationFrame + return nil +} + +func (w *messageWriter) ncopy(max int) (int, error) { + n := len(w.c.writeBuf) - w.pos + if n <= 0 { + if err := w.flushFrame(false, nil); err != nil { + return 0, err + } + n = len(w.c.writeBuf) - w.pos + } + if n > max { + n = max + } + return n, nil +} + +func (w *messageWriter) Write(p []byte) (int, error) { + if w.err != nil { + return 0, w.err + } + + if len(p) > 2*len(w.c.writeBuf) && w.c.isServer { + // Don't buffer large messages. + err := w.flushFrame(false, p) + if err != nil { + return 0, err + } + return len(p), nil + } + + nn := len(p) + for len(p) > 0 { + n, err := w.ncopy(len(p)) + if err != nil { + return 0, err + } + copy(w.c.writeBuf[w.pos:], p[:n]) + w.pos += n + p = p[n:] + } + return nn, nil +} + +func (w *messageWriter) WriteString(p string) (int, error) { + if w.err != nil { + return 0, w.err + } + + nn := len(p) + for len(p) > 0 { + n, err := w.ncopy(len(p)) + if err != nil { + return 0, err + } + copy(w.c.writeBuf[w.pos:], p[:n]) + w.pos += n + p = p[n:] + } + return nn, nil +} + +func (w *messageWriter) ReadFrom(r io.Reader) (nn int64, err error) { + if w.err != nil { + return 0, w.err + } + for { + if w.pos == len(w.c.writeBuf) { + err = w.flushFrame(false, nil) + if err != nil { + break + } + } + var n int + n, err = r.Read(w.c.writeBuf[w.pos:]) + w.pos += n + nn += int64(n) + if err != nil { + if err == io.EOF { + err = nil + } + break + } + } + return nn, err +} + +func (w *messageWriter) Close() error { + if w.err != nil { + return w.err + } + return w.flushFrame(true, nil) +} + +// WritePreparedMessage writes prepared message into connection. +func (c *Conn) WritePreparedMessage(pm *PreparedMessage) error { + frameType, frameData, err := pm.frame(prepareKey{ + isServer: c.isServer, + compress: c.newCompressionWriter != nil && c.enableWriteCompression && isData(pm.messageType), + compressionLevel: c.compressionLevel, + }) + if err != nil { + return err + } + if c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = true + err = c.write(frameType, c.writeDeadline, frameData, nil) + if !c.isWriting { + panic("concurrent write to websocket connection") + } + c.isWriting = false + return err +} + +// WriteMessage is a helper method for getting a writer using NextWriter, +// writing the message and closing the writer. +func (c *Conn) WriteMessage(messageType int, data []byte) error { + if c.isServer && (c.newCompressionWriter == nil || !c.enableWriteCompression) { + // Fast path with no allocations and single frame. + + var mw messageWriter + if err := c.beginMessage(&mw, messageType); err != nil { + return err + } + n := copy(c.writeBuf[mw.pos:], data) + mw.pos += n + data = data[n:] + return mw.flushFrame(true, data) + } + + w, err := c.NextWriter(messageType) + if err != nil { + return err + } + if _, err = w.Write(data); err != nil { + return err + } + return w.Close() +} + +// SetWriteDeadline sets the write deadline on the underlying network +// connection. After a write has timed out, the websocket state is corrupt and +// all future writes will return an error. A zero value for t means writes will +// not time out. +func (c *Conn) SetWriteDeadline(t time.Time) error { + c.writeDeadline = t + return nil +} + +// Read methods + +func (c *Conn) advanceFrame() (int, error) { + // 1. Skip remainder of previous frame. + + if c.readRemaining > 0 { + if _, err := io.CopyN(ioutil.Discard, c.br, c.readRemaining); err != nil { + return noFrame, err + } + } + + // 2. Read and parse first two bytes of frame header. + // To aid debugging, collect and report all errors in the first two bytes + // of the header. + + var errors []string + + p, err := c.read(2) + if err != nil { + return noFrame, err + } + + frameType := int(p[0] & 0xf) + final := p[0]&finalBit != 0 + rsv1 := p[0]&rsv1Bit != 0 + rsv2 := p[0]&rsv2Bit != 0 + rsv3 := p[0]&rsv3Bit != 0 + mask := p[1]&maskBit != 0 + c.setReadRemaining(int64(p[1] & 0x7f)) + + c.readDecompress = false + if rsv1 { + if c.newDecompressionReader != nil { + c.readDecompress = true + } else { + errors = append(errors, "RSV1 set") + } + } + + if rsv2 { + errors = append(errors, "RSV2 set") + } + + if rsv3 { + errors = append(errors, "RSV3 set") + } + + switch frameType { + case CloseMessage, PingMessage, PongMessage: + if c.readRemaining > maxControlFramePayloadSize { + errors = append(errors, "len > 125 for control") + } + if !final { + errors = append(errors, "FIN not set on control") + } + case TextMessage, BinaryMessage: + if !c.readFinal { + errors = append(errors, "data before FIN") + } + c.readFinal = final + case continuationFrame: + if c.readFinal { + errors = append(errors, "continuation after FIN") + } + c.readFinal = final + default: + errors = append(errors, "bad opcode "+strconv.Itoa(frameType)) + } + + if mask != c.isServer { + errors = append(errors, "bad MASK") + } + + if len(errors) > 0 { + return noFrame, c.handleProtocolError(strings.Join(errors, ", ")) + } + + // 3. Read and parse frame length as per + // https://tools.ietf.org/html/rfc6455#section-5.2 + // + // The length of the "Payload data", in bytes: if 0-125, that is the payload + // length. + // - If 126, the following 2 bytes interpreted as a 16-bit unsigned + // integer are the payload length. + // - If 127, the following 8 bytes interpreted as + // a 64-bit unsigned integer (the most significant bit MUST be 0) are the + // payload length. Multibyte length quantities are expressed in network byte + // order. + + switch c.readRemaining { + case 126: + p, err := c.read(2) + if err != nil { + return noFrame, err + } + + if err := c.setReadRemaining(int64(binary.BigEndian.Uint16(p))); err != nil { + return noFrame, err + } + case 127: + p, err := c.read(8) + if err != nil { + return noFrame, err + } + + if err := c.setReadRemaining(int64(binary.BigEndian.Uint64(p))); err != nil { + return noFrame, err + } + } + + // 4. Handle frame masking. + + if mask { + c.readMaskPos = 0 + p, err := c.read(len(c.readMaskKey)) + if err != nil { + return noFrame, err + } + copy(c.readMaskKey[:], p) + } + + // 5. For text and binary messages, enforce read limit and return. + + if frameType == continuationFrame || frameType == TextMessage || frameType == BinaryMessage { + + c.readLength += c.readRemaining + // Don't allow readLength to overflow in the presence of a large readRemaining + // counter. + if c.readLength < 0 { + return noFrame, ErrReadLimit + } + + if c.readLimit > 0 && c.readLength > c.readLimit { + c.WriteControl(CloseMessage, FormatCloseMessage(CloseMessageTooBig, ""), time.Now().Add(writeWait)) + return noFrame, ErrReadLimit + } + + return frameType, nil + } + + // 6. Read control frame payload. + + var payload []byte + if c.readRemaining > 0 { + payload, err = c.read(int(c.readRemaining)) + c.setReadRemaining(0) + if err != nil { + return noFrame, err + } + if c.isServer { + maskBytes(c.readMaskKey, 0, payload) + } + } + + // 7. Process control frame payload. + + switch frameType { + case PongMessage: + if err := c.handlePong(string(payload)); err != nil { + return noFrame, err + } + case PingMessage: + if err := c.handlePing(string(payload)); err != nil { + return noFrame, err + } + case CloseMessage: + closeCode := CloseNoStatusReceived + closeText := "" + if len(payload) >= 2 { + closeCode = int(binary.BigEndian.Uint16(payload)) + if !isValidReceivedCloseCode(closeCode) { + return noFrame, c.handleProtocolError("bad close code " + strconv.Itoa(closeCode)) + } + closeText = string(payload[2:]) + if !utf8.ValidString(closeText) { + return noFrame, c.handleProtocolError("invalid utf8 payload in close frame") + } + } + if err := c.handleClose(closeCode, closeText); err != nil { + return noFrame, err + } + return noFrame, &CloseError{Code: closeCode, Text: closeText} + } + + return frameType, nil +} + +func (c *Conn) handleProtocolError(message string) error { + data := FormatCloseMessage(CloseProtocolError, message) + if len(data) > maxControlFramePayloadSize { + data = data[:maxControlFramePayloadSize] + } + c.WriteControl(CloseMessage, data, time.Now().Add(writeWait)) + return errors.New("websocket: " + message) +} + +// NextReader returns the next data message received from the peer. The +// returned messageType is either TextMessage or BinaryMessage. +// +// There can be at most one open reader on a connection. NextReader discards +// the previous message if the application has not already consumed it. +// +// Applications must break out of the application's read loop when this method +// returns a non-nil error value. Errors returned from this method are +// permanent. Once this method returns a non-nil error, all subsequent calls to +// this method return the same error. +func (c *Conn) NextReader() (messageType int, r io.Reader, err error) { + // Close previous reader, only relevant for decompression. + if c.reader != nil { + c.reader.Close() + c.reader = nil + } + + c.messageReader = nil + c.readLength = 0 + + for c.readErr == nil { + frameType, err := c.advanceFrame() + if err != nil { + c.readErr = hideTempErr(err) + break + } + + if frameType == TextMessage || frameType == BinaryMessage { + c.messageReader = &messageReader{c} + c.reader = c.messageReader + if c.readDecompress { + c.reader = c.newDecompressionReader(c.reader) + } + return frameType, c.reader, nil + } + } + + // Applications that do handle the error returned from this method spin in + // tight loop on connection failure. To help application developers detect + // this error, panic on repeated reads to the failed connection. + c.readErrCount++ + if c.readErrCount >= 1000 { + panic("repeated read on failed websocket connection") + } + + return noFrame, nil, c.readErr +} + +type messageReader struct{ c *Conn } + +func (r *messageReader) Read(b []byte) (int, error) { + c := r.c + if c.messageReader != r { + return 0, io.EOF + } + + for c.readErr == nil { + + if c.readRemaining > 0 { + if int64(len(b)) > c.readRemaining { + b = b[:c.readRemaining] + } + n, err := c.br.Read(b) + c.readErr = hideTempErr(err) + if c.isServer { + c.readMaskPos = maskBytes(c.readMaskKey, c.readMaskPos, b[:n]) + } + rem := c.readRemaining + rem -= int64(n) + c.setReadRemaining(rem) + if c.readRemaining > 0 && c.readErr == io.EOF { + c.readErr = errUnexpectedEOF + } + return n, c.readErr + } + + if c.readFinal { + c.messageReader = nil + return 0, io.EOF + } + + frameType, err := c.advanceFrame() + switch { + case err != nil: + c.readErr = hideTempErr(err) + case frameType == TextMessage || frameType == BinaryMessage: + c.readErr = errors.New("websocket: internal error, unexpected text or binary in Reader") + } + } + + err := c.readErr + if err == io.EOF && c.messageReader == r { + err = errUnexpectedEOF + } + return 0, err +} + +func (r *messageReader) Close() error { + return nil +} + +// ReadMessage is a helper method for getting a reader using NextReader and +// reading from that reader to a buffer. +func (c *Conn) ReadMessage() (messageType int, p []byte, err error) { + var r io.Reader + messageType, r, err = c.NextReader() + if err != nil { + return messageType, nil, err + } + p, err = ioutil.ReadAll(r) + return messageType, p, err +} + +// SetReadDeadline sets the read deadline on the underlying network connection. +// After a read has timed out, the websocket connection state is corrupt and +// all future reads will return an error. A zero value for t means reads will +// not time out. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetReadLimit sets the maximum size in bytes for a message read from the peer. If a +// message exceeds the limit, the connection sends a close message to the peer +// and returns ErrReadLimit to the application. +func (c *Conn) SetReadLimit(limit int64) { + c.readLimit = limit +} + +// CloseHandler returns the current close handler +func (c *Conn) CloseHandler() func(code int, text string) error { + return c.handleClose +} + +// SetCloseHandler sets the handler for close messages received from the peer. +// The code argument to h is the received close code or CloseNoStatusReceived +// if the close message is empty. The default close handler sends a close +// message back to the peer. +// +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// close messages as described in the section on Control Messages above. +// +// The connection read methods return a CloseError when a close message is +// received. Most applications should handle close messages as part of their +// normal error handling. Applications should only set a close handler when the +// application must perform some action before sending a close message back to +// the peer. +func (c *Conn) SetCloseHandler(h func(code int, text string) error) { + if h == nil { + h = func(code int, text string) error { + message := FormatCloseMessage(code, "") + c.WriteControl(CloseMessage, message, time.Now().Add(writeWait)) + return nil + } + } + c.handleClose = h +} + +// PingHandler returns the current ping handler +func (c *Conn) PingHandler() func(appData string) error { + return c.handlePing +} + +// SetPingHandler sets the handler for ping messages received from the peer. +// The appData argument to h is the PING message application data. The default +// ping handler sends a pong to the peer. +// +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// ping messages as described in the section on Control Messages above. +func (c *Conn) SetPingHandler(h func(appData string) error) { + if h == nil { + h = func(message string) error { + err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) + if err == ErrCloseSent { + return nil + } else if e, ok := err.(net.Error); ok && e.Temporary() { + return nil + } + return err + } + } + c.handlePing = h +} + +// PongHandler returns the current pong handler +func (c *Conn) PongHandler() func(appData string) error { + return c.handlePong +} + +// SetPongHandler sets the handler for pong messages received from the peer. +// The appData argument to h is the PONG message application data. The default +// pong handler does nothing. +// +// The handler function is called from the NextReader, ReadMessage and message +// reader Read methods. The application must read the connection to process +// pong messages as described in the section on Control Messages above. +func (c *Conn) SetPongHandler(h func(appData string) error) { + if h == nil { + h = func(string) error { return nil } + } + c.handlePong = h +} + +// NetConn returns the underlying connection that is wrapped by c. +// Note that writing to or reading from this connection directly will corrupt the +// WebSocket connection. +func (c *Conn) NetConn() net.Conn { + return c.conn +} + +// UnderlyingConn returns the internal net.Conn. This can be used to further +// modifications to connection specific flags. +// Deprecated: Use the NetConn method. +func (c *Conn) UnderlyingConn() net.Conn { + return c.conn +} + +// EnableWriteCompression enables and disables write compression of +// subsequent text and binary messages. This function is a noop if +// compression was not negotiated with the peer. +func (c *Conn) EnableWriteCompression(enable bool) { + c.enableWriteCompression = enable +} + +// SetCompressionLevel sets the flate compression level for subsequent text and +// binary messages. This function is a noop if compression was not negotiated +// with the peer. See the compress/flate package for a description of +// compression levels. +func (c *Conn) SetCompressionLevel(level int) error { + if !isValidCompressionLevel(level) { + return errors.New("websocket: invalid compression level") + } + c.compressionLevel = level + return nil +} + +// FormatCloseMessage formats closeCode and text as a WebSocket close message. +// An empty message is returned for code CloseNoStatusReceived. +func FormatCloseMessage(closeCode int, text string) []byte { + if closeCode == CloseNoStatusReceived { + // Return empty message because it's illegal to send + // CloseNoStatusReceived. Return non-nil value in case application + // checks for nil. + return []byte{} + } + buf := make([]byte, 2+len(text)) + binary.BigEndian.PutUint16(buf, uint16(closeCode)) + copy(buf[2:], text) + return buf +} diff --git a/conn_broadcast_test.go b/conn_broadcast_test.go new file mode 100644 index 0000000..fa3c016 --- /dev/null +++ b/conn_broadcast_test.go @@ -0,0 +1,134 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file may have been modified by CloudWeGo authors. All CloudWeGo +// Modifications are Copyright 2022 CloudWeGo Authors. + +package websocket + +import ( + "io" + "io/ioutil" + "sync/atomic" + "testing" +) + +// broadcastBench allows to run broadcast benchmarks. +// In every broadcast benchmark we create many connections, then send the same +// message into every connection and wait for all writes complete. This emulates +// an application where many connections listen to the same data - i.e. PUB/SUB +// scenarios with many subscribers in one channel. +type broadcastBench struct { + w io.Writer + closeCh chan struct{} + doneCh chan struct{} + count int32 + conns []*broadcastConn + compression bool + usePrepared bool +} + +type broadcastMessage struct { + payload []byte + prepared *PreparedMessage +} + +type broadcastConn struct { + conn *Conn + msgCh chan *broadcastMessage +} + +func newBroadcastConn(c *Conn) *broadcastConn { + return &broadcastConn{ + conn: c, + msgCh: make(chan *broadcastMessage, 1), + } +} + +func newBroadcastBench(usePrepared, compression bool) *broadcastBench { + bench := &broadcastBench{ + w: ioutil.Discard, + doneCh: make(chan struct{}), + closeCh: make(chan struct{}), + usePrepared: usePrepared, + compression: compression, + } + bench.makeConns(10000) + return bench +} + +func (b *broadcastBench) makeConns(numConns int) { + conns := make([]*broadcastConn, numConns) + + for i := 0; i < numConns; i++ { + c := newTestConn(nil, b.w, true) + if b.compression { + c.enableWriteCompression = true + c.newCompressionWriter = compressNoContextTakeover + } + conns[i] = newBroadcastConn(c) + go func(c *broadcastConn) { + for { + select { + case msg := <-c.msgCh: + if msg.prepared != nil { + c.conn.WritePreparedMessage(msg.prepared) + } else { + c.conn.WriteMessage(TextMessage, msg.payload) + } + val := atomic.AddInt32(&b.count, 1) + if val%int32(numConns) == 0 { + b.doneCh <- struct{}{} + } + case <-b.closeCh: + return + } + } + }(conns[i]) + } + b.conns = conns +} + +func (b *broadcastBench) close() { + close(b.closeCh) +} + +func (b *broadcastBench) broadcastOnce(msg *broadcastMessage) { + for _, c := range b.conns { + c.msgCh <- msg + } + <-b.doneCh +} + +func BenchmarkBroadcast(b *testing.B) { + benchmarks := []struct { + name string + usePrepared bool + compression bool + }{ + {"NoCompression", false, false}, + {"Compression", false, true}, + {"NoCompressionPrepared", true, false}, + {"CompressionPrepared", true, true}, + } + payload := textMessages(1)[0] + for _, bm := range benchmarks { + b.Run(bm.name, func(b *testing.B) { + bench := newBroadcastBench(bm.usePrepared, bm.compression) + defer bench.close() + b.ResetTimer() + for i := 0; i < b.N; i++ { + message := &broadcastMessage{ + payload: payload, + } + if bench.usePrepared { + pm, _ := NewPreparedMessage(TextMessage, message.payload) + message.prepared = pm + } + bench.broadcastOnce(message) + } + b.ReportAllocs() + }) + } +} diff --git a/conn_test.go b/conn_test.go new file mode 100644 index 0000000..4961797 --- /dev/null +++ b/conn_test.go @@ -0,0 +1,708 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file may have been modified by CloudWeGo authors. All CloudWeGo +// Modifications are Copyright 2022 CloudWeGo Authors. + +package websocket + +import ( + "bufio" + "bytes" + "errors" + "fmt" + "io" + "io/ioutil" + "net" + "reflect" + "sync" + "testing" + "testing/iotest" + "time" +) + +var _ net.Error = errWriteTimeout + +type fakeNetConn struct { + io.Reader + io.Writer +} + +func (c fakeNetConn) Close() error { return nil } +func (c fakeNetConn) LocalAddr() net.Addr { return localAddr } +func (c fakeNetConn) RemoteAddr() net.Addr { return remoteAddr } +func (c fakeNetConn) SetDeadline(t time.Time) error { return nil } +func (c fakeNetConn) SetReadDeadline(t time.Time) error { return nil } +func (c fakeNetConn) SetWriteDeadline(t time.Time) error { return nil } + +type fakeAddr int + +var ( + localAddr = fakeAddr(1) + remoteAddr = fakeAddr(2) +) + +func (a fakeAddr) Network() string { + return "net" +} + +func (a fakeAddr) String() string { + return "str" +} + +// newTestConn creates a connection backed by a fake network connection using +// default values for buffering. +func newTestConn(r io.Reader, w io.Writer, isServer bool) *Conn { + return newConn(fakeNetConn{Reader: r, Writer: w}, isServer, 1024, 1024, nil, nil, nil) +} + +func TestFraming(t *testing.T) { + frameSizes := []int{ + 0, 1, 2, 124, 125, 126, 127, 128, 129, 65534, 65535, + // 65536, 65537 + } + readChunkers := []struct { + name string + f func(io.Reader) io.Reader + }{ + {"half", iotest.HalfReader}, + {"one", iotest.OneByteReader}, + {"asis", func(r io.Reader) io.Reader { return r }}, + } + writeBuf := make([]byte, 65537) + for i := range writeBuf { + writeBuf[i] = byte(i) + } + writers := []struct { + name string + f func(w io.Writer, n int) (int, error) + }{ + {"iocopy", func(w io.Writer, n int) (int, error) { + nn, err := io.Copy(w, bytes.NewReader(writeBuf[:n])) + return int(nn), err + }}, + {"write", func(w io.Writer, n int) (int, error) { + return w.Write(writeBuf[:n]) + }}, + {"string", func(w io.Writer, n int) (int, error) { + return io.WriteString(w, string(writeBuf[:n])) + }}, + } + + for _, compress := range []bool{false, true} { + for _, isServer := range []bool{true, false} { + for _, chunker := range readChunkers { + + var connBuf bytes.Buffer + wc := newTestConn(nil, &connBuf, isServer) + rc := newTestConn(chunker.f(&connBuf), nil, !isServer) + if compress { + wc.newCompressionWriter = compressNoContextTakeover + rc.newDecompressionReader = decompressNoContextTakeover + } + for _, n := range frameSizes { + for _, writer := range writers { + name := fmt.Sprintf("z:%v, s:%v, r:%s, n:%d w:%s", compress, isServer, chunker.name, n, writer.name) + + w, err := wc.NextWriter(TextMessage) + if err != nil { + t.Errorf("%s: wc.NextWriter() returned %v", name, err) + continue + } + nn, err := writer.f(w, n) + if err != nil || nn != n { + t.Errorf("%s: w.Write(writeBuf[:n]) returned %d, %v", name, nn, err) + continue + } + err = w.Close() + if err != nil { + t.Errorf("%s: w.Close() returned %v", name, err) + continue + } + + opCode, r, err := rc.NextReader() + if err != nil || opCode != TextMessage { + t.Errorf("%s: NextReader() returned %d, r, %v", name, opCode, err) + continue + } + + t.Logf("frame size: %d", n) + rbuf, err := ioutil.ReadAll(r) + if err != nil { + t.Errorf("%s: ReadFull() returned rbuf, %v", name, err) + continue + } + + if len(rbuf) != n { + t.Errorf("%s: len(rbuf) is %d, want %d", name, len(rbuf), n) + continue + } + + for i, b := range rbuf { + if byte(i) != b { + t.Errorf("%s: bad byte at offset %d", name, i) + break + } + } + } + } + } + } + } +} + +func TestControl(t *testing.T) { + const message = "this is a ping/pong message" + for _, isServer := range []bool{true, false} { + for _, isWriteControl := range []bool{true, false} { + name := fmt.Sprintf("s:%v, wc:%v", isServer, isWriteControl) + var connBuf bytes.Buffer + wc := newTestConn(nil, &connBuf, isServer) + rc := newTestConn(&connBuf, nil, !isServer) + if isWriteControl { + wc.WriteControl(PongMessage, []byte(message), time.Now().Add(time.Second)) + } else { + w, err := wc.NextWriter(PongMessage) + if err != nil { + t.Errorf("%s: wc.NextWriter() returned %v", name, err) + continue + } + if _, err := w.Write([]byte(message)); err != nil { + t.Errorf("%s: w.Write() returned %v", name, err) + continue + } + if err := w.Close(); err != nil { + t.Errorf("%s: w.Close() returned %v", name, err) + continue + } + var actualMessage string + rc.SetPongHandler(func(s string) error { actualMessage = s; return nil }) + rc.NextReader() + if actualMessage != message { + t.Errorf("%s: pong=%q, want %q", name, actualMessage, message) + continue + } + } + } + } +} + +// simpleBufferPool is an implementation of BufferPool for TestWriteBufferPool. +type simpleBufferPool struct { + v interface{} +} + +func (p *simpleBufferPool) Get() interface{} { + v := p.v + p.v = nil + return v +} + +func (p *simpleBufferPool) Put(v interface{}) { + p.v = v +} + +func TestWriteBufferPool(t *testing.T) { + const message = "Now is the time for all good people to come to the aid of the party." + + var buf bytes.Buffer + var pool simpleBufferPool + rc := newTestConn(&buf, nil, false) + + // Specify writeBufferSize smaller than message size to ensure that pooling + // works with fragmented messages. + wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, len(message)-1, &pool, nil, nil) + + if wc.writeBuf != nil { + t.Fatal("writeBuf not nil after create") + } + + // Part 1: test NextWriter/Write/Close + + w, err := wc.NextWriter(TextMessage) + if err != nil { + t.Fatalf("wc.NextWriter() returned %v", err) + } + + if wc.writeBuf == nil { + t.Fatal("writeBuf is nil after NextWriter") + } + + writeBufAddr := &wc.writeBuf[0] + + if _, err := io.WriteString(w, message); err != nil { + t.Fatalf("io.WriteString(w, message) returned %v", err) + } + + if err := w.Close(); err != nil { + t.Fatalf("w.Close() returned %v", err) + } + + if wc.writeBuf != nil { + t.Fatal("writeBuf not nil after w.Close()") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool") + } + + opCode, p, err := rc.ReadMessage() + if opCode != TextMessage || err != nil { + t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) + } + + if s := string(p); s != message { + t.Fatalf("message is %s, want %s", s, message) + } + + // Part 2: Test WriteMessage. + + if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { + t.Fatalf("wc.WriteMessage() returned %v", err) + } + + if wc.writeBuf != nil { + t.Fatal("writeBuf not nil after wc.WriteMessage()") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool after WriteMessage") + } + + opCode, p, err = rc.ReadMessage() + if opCode != TextMessage || err != nil { + t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) + } + + if s := string(p); s != message { + t.Fatalf("message is %s, want %s", s, message) + } +} + +// TestWriteBufferPoolSync ensures that *sync.Pool works as a buffer pool. +func TestWriteBufferPoolSync(t *testing.T) { + var buf bytes.Buffer + var pool sync.Pool + wc := newConn(fakeNetConn{Writer: &buf}, true, 1024, 1024, &pool, nil, nil) + rc := newTestConn(&buf, nil, false) + + const message = "Hello World!" + for i := 0; i < 3; i++ { + if err := wc.WriteMessage(TextMessage, []byte(message)); err != nil { + t.Fatalf("wc.WriteMessage() returned %v", err) + } + opCode, p, err := rc.ReadMessage() + if opCode != TextMessage || err != nil { + t.Fatalf("ReadMessage() returned %d, p, %v", opCode, err) + } + if s := string(p); s != message { + t.Fatalf("message is %s, want %s", s, message) + } + } +} + +// errorWriter is an io.Writer than returns an error on all writes. +type errorWriter struct{} + +func (ew errorWriter) Write(p []byte) (int, error) { return 0, errors.New("error") } + +// TestWriteBufferPoolError ensures that buffer is returned to pool after error +// on write. +func TestWriteBufferPoolError(t *testing.T) { + // Part 1: Test NextWriter/Write/Close + + var pool simpleBufferPool + wc := newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) + + w, err := wc.NextWriter(TextMessage) + if err != nil { + t.Fatalf("wc.NextWriter() returned %v", err) + } + + if wc.writeBuf == nil { + t.Fatal("writeBuf is nil after NextWriter") + } + + writeBufAddr := &wc.writeBuf[0] + + if _, err := io.WriteString(w, "Hello"); err != nil { + t.Fatalf("io.WriteString(w, message) returned %v", err) + } + + if err := w.Close(); err == nil { + t.Fatalf("w.Close() did not return error") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool") + } + + // Part 2: Test WriteMessage + + wc = newConn(fakeNetConn{Writer: errorWriter{}}, true, 1024, 1024, &pool, nil, nil) + + if err := wc.WriteMessage(TextMessage, []byte("Hello")); err == nil { + t.Fatalf("wc.WriteMessage did not return error") + } + + if wpd, ok := pool.v.(writePoolData); !ok || len(wpd.buf) == 0 || &wpd.buf[0] != writeBufAddr { + t.Fatal("writeBuf not returned to pool") + } +} + +func TestCloseFrameBeforeFinalMessageFrame(t *testing.T) { + const bufSize = 512 + + expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} + + var b1, b2 bytes.Buffer + wc := newConn(&fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize, nil, nil, nil) + rc := newTestConn(&b1, &b2, true) + + w, _ := wc.NextWriter(BinaryMessage) + w.Write(make([]byte, bufSize+bufSize/2)) + wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) + w.Close() + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("NextReader() returned %d, %v", op, err) + } + _, err = io.Copy(ioutil.Discard, r) + if !reflect.DeepEqual(err, expectedErr) { + t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr) + } + _, _, err = rc.NextReader() + if !reflect.DeepEqual(err, expectedErr) { + t.Fatalf("NextReader() returned %v, want %v", err, expectedErr) + } +} + +func TestEOFWithinFrame(t *testing.T) { + const bufSize = 64 + + for n := 0; ; n++ { + var b bytes.Buffer + wc := newTestConn(nil, &b, false) + rc := newTestConn(&b, nil, true) + + w, _ := wc.NextWriter(BinaryMessage) + w.Write(make([]byte, bufSize)) + w.Close() + + if n >= b.Len() { + break + } + b.Truncate(n) + + op, r, err := rc.NextReader() + if err == errUnexpectedEOF { + continue + } + if op != BinaryMessage || err != nil { + t.Fatalf("%d: NextReader() returned %d, %v", n, op, err) + } + _, err = io.Copy(ioutil.Discard, r) + if err != errUnexpectedEOF { + t.Fatalf("%d: io.Copy() returned %v, want %v", n, err, errUnexpectedEOF) + } + _, _, err = rc.NextReader() + if err != errUnexpectedEOF { + t.Fatalf("%d: NextReader() returned %v, want %v", n, err, errUnexpectedEOF) + } + } +} + +func TestEOFBeforeFinalFrame(t *testing.T) { + const bufSize = 512 + + var b1, b2 bytes.Buffer + wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, bufSize, nil, nil, nil) + rc := newTestConn(&b1, &b2, true) + + w, _ := wc.NextWriter(BinaryMessage) + w.Write(make([]byte, bufSize+bufSize/2)) + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("NextReader() returned %d, %v", op, err) + } + _, err = io.Copy(ioutil.Discard, r) + if err != errUnexpectedEOF { + t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) + } + _, _, err = rc.NextReader() + if err != errUnexpectedEOF { + t.Fatalf("NextReader() returned %v, want %v", err, errUnexpectedEOF) + } +} + +func TestWriteAfterMessageWriterClose(t *testing.T) { + wc := newTestConn(nil, &bytes.Buffer{}, false) + w, _ := wc.NextWriter(BinaryMessage) + io.WriteString(w, "hello") + if err := w.Close(); err != nil { + t.Fatalf("unexpected error closing message writer, %v", err) + } + + if _, err := io.WriteString(w, "world"); err == nil { + t.Fatalf("no error writing after close") + } + + w, _ = wc.NextWriter(BinaryMessage) + io.WriteString(w, "hello") + + // close w by getting next writer + _, err := wc.NextWriter(BinaryMessage) + if err != nil { + t.Fatalf("unexpected error getting next writer, %v", err) + } + + if _, err := io.WriteString(w, "world"); err == nil { + t.Fatalf("no error writing after close") + } +} + +func TestReadLimit(t *testing.T) { + t.Run("Test ReadLimit is enforced", func(t *testing.T) { + const readLimit = 512 + message := make([]byte, readLimit+1) + + var b1, b2 bytes.Buffer + wc := newConn(&fakeNetConn{Writer: &b1}, false, 1024, readLimit-2, nil, nil, nil) + rc := newTestConn(&b1, &b2, true) + rc.SetReadLimit(readLimit) + + // Send message at the limit with interleaved pong. + w, _ := wc.NextWriter(BinaryMessage) + w.Write(message[:readLimit-1]) + wc.WriteControl(PongMessage, []byte("this is a pong"), time.Now().Add(10*time.Second)) + w.Write(message[:1]) + w.Close() + + // Send message larger than the limit. + wc.WriteMessage(BinaryMessage, message[:readLimit+1]) + + op, _, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("1: NextReader() returned %d, %v", op, err) + } + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("2: NextReader() returned %d, %v", op, err) + } + _, err = io.Copy(ioutil.Discard, r) + if err != ErrReadLimit { + t.Fatalf("io.Copy() returned %v", err) + } + }) + + t.Run("Test that ReadLimit cannot be overflowed", func(t *testing.T) { + const readLimit = 1 + + var b1, b2 bytes.Buffer + rc := newTestConn(&b1, &b2, true) + rc.SetReadLimit(readLimit) + + // First, send a non-final binary message + b1.Write([]byte("\x02\x81")) + + // Mask key + b1.Write([]byte("\x00\x00\x00\x00")) + + // First payload + b1.Write([]byte("A")) + + // Next, send a negative-length, non-final continuation frame + b1.Write([]byte("\x00\xFF\x80\x00\x00\x00\x00\x00\x00\x00")) + + // Mask key + b1.Write([]byte("\x00\x00\x00\x00")) + + // Next, send a too long, final continuation frame + b1.Write([]byte("\x80\xFF\x00\x00\x00\x00\x00\x00\x00\x05")) + + // Mask key + b1.Write([]byte("\x00\x00\x00\x00")) + + // Too-long payload + b1.Write([]byte("BCDEF")) + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("1: NextReader() returned %d, %v", op, err) + } + + var buf [10]byte + var read int + n, err := r.Read(buf[:]) + if err != nil && err != ErrReadLimit { + t.Fatalf("unexpected error testing read limit: %v", err) + } + read += n + + n, err = r.Read(buf[:]) + if err != nil && err != ErrReadLimit { + t.Fatalf("unexpected error testing read limit: %v", err) + } + read += n + + if err == nil && read > readLimit { + t.Fatalf("read limit exceeded: limit %d, read %d", readLimit, read) + } + }) +} + +func TestAddrs(t *testing.T) { + c := newTestConn(nil, nil, true) + if c.LocalAddr() != localAddr { + t.Errorf("LocalAddr = %v, want %v", c.LocalAddr(), localAddr) + } + if c.RemoteAddr() != remoteAddr { + t.Errorf("RemoteAddr = %v, want %v", c.RemoteAddr(), remoteAddr) + } +} + +func TestDeprecatedUnderlyingConn(t *testing.T) { + var b1, b2 bytes.Buffer + fc := fakeNetConn{Reader: &b1, Writer: &b2} + c := newConn(fc, true, 1024, 1024, nil, nil, nil) + ul := c.UnderlyingConn() + if ul != fc { + t.Fatalf("Underlying conn is not what it should be.") + } +} + +func TestNetConn(t *testing.T) { + var b1, b2 bytes.Buffer + fc := fakeNetConn{Reader: &b1, Writer: &b2} + c := newConn(fc, true, 1024, 1024, nil, nil, nil) + ul := c.NetConn() + if ul != fc { + t.Fatalf("Underlying conn is not what it should be.") + } +} + +func TestBufioReadBytes(t *testing.T) { + // Test calling bufio.ReadBytes for value longer than read buffer size. + + m := make([]byte, 512) + m[len(m)-1] = '\n' + + var b1, b2 bytes.Buffer + wc := newConn(fakeNetConn{Writer: &b1}, false, len(m)+64, len(m)+64, nil, nil, nil) + rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64, nil, nil, nil) + + w, _ := wc.NextWriter(BinaryMessage) + w.Write(m) + w.Close() + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("NextReader() returned %d, %v", op, err) + } + + br := bufio.NewReader(r) + p, err := br.ReadBytes('\n') + if err != nil { + t.Fatalf("ReadBytes() returned %v", err) + } + if len(p) != len(m) { + t.Fatalf("read returned %d bytes, want %d bytes", len(p), len(m)) + } +} + +var closeErrorTests = []struct { + err error + codes []int + ok bool +}{ + {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, true}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, false}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, true}, + {errors.New("hello"), []int{CloseNormalClosure}, false}, +} + +func TestCloseError(t *testing.T) { + for _, tt := range closeErrorTests { + ok := IsCloseError(tt.err, tt.codes...) + if ok != tt.ok { + t.Errorf("IsCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) + } + } +} + +var unexpectedCloseErrorTests = []struct { + err error + codes []int + ok bool +}{ + {&CloseError{Code: CloseNormalClosure}, []int{CloseNormalClosure}, false}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived}, true}, + {&CloseError{Code: CloseNormalClosure}, []int{CloseNoStatusReceived, CloseNormalClosure}, false}, + {errors.New("hello"), []int{CloseNormalClosure}, false}, +} + +func TestUnexpectedCloseErrors(t *testing.T) { + for _, tt := range unexpectedCloseErrorTests { + ok := IsUnexpectedCloseError(tt.err, tt.codes...) + if ok != tt.ok { + t.Errorf("IsUnexpectedCloseError(%#v, %#v) returned %v, want %v", tt.err, tt.codes, ok, tt.ok) + } + } +} + +type blockingWriter struct { + c1, c2 chan struct{} +} + +func (w blockingWriter) Write(p []byte) (int, error) { + // Allow main to continue + close(w.c1) + // Wait for panic in main + <-w.c2 + return len(p), nil +} + +func TestConcurrentWritePanic(t *testing.T) { + w := blockingWriter{make(chan struct{}), make(chan struct{})} + c := newTestConn(nil, w, false) + go func() { + c.WriteMessage(TextMessage, []byte{}) + }() + + // wait for goroutine to block in write. + <-w.c1 + + defer func() { + close(w.c2) + if v := recover(); v != nil { + return + } + }() + + c.WriteMessage(TextMessage, []byte{}) + t.Fatal("should not get here") +} + +type failingReader struct{} + +func (r failingReader) Read(p []byte) (int, error) { + return 0, io.EOF +} + +func TestFailedConnectionReadPanic(t *testing.T) { + c := newTestConn(failingReader{}, nil, false) + + defer func() { + if v := recover(); v != nil { + return + } + }() + + for i := 0; i < 20000; i++ { + c.ReadMessage() + } + t.Fatal("should not get here") +} diff --git a/examples/autobahn/README.md b/examples/autobahn/README.md new file mode 100644 index 0000000..cc954fe --- /dev/null +++ b/examples/autobahn/README.md @@ -0,0 +1,18 @@ +# Test Server + +This package contains a server for the [Autobahn WebSockets Test Suite](https://github.com/crossbario/autobahn-testsuite). + +To test the server, run + + go run server.go + +and start the client test driver + + mkdir -p reports + docker run -it --rm \ + -v ${PWD}/config:/config \ + -v ${PWD}/reports:/reports \ + crossbario/autobahn-testsuite \ + wstest -m fuzzingclient -s /config/fuzzingclient.json + +When the client completes, it writes a report to reports/index.html. diff --git a/examples/autobahn/config/fuzzingclient.json b/examples/autobahn/config/fuzzingclient.json new file mode 100644 index 0000000..eda4e66 --- /dev/null +++ b/examples/autobahn/config/fuzzingclient.json @@ -0,0 +1,29 @@ +{ + "cases": ["*"], + "exclude-cases": [], + "exclude-agent-cases": {}, + "outdir": "/reports", + "options": {"failByDrop": false}, + "servers": [ + { + "agent": "ReadAllWriteMessage", + "url": "ws://host.docker.internal:9000/m" + }, + { + "agent": "ReadAllWritePreparedMessage", + "url": "ws://host.docker.internal:9000/p" + }, + { + "agent": "CopyFull", + "url": "ws://host.docker.internal:9000/f" + }, + { + "agent": "ReadAllWrite", + "url": "ws://host.docker.internal:9000/r" + }, + { + "agent": "CopyWriterOnly", + "url": "ws://host.docker.internal:9000/c" + } + ] +} diff --git a/examples/autobahn/server.go b/examples/autobahn/server.go new file mode 100644 index 0000000..2231c6a --- /dev/null +++ b/examples/autobahn/server.go @@ -0,0 +1,277 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file may have been modified by CloudWeGo authors. All CloudWeGo +// Modifications are Copyright 2022 CloudWeGo Authors. + +// Command server is a test server for the Autobahn WebSockets Test Suite. +package main + +import ( + "context" + "errors" + "io" + "log" + "net/http" + "time" + "unicode/utf8" + + "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/app/server" + "github.com/cloudwego/hertz/pkg/protocol/consts" + "github.com/hertz-contrib/websocket" +) + +var upgrader = websocket.HertzUpgrader{ + ReadBufferSize: 4096, + WriteBufferSize: 4096, + EnableCompression: true, + CheckOrigin: func(ctx *app.RequestContext) bool { + return true + }, +} + +// echoCopy echoes messages from the client using io.Copy. +func echoCopy(ctx *app.RequestContext, writerOnly bool) { + err := upgrader.Upgrade(ctx, func(conn *websocket.Conn) { + defer conn.Close() + for { + mt, r, err := conn.NextReader() + if err != nil { + if err != io.EOF { + log.Println("NextReader:", err) + } + return + } + if mt == websocket.TextMessage { + r = &validator{r: r} + } + w, err := conn.NextWriter(mt) + if err != nil { + log.Println("NextWriter:", err) + return + } + if mt == websocket.TextMessage { + r = &validator{r: r} + } + if writerOnly { + _, err = io.Copy(struct{ io.Writer }{w}, r) + } else { + _, err = io.Copy(w, r) + } + if err != nil { + if err == errInvalidUTF8 { + conn.WriteControl(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseInvalidFramePayloadData, ""), + time.Time{}) + } + log.Println("Copy:", err) + return + } + err = w.Close() + if err != nil { + log.Println("Close:", err) + return + } + } + }) + if err != nil { + log.Println("Upgrade:", err) + return + } +} + +func echoCopyWriterOnly(_ context.Context, c *app.RequestContext) { + echoCopy(c, true) +} + +func echoCopyFull(_ context.Context, c *app.RequestContext) { + echoCopy(c, false) +} + +// echoReadAll echoes messages from the client by reading the entire message +// with ioutil.ReadAll. +func echoReadAll(ctx *app.RequestContext, writeMessage, writePrepared bool) { + err := upgrader.Upgrade(ctx, func(conn *websocket.Conn) { + defer conn.Close() + for { + mt, b, err := conn.ReadMessage() + if err != nil { + if err != io.EOF { + log.Println("NextReader:", err) + } + return + } + if mt == websocket.TextMessage { + if !utf8.Valid(b) { + conn.WriteControl(websocket.CloseMessage, + websocket.FormatCloseMessage(websocket.CloseInvalidFramePayloadData, ""), + time.Time{}) + log.Println("ReadAll: invalid utf8") + } + } + if writeMessage { + if !writePrepared { + err = conn.WriteMessage(mt, b) + if err != nil { + log.Println("WriteMessage:", err) + } + } else { + pm, err := websocket.NewPreparedMessage(mt, b) + if err != nil { + log.Println("NewPreparedMessage:", err) + return + } + err = conn.WritePreparedMessage(pm) + if err != nil { + log.Println("WritePreparedMessage:", err) + } + } + } else { + w, err := conn.NextWriter(mt) + if err != nil { + log.Println("NextWriter:", err) + return + } + if _, err := w.Write(b); err != nil { + log.Println("Writer:", err) + return + } + if err := w.Close(); err != nil { + log.Println("Close:", err) + return + } + } + } + }) + if err != nil { + log.Println("Upgrade:", err) + return + } +} + +func echoReadAllWriter(_ context.Context, c *app.RequestContext) { + echoReadAll(c, false, false) +} + +func echoReadAllWriteMessage(_ context.Context, c *app.RequestContext) { + echoReadAll(c, true, false) +} + +func echoReadAllWritePreparedMessage(_ context.Context, c *app.RequestContext) { + echoReadAll(c, true, true) +} + +func serveHome(_ context.Context, c *app.RequestContext) { + if string(c.URI().Path()) != "/" { + _ = c.AbortWithError(http.StatusNotFound, errors.New("not found")) + return + } + if !c.IsGet() { + _ = c.AbortWithError(http.StatusMethodNotAllowed, errors.New("method not allowed")) + return + } + c.Response.Header.Set("Content-Type", "text/html; charset=utf-8") + c.String(consts.StatusOK, "Echo Server") +} + +var addr = ":9000" + +func main() { + // server.Default() creates a Hertz with recovery middleware. + // If you need a pure hertz, you can use server.New() + h := server.Default(server.WithHostPorts(addr)) + h.GET("/", serveHome) + h.GET("/c", echoCopyWriterOnly) + h.GET("/f", echoCopyFull) + h.GET("/r", echoReadAllWriter) + h.GET("/m", echoReadAllWriteMessage) + h.GET("/p", echoReadAllWritePreparedMessage) + + h.NoRoute(func(c context.Context, ctx *app.RequestContext) { + ctx.AbortWithMsg("Unsupported path", consts.StatusNotFound) + }) + + h.Spin() +} + +type validator struct { + state int + x rune + r io.Reader +} + +var errInvalidUTF8 = errors.New("invalid utf8") + +func (r *validator) Read(p []byte) (int, error) { + n, err := r.r.Read(p) + state := r.state + x := r.x + for _, b := range p[:n] { + state, x = decode(state, x, b) + if state == utf8Reject { + break + } + } + r.state = state + r.x = x + if state == utf8Reject || (err == io.EOF && state != utf8Accept) { + return n, errInvalidUTF8 + } + return n, err +} + +// UTF-8 decoder from http://bjoern.hoehrmann.de/utf-8/decoder/dfa/ +// +// Copyright (c) 2008-2009 Bjoern Hoehrmann +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to +// deal in the Software without restriction, including without limitation the +// rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +// sell copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in +// all copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +// FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +// IN THE SOFTWARE. +var utf8d = [...]byte{ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 00..1f + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 20..3f + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 40..5f + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 60..7f + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, // 80..9f + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, // a0..bf + 8, 8, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, // c0..df + 0xa, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x3, 0x4, 0x3, 0x3, // e0..ef + 0xb, 0x6, 0x6, 0x6, 0x5, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, 0x8, // f0..ff + 0x0, 0x1, 0x2, 0x3, 0x5, 0x8, 0x7, 0x1, 0x1, 0x1, 0x4, 0x6, 0x1, 0x1, 0x1, 0x1, // s0..s0 + 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 0, 1, 1, 1, 1, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, // s1..s2 + 1, 2, 1, 1, 1, 1, 1, 2, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, // s3..s4 + 1, 2, 1, 1, 1, 1, 1, 1, 1, 2, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 3, 1, 1, 1, 1, 1, 1, // s5..s6 + 1, 3, 1, 1, 1, 1, 1, 3, 1, 3, 1, 1, 1, 1, 1, 1, 1, 3, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, // s7..s8 +} + +const ( + utf8Accept = 0 + utf8Reject = 1 +) + +func decode(state int, x rune, b byte) (int, rune) { + t := utf8d[b] + if state != utf8Accept { + x = rune(b&0x3f) | (x << 6) + } else { + x = rune((0xff >> t) & b) + } + state = int(utf8d[256+state*16+int(t)]) + return state, x +} diff --git a/examples/chat/README.md b/examples/chat/README.md new file mode 100644 index 0000000..7bff1e1 --- /dev/null +++ b/examples/chat/README.md @@ -0,0 +1,102 @@ +# Chat Example + +This application shows how to use the +[websocket](https://github.com/hertz-contrib/websocket) package to implement a simple +web chat application. + +## Running the example + +The example requires a working Go development environment. The [Getting +Started](http://golang.org/doc/install) page describes how to install the +development environment. + +Once you have Go up and running, you can download, build and run the example +using the following commands. + + $ go get https://github.com/hertz-contrib/websocket + $ cd `go list -f '{{.Dir}}' https://github.com/hertz-contrib/websocket/examples/chat` + $ go run *.go + +To use the chat example, open http://localhost:8080/ in your browser. + +## Server + +The server application defines two types, `Client` and `Hub`. The server +creates an instance of the `Client` type for each websocket connection. A +`Client` acts as an intermediary between the websocket connection and a single +instance of the `Hub` type. The `Hub` maintains a set of registered clients and +broadcasts messages to the clients. + +The application runs one goroutine for the `Hub` and two goroutines for each +`Client`. The goroutines communicate with each other using channels. The `Hub` +has channels for registering clients, unregistering clients and broadcasting +messages. A `Client` has a buffered channel of outbound messages. One of the +client's goroutines reads messages from this channel and writes the messages to +the websocket. The other client goroutine reads messages from the websocket and +sends them to the hub. + +### Hub + +The code for the `Hub` type is in +[hub.go](https://https://github.com/hertz-contrib/websocket/blob/master/examples/chat/hub.go). +The application's `main` function starts the hub's `run` method as a goroutine. +Clients send requests to the hub using the `register`, `unregister` and +`broadcast` channels. + +The hub registers clients by adding the client pointer as a key in the +`clients` map. The map value is always true. + +The unregister code is a little more complicated. In addition to deleting the +client pointer from the `clients` map, the hub closes the clients's `send` +channel to signal the client that no more messages will be sent to the client. + +The hub handles messages by looping over the registered clients and sending the +message to the client's `send` channel. If the client's `send` buffer is full, +then the hub assumes that the client is dead or stuck. In this case, the hub +unregisters the client and closes the websocket. + +### Client + +The code for the `Client` type is in [client.go](https://https://github.com/hertz-contrib/websocket/blob/master/examples/chat/client.go). + +The `serveWs` function is registered by the application's `main` function as +an HTTP handler. The handler upgrades the HTTP connection to the WebSocket +protocol, creates a client, registers the client with the hub and schedules the +client to be unregistered using a defer statement. + +Next, the HTTP handler starts the client's `writePump` method as a goroutine. +This method transfers messages from the client's send channel to the websocket +connection. The writer method exits when the channel is closed by the hub or +there's an error writing to the websocket connection. + +Finally, the HTTP handler calls the client's `readPump` method. This method +transfers inbound messages from the websocket to the hub. + +WebSocket connections [support one concurrent reader and one concurrent +writer](https://godoc.org/github.com/gorilla/websocket#hdr-Concurrency). The +application ensures that these concurrency requirements are met by executing +all reads from the `readPump` goroutine and all writes from the `writePump` +goroutine. + +To improve efficiency under high load, the `writePump` function coalesces +pending chat messages in the `send` channel to a single WebSocket message. This +reduces the number of system calls and the amount of data sent over the +network. + +## Frontend + +The frontend code is in [home.html](https://https://github.com/hertz-contrib/websocket/blob/master/examples/chat/home.html). + +On document load, the script checks for websocket functionality in the browser. +If websocket functionality is available, then the script opens a connection to +the server and registers a callback to handle messages from the server. The +callback appends the message to the chat log using the appendLog function. + +To allow the user to manually scroll through the chat log without interruption +from new messages, the `appendLog` function checks the scroll position before +adding new content. If the chat log is scrolled to the bottom, then the +function scrolls new content into view after adding the content. Otherwise, the +scroll position is not changed. + +The form handler writes the user input to the websocket and clears the input +field. diff --git a/examples/chat/client.go b/examples/chat/client.go new file mode 100644 index 0000000..34ebbc2 --- /dev/null +++ b/examples/chat/client.go @@ -0,0 +1,133 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file may have been modified by CloudWeGo authors. All CloudWeGo +// Modifications are Copyright 2022 CloudWeGo Authors. + +package main + +import ( + "bytes" + "log" + "time" + + "github.com/cloudwego/hertz/pkg/app" + "github.com/hertz-contrib/websocket" +) + +const ( + // Time allowed to write a message to the peer. + writeWait = 10 * time.Second + + // Time allowed to read the next pong message from the peer. + pongWait = 60 * time.Second + + // Send pings to peer with this period. Must be less than pongWait. + pingPeriod = (pongWait * 9) / 10 + + // Maximum message size allowed from peer. + maxMessageSize = 512 +) + +var ( + newline = []byte{'\n'} + space = []byte{' '} +) + +// Client is a middleman between the websocket connection and the hub. +type Client struct { + hub *Hub + + // The websocket connection. + conn *websocket.Conn + + // Buffered channel of outbound messages. + send chan []byte +} + +// readPump pumps messages from the websocket connection to the hub. +// +// The application runs readPump in a per-connection goroutine. The application +// ensures that there is at most one reader on a connection by executing all +// reads from this goroutine. +func (c *Client) readPump() { + defer func() { + c.hub.unregister <- c + c.conn.Close() + }() + c.conn.SetReadLimit(maxMessageSize) + c.conn.SetReadDeadline(time.Now().Add(pongWait)) + c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) + for { + _, message, err := c.conn.ReadMessage() + if err != nil { + if websocket.IsUnexpectedCloseError(err, websocket.CloseGoingAway, websocket.CloseAbnormalClosure) { + log.Printf("error: %v", err) + } + break + } + message = bytes.TrimSpace(bytes.Replace(message, newline, space, -1)) + c.hub.broadcast <- message + } +} + +// writePump pumps messages from the hub to the websocket connection. +// +// A goroutine running writePump is started for each connection. The +// application ensures that there is at most one writer to a connection by +// executing all writes from this goroutine. +func (c *Client) writePump() { + ticker := time.NewTicker(pingPeriod) + defer func() { + ticker.Stop() + c.conn.Close() + }() + for { + select { + case message, ok := <-c.send: + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if !ok { + // The hub closed the channel. + c.conn.WriteMessage(websocket.CloseMessage, []byte{}) + return + } + + w, err := c.conn.NextWriter(websocket.TextMessage) + if err != nil { + return + } + w.Write(message) + + // Add queued chat messages to the current websocket message. + n := len(c.send) + for i := 0; i < n; i++ { + w.Write(newline) + w.Write(<-c.send) + } + + if err := w.Close(); err != nil { + return + } + case <-ticker.C: + c.conn.SetWriteDeadline(time.Now().Add(writeWait)) + if err := c.conn.WriteMessage(websocket.PingMessage, nil); err != nil { + return + } + } + } +} + +// serveWs handles websocket requests from the peer. +func serveWs(ctx *app.RequestContext, hub *Hub) { + err := upgrader.Upgrade(ctx, func(conn *websocket.Conn) { + client := &Client{hub: hub, conn: conn, send: make(chan []byte, 256)} + client.hub.register <- client + + go client.writePump() + client.readPump() + }) + if err != nil { + log.Println(err) + } +} diff --git a/examples/chat/home.html b/examples/chat/home.html new file mode 100644 index 0000000..bf866af --- /dev/null +++ b/examples/chat/home.html @@ -0,0 +1,98 @@ + + + +Chat Example + + + + +
+
+ + +
+ + diff --git a/examples/chat/hub.go b/examples/chat/hub.go new file mode 100644 index 0000000..fb25735 --- /dev/null +++ b/examples/chat/hub.go @@ -0,0 +1,56 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file may have been modified by CloudWeGo authors. All CloudWeGo +// Modifications are Copyright 2022 CloudWeGo Authors. + +package main + +// Hub maintains the set of active clients and broadcasts messages to the +// clients. +type Hub struct { + // Registered clients. + clients map[*Client]bool + + // Inbound messages from the clients. + broadcast chan []byte + + // Register requests from the clients. + register chan *Client + + // Unregister requests from clients. + unregister chan *Client +} + +func newHub() *Hub { + return &Hub{ + broadcast: make(chan []byte), + register: make(chan *Client), + unregister: make(chan *Client), + clients: make(map[*Client]bool), + } +} + +func (h *Hub) run() { + for { + select { + case client := <-h.register: + h.clients[client] = true + case client := <-h.unregister: + if _, ok := h.clients[client]; ok { + delete(h.clients, client) + close(client.send) + } + case message := <-h.broadcast: + for client := range h.clients { + select { + case client.send <- message: + default: + close(client.send) + delete(h.clients, client) + } + } + } + } +} diff --git a/examples/chat/main.go b/examples/chat/main.go new file mode 100644 index 0000000..63863ff --- /dev/null +++ b/examples/chat/main.go @@ -0,0 +1,53 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file may have been modified by CloudWeGo authors. All CloudWeGo +// Modifications are Copyright 2022 CloudWeGo Authors. + +package main + +import ( + "context" + "net/http" + + "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/app/server" + "github.com/cloudwego/hertz/pkg/common/hlog" + "github.com/hertz-contrib/websocket" +) + +var upgrader = websocket.HertzUpgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, +} + +var addr = ":8080" + +func serveHome(_ context.Context, c *app.RequestContext) { + if string(c.URI().Path()) != "/" { + hlog.Error("Not found", http.StatusNotFound) + return + } + if !c.IsGet() { + hlog.Error("Method not allowed", http.StatusMethodNotAllowed) + return + } + c.HTML(http.StatusOK, "home.html", nil) +} + +func main() { + hub := newHub() + go hub.run() + // server.Default() creates a Hertz with recovery middleware. + // If you need a pure hertz, you can use server.New() + h := server.Default(server.WithHostPorts(addr)) + h.LoadHTMLGlob("home.html") + + h.GET("/", serveHome) + h.GET("/ws", func(c context.Context, ctx *app.RequestContext) { + serveWs(ctx, hub) + }) + + h.Spin() +} diff --git a/examples/command/README.md b/examples/command/README.md new file mode 100644 index 0000000..901d666 --- /dev/null +++ b/examples/command/README.md @@ -0,0 +1,17 @@ +# Command example + +This example connects a websocket connection to stdin and stdout of a command. +Received messages are written to stdin followed by a `\n`. Each line read from +standard out is sent as a message to the client. + + $ go run main.go + # Open http://localhost:8080/ . + +Try the following commands. + + # Echo sent messages to the output area. + $ go run main.go cat + + # Run a shell.Try sending "ls" and "cat main.go". + $ go run main.go sh + diff --git a/examples/command/home.html b/examples/command/home.html new file mode 100644 index 0000000..19c4612 --- /dev/null +++ b/examples/command/home.html @@ -0,0 +1,102 @@ + + + +Command Example + + + + +
+
+ + +
+ + diff --git a/examples/command/main.go b/examples/command/main.go new file mode 100644 index 0000000..3fa1217 --- /dev/null +++ b/examples/command/main.go @@ -0,0 +1,203 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file may have been modified by CloudWeGo authors. All CloudWeGo +// Modifications are Copyright 2022 CloudWeGo Authors. + +package main + +import ( + "bufio" + "context" + "flag" + "io" + "log" + "net/http" + "os" + "os/exec" + "time" + + "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/app/server" + "github.com/cloudwego/hertz/pkg/common/hlog" + "github.com/hertz-contrib/websocket" +) + +var ( + addr = flag.String("addr", "127.0.0.1:8080", "http service address") + cmdPath string +) + +const ( + // Time allowed to write a message to the peer. + writeWait = 10 * time.Second + + // Maximum message size allowed from peer. + maxMessageSize = 8192 + + // Time allowed to read the next pong message from the peer. + pongWait = 60 * time.Second + + // Send pings to peer with this period. Must be less than pongWait. + pingPeriod = (pongWait * 9) / 10 + + // Time to wait before force close on connection. + closeGracePeriod = 10 * time.Second +) + +func pumpStdin(ws *websocket.Conn, w io.Writer) { + defer ws.Close() + ws.SetReadLimit(maxMessageSize) + ws.SetReadDeadline(time.Now().Add(pongWait)) + ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(pongWait)); return nil }) + for { + _, message, err := ws.ReadMessage() + if err != nil { + break + } + message = append(message, '\n') + if _, err := w.Write(message); err != nil { + break + } + } +} + +func pumpStdout(ws *websocket.Conn, r io.Reader, done chan struct{}) { + defer func() { + }() + s := bufio.NewScanner(r) + for s.Scan() { + ws.SetWriteDeadline(time.Now().Add(writeWait)) + if err := ws.WriteMessage(websocket.TextMessage, s.Bytes()); err != nil { + ws.Close() + break + } + } + if s.Err() != nil { + log.Println("scan:", s.Err()) + } + close(done) + + ws.SetWriteDeadline(time.Now().Add(writeWait)) + ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + time.Sleep(closeGracePeriod) + ws.Close() +} + +func ping(ws *websocket.Conn, done chan struct{}) { + ticker := time.NewTicker(pingPeriod) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil { + log.Println("ping:", err) + } + case <-done: + return + } + } +} + +func internalError(ws *websocket.Conn, msg string, err error) { + log.Println(msg, err) + ws.WriteMessage(websocket.TextMessage, []byte("Internal server error.")) +} + +var upgrader = websocket.HertzUpgrader{} + +func serveWs(c context.Context, ctx *app.RequestContext) { + err := upgrader.Upgrade(ctx, func(ws *websocket.Conn) { + defer ws.Close() + + outr, outw, err := os.Pipe() + if err != nil { + internalError(ws, "stdout:", err) + return + } + defer outr.Close() + defer outw.Close() + + inr, inw, err := os.Pipe() + if err != nil { + internalError(ws, "stdin:", err) + return + } + defer inr.Close() + defer inw.Close() + + proc, err := os.StartProcess(cmdPath, flag.Args(), &os.ProcAttr{ + Files: []*os.File{inr, outw, outw}, + }) + if err != nil { + internalError(ws, "start:", err) + return + } + + inr.Close() + outw.Close() + + stdoutDone := make(chan struct{}) + go pumpStdout(ws, outr, stdoutDone) + go ping(ws, stdoutDone) + + pumpStdin(ws, inw) + + // Some commands will exit when stdin is closed. + inw.Close() + + // Other commands need a bonk on the head. + if err := proc.Signal(os.Interrupt); err != nil { + log.Println("inter:", err) + } + + select { + case <-stdoutDone: + case <-time.After(time.Second): + // A bigger bonk on the head. + if err := proc.Signal(os.Kill); err != nil { + log.Println("term:", err) + } + <-stdoutDone + } + + if _, err := proc.Wait(); err != nil { + log.Println("wait:", err) + } + }) + if err != nil { + log.Println("upgrade:", err) + return + } +} + +func serveHome(_ context.Context, c *app.RequestContext) { + log.Println(string(c.URI().FullURI())) + if string(c.URI().Path()) != "/" { + hlog.Error("Not found", http.StatusNotFound) + return + } + if !c.IsGet() { + hlog.Error("Method not allowed", http.StatusMethodNotAllowed) + return + } + c.HTML(http.StatusOK, "home.html", nil) +} + +func main() { + flag.Parse() + if len(flag.Args()) < 1 { + log.Fatal("must specify at least one argument") + } + var err error + cmdPath, err = exec.LookPath(flag.Args()[0]) + if err != nil { + log.Fatal(err) + } + h := server.Default(server.WithHostPorts(*addr)) + h.LoadHTMLGlob("home.html") + h.GET("/", serveHome) + h.GET("/ws", serveWs) + h.Spin() +} diff --git a/examples/echo/README.md b/examples/echo/README.md new file mode 100644 index 0000000..6ad79ed --- /dev/null +++ b/examples/echo/README.md @@ -0,0 +1,17 @@ +# Client and server example + +This example shows a simple client and server. + +The server echoes messages sent to it. The client sends a message every second +and prints all messages received. + +To run the example, start the server: + + $ go run server.go + +Next, start the client: + + $ go run client.go + +The server includes a simple web client. To use the client, open +http://127.0.0.1:8080 in the browser and follow the instructions on the page. diff --git a/examples/echo/client.go b/examples/echo/client.go new file mode 100644 index 0000000..a225ce4 --- /dev/null +++ b/examples/echo/client.go @@ -0,0 +1,86 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file may have been modified by CloudWeGo authors. All CloudWeGo +// Modifications are Copyright 2022 CloudWeGo Authors. + +//go:build ignore +// +build ignore + +package main + +import ( + "flag" + "log" + "net/url" + "os" + "os/signal" + "time" + + "github.com/gorilla/websocket" +) + +var addr = flag.String("addr", "localhost:8080", "http service address") + +func main() { + flag.Parse() + log.SetFlags(0) + + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, os.Interrupt) + + u := url.URL{Scheme: "ws", Host: *addr, Path: "/echo"} + log.Printf("connecting to %s", u.String()) + + c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + log.Fatal("dial:", err) + } + defer c.Close() + + done := make(chan struct{}) + + go func() { + defer close(done) + for { + _, message, err := c.ReadMessage() + if err != nil { + log.Println("read:", err) + return + } + log.Printf("recv: %s", message) + } + }() + + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + for { + select { + case <-done: + return + case t := <-ticker.C: + err := c.WriteMessage(websocket.TextMessage, []byte(t.String())) + if err != nil { + log.Println("write:", err) + return + } + case <-interrupt: + log.Println("interrupt") + + // Cleanly close the connection by sending a close message and then + // waiting (with timeout) for the server to close the connection. + err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + if err != nil { + log.Println("write close:", err) + return + } + select { + case <-done: + case <-time.After(time.Second): + } + return + } + } +} diff --git a/examples/echo/server.go b/examples/echo/server.go new file mode 100644 index 0000000..8f6a3f9 --- /dev/null +++ b/examples/echo/server.go @@ -0,0 +1,142 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file may have been modified by CloudWeGo authors. All CloudWeGo +// Modifications are Copyright 2022 CloudWeGo Authors. + +//go:build ignore +// +build ignore + +package main + +import ( + "context" + "flag" + "html/template" + "log" + + "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/app/server" + "github.com/hertz-contrib/websocket" +) + +var addr = flag.String("addr", "localhost:8080", "http service address") + +var upgrader = websocket.HertzUpgrader{} // use default options + +func echo(_ context.Context, c *app.RequestContext) { + err := upgrader.Upgrade(c, func(conn *websocket.Conn) { + for { + mt, message, err := conn.ReadMessage() + if err != nil { + log.Println("read:", err) + break + } + log.Printf("recv: %s", message) + err = conn.WriteMessage(mt, message) + if err != nil { + log.Println("write:", err) + break + } + } + }) + if err != nil { + log.Print("upgrade:", err) + return + } +} + +func home(_ context.Context, c *app.RequestContext) { + homeTemplate.Execute(c.GetConn(), "ws://"+string(c.Host())+"/echo") +} + +func main() { + flag.Parse() + h := server.Default(server.WithHostPorts(*addr)) + // https://github.com/cloudwego/hertz/issues/121 + h.NoHijackConnPool = true + h.GET("/", home) + h.GET("/echo", echo) + h.Spin() +} + +var homeTemplate = template.Must(template.New("").Parse(` + + + + + + + + +
+

Click "Open" to create a connection to the server, +"Send" to send a message to the server and "Close" to close the connection. +You can change the message and send multiple times. +

+

+ + +

+ +

+
+
+
+ + +`)) diff --git a/examples/filewatch/README.md b/examples/filewatch/README.md new file mode 100644 index 0000000..49bc779 --- /dev/null +++ b/examples/filewatch/README.md @@ -0,0 +1,7 @@ +# File Watch example. + +This example sends a file to the browser client for display whenever the file is modified. + + $ go run main.go + # Open http://localhost:8080/ . + # Modify the file to see it update in the browser. diff --git a/examples/filewatch/main.go b/examples/filewatch/main.go new file mode 100644 index 0000000..e089b4b --- /dev/null +++ b/examples/filewatch/main.go @@ -0,0 +1,204 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file may have been modified by CloudWeGo authors. All CloudWeGo +// Modifications are Copyright 2022 CloudWeGo Authors. + +package main + +import ( + "context" + "flag" + "html/template" + "io/ioutil" + "log" + "os" + "strconv" + "time" + + "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/app/server" + "github.com/cloudwego/hertz/pkg/protocol/consts" + "github.com/hertz-contrib/websocket" +) + +const ( + // Time allowed to write the file to the client. + writeWait = 10 * time.Second + + // Time allowed to read the next pong message from the client. + pongWait = 60 * time.Second + + // Send pings to client with this period. Must be less than pongWait. + pingPeriod = (pongWait * 9) / 10 + + // Poll file for changes with this period. + filePeriod = 10 * time.Second +) + +var ( + addr = flag.String("addr", ":8080", "http service address") + homeTempl = template.Must(template.New("").Parse(homeHTML)) + filename string + upgrader = websocket.HertzUpgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + } +) + +func readFileIfModified(lastMod time.Time) ([]byte, time.Time, error) { + fi, err := os.Stat(filename) + if err != nil { + return nil, lastMod, err + } + if !fi.ModTime().After(lastMod) { + return nil, lastMod, nil + } + p, err := ioutil.ReadFile(filename) + if err != nil { + return nil, fi.ModTime(), err + } + return p, fi.ModTime(), nil +} + +func reader(ws *websocket.Conn) { + defer ws.Close() + ws.SetReadLimit(512) + ws.SetReadDeadline(time.Now().Add(pongWait)) + ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(pongWait)); return nil }) + for { + _, _, err := ws.ReadMessage() + if err != nil { + break + } + } +} + +func writer(ws *websocket.Conn, lastMod time.Time) { + lastError := "" + pingTicker := time.NewTicker(pingPeriod) + fileTicker := time.NewTicker(filePeriod) + defer func() { + pingTicker.Stop() + fileTicker.Stop() + ws.Close() + }() + for { + select { + case <-fileTicker.C: + var p []byte + var err error + + p, lastMod, err = readFileIfModified(lastMod) + + if err != nil { + if s := err.Error(); s != lastError { + lastError = s + p = []byte(lastError) + } + } else { + lastError = "" + } + + if p != nil { + ws.SetWriteDeadline(time.Now().Add(writeWait)) + if err := ws.WriteMessage(websocket.TextMessage, p); err != nil { + return + } + } + case <-pingTicker.C: + ws.SetWriteDeadline(time.Now().Add(writeWait)) + if err := ws.WriteMessage(websocket.PingMessage, []byte{}); err != nil { + return + } + } + } +} + +func serveWs(c context.Context, ctx *app.RequestContext) { + err := upgrader.Upgrade(ctx, func(ws *websocket.Conn) { + var lastMod time.Time + if n, err := strconv.ParseInt(string(ctx.FormValue("lastMod")), 16, 64); err == nil { + lastMod = time.Unix(0, n) + } + + go writer(ws, lastMod) + reader(ws) + }) + if err != nil { + if _, ok := err.(websocket.HandshakeError); ok { + log.Println(err) + } + return + } +} + +func serveHome(c context.Context, ctx *app.RequestContext) { + if !ctx.IsGet() { + ctx.AbortWithMsg("Method not allowed", consts.StatusMethodNotAllowed) + return + } + + ctx.SetContentType("text/html; charset=utf-8") + + p, lastMod, err := readFileIfModified(time.Time{}) + if err != nil { + p = []byte(err.Error()) + lastMod = time.Unix(0, 0) + } + v := struct { + Host string + Data string + LastMod string + }{ + string(ctx.Host()), + string(p), + strconv.FormatInt(lastMod.UnixNano(), 16), + } + homeTempl.Execute(ctx, &v) +} + +func main() { + flag.Parse() + if flag.NArg() != 1 { + log.Fatal("filename not specified") + } + filename = flag.Args()[0] + + h := server.New(server.WithHostPorts(*addr)) + + h.GET("/", serveHome) + + h.GET("/ws", serveWs) + + h.NoRoute(func(c context.Context, ctx *app.RequestContext) { + ctx.AbortWithMsg("Unsupported path", consts.StatusNotFound) + }) + + h.Spin() +} + +const homeHTML = ` + + + WebSocket Example + + +
{{.Data}}
+ + + +` diff --git a/go.mod b/go.mod index e5f7088..56749f7 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,7 @@ module github.com/hertz-contrib/websocket -go 1.18 +go 1.16 + +require github.com/cloudwego/hertz v0.3.0 + +require github.com/savsgio/gotils v0.0.0-20220530130905-52f3993e8d6d diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..84e67e4 --- /dev/null +++ b/go.sum @@ -0,0 +1,77 @@ +github.com/bytedance/go-tagexpr/v2 v2.9.2 h1:QySJaAIQgOEDQBLS3x9BxOWrnhqu5sQ+f6HaZIxD39I= +github.com/bytedance/go-tagexpr/v2 v2.9.2/go.mod h1:5qsx05dYOiUXOUgnQ7w3Oz8BYs2qtM/bJokdLb79wRM= +github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7 h1:PtwsQyQJGxf8iaPptPNaduEIu9BnrNms+pcRdHAxZaM= +github.com/bytedance/gopkg v0.0.0-20220413063733-65bf48ffb3a7/go.mod h1:2ZlV9BaUH4+NXIBF0aMdKKAnHTzqH+iMU4KUjAbL23Q= +github.com/bytedance/sonic v1.3.0 h1:T2rlvNytw6bTmczlAXvGqmuMzIqGJBOsJKYwRPWR7Y8= +github.com/bytedance/sonic v1.3.0/go.mod h1:V973WhNhGmvHxW6nQmsHEfHaoU9F3zTF+93rH03hcUQ= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06 h1:1sDoSuDPWzhkdzNVxCxtIaKiAe96ESVPv8coGwc1gZ4= +github.com/chenzhuoyu/base64x v0.0.0-20211019084208-fb5309c8db06/go.mod h1:DH46F32mSOjUmXrMHnKwZdA8wcEefY7UVqBKYGjpdQY= +github.com/cloudwego/hertz v0.3.0 h1:gyKaw/uc7I2YLuCzqWHD3G2W8MVKRZPZ5u8k8tiDBbk= +github.com/cloudwego/hertz v0.3.0/go.mod h1:GWWYlAVkq1gDu6vJd/XNciWsP6q0d4TrEKk5fpJYF04= +github.com/cloudwego/netpoll v0.2.4 h1:Kbo2HA1cXEgoy/bu1jSNrjcqZj2diENcJqLy6vKiROU= +github.com/cloudwego/netpoll v0.2.4/go.mod h1:1T2WVuQ+MQw6h6DpE45MohSvDTKdy2DlzCx2KsnPI4E= +github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= +github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/goccy/go-json v0.9.4 h1:L8MLKG2mvVXiQu07qB6hmfqeSYQdOnqPot2GhsIwIaI= +github.com/goccy/go-json v0.9.4/go.mod h1:6MelG93GURQebXPDq3khkgXZkazVtN9CRI+MGFi0w8I= +github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U= +github.com/golang/protobuf v1.5.0 h1:LUVKkCeviFUMKqHa4tXIIij/lbhnMbP7Fn5wKdKkRh4= +github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk= +github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU= +github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE= +github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg= +github.com/google/uuid v1.3.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= +github.com/henrylee2cn/ameda v1.4.8/go.mod h1:liZulR8DgHxdK+MEwvZIylGnmcjzQ6N6f2PlWe7nEO4= +github.com/henrylee2cn/ameda v1.4.10 h1:JdvI2Ekq7tapdPsuhrc4CaFiqw6QXFvZIULWJgQyCAk= +github.com/henrylee2cn/ameda v1.4.10/go.mod h1:liZulR8DgHxdK+MEwvZIylGnmcjzQ6N6f2PlWe7nEO4= +github.com/henrylee2cn/goutil v0.0.0-20210127050712-89660552f6f8 h1:yE9ULgp02BhYIrO6sdV/FPe0xQM6fNHkVQW2IAymfM0= +github.com/henrylee2cn/goutil v0.0.0-20210127050712-89660552f6f8/go.mod h1:Nhe/DM3671a5udlv2AdV2ni/MZzgfv2qrPL5nIi3EGQ= +github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnrnM= +github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo= +github.com/klauspost/cpuid/v2 v2.0.9 h1:lgaqFMSdTdQYdZ04uHyN2d/eKdOMyi2YLSvlQIBFYa4= +github.com/klauspost/cpuid/v2 v2.0.9/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421 h1:ZqeYNhU3OHLH3mGKHDcjJRFFRrJa6eAM5H+CtDdOsPc= +github.com/modern-go/concurrent v0.0.0-20180228061459-e0a39a4cb421/go.mod h1:6dJC0mAP4ikYIbvyc7fijjWJddQyLn8Ig3JB5CqoB9Q= +github.com/modern-go/reflect2 v1.0.2 h1:xBagoLtFs94CBntxluKeaWgTMpvLxC4ur3nMaC9Gz0M= +github.com/modern-go/reflect2 v1.0.2/go.mod h1:yWuevngMOJpCy52FWWMvUC8ws7m/LJsjYzDa0/r8luk= +github.com/nyaruka/phonenumbers v1.0.55 h1:bj0nTO88Y68KeUQ/n3Lo2KgK7lM1hF7L9NFuwcCl3yg= +github.com/nyaruka/phonenumbers v1.0.55/go.mod h1:sDaTZ/KPX5f8qyV9qN+hIm+4ZBARJrupC6LuhshJq1U= +github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= +github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= +github.com/savsgio/gotils v0.0.0-20220530130905-52f3993e8d6d h1:Q+gqLBOPkFGHyCJxXMRqtUgUbTjI8/Ze8vu8GGyNFwo= +github.com/savsgio/gotils v0.0.0-20220530130905-52f3993e8d6d/go.mod h1:Gy+0tqhJvgGlqnTF8CVGP0AaGRjwBtXs/a5PA0Y3+A4= +github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= +github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI= +github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4= +github.com/stretchr/testify v1.5.1/go.mod h1:5W2xD1RspED5o8YsWQXVCued0rvSQ+mT+I5cxcmMvtA= +github.com/stretchr/testify v1.7.0 h1:nwc3DEeHmmLAfoZucVR881uASk0Mfjw8xYJ99tb5CcY= +github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= +github.com/tidwall/gjson v1.9.3/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.12.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/gjson v1.13.0 h1:3TFY9yxOQShrvmjdM76K+jc66zJeT6D3/VFFYCGQf7M= +github.com/tidwall/gjson v1.13.0/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= +github.com/tidwall/match v1.1.1 h1:+Ho715JplO36QYgwN9PGYNhgZvoUSc9X2c80KVTi+GA= +github.com/tidwall/match v1.1.1/go.mod h1:eRSPERbgtNPcGhD8UCthc6PmLEQXEWd3PRB5JTxsfmM= +github.com/tidwall/pretty v1.2.0 h1:RWIZEg2iJ8/g6fDDYzMpobmaoGh5OLl4AXtGUGPcqCs= +github.com/tidwall/pretty v1.2.0/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= +github.com/tidwall/sjson v1.2.4 h1:cuiLzLnaMeBhRmEv00Lpk3tkYrcxpmbU81tAY4Dw0tc= +github.com/tidwall/sjson v1.2.4/go.mod h1:098SZ494YoMWPmMO6ct4dcFnqxwj9r/gF0Etp19pSNM= +github.com/twitchyliquid64/golang-asm v0.15.1 h1:SU5vSMR7hnwNxj24w34ZyCi/FmDZTkS4MhqMhdFk5YI= +github.com/twitchyliquid64/golang-asm v0.15.1/go.mod h1:a1lVb/DtPvCB8fslRZhAngC2+aY1QWCk3Cedj/Gdt08= +github.com/valyala/bytebufferpool v1.0.0/go.mod h1:6bBcMArwyJ5K/AmCkWv1jt77kVWyCJ6HpOuEn7z0Csc= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670 h1:18EFjUmQOcUvxNYSkA6jO9VAiXCnxFY6NyDX0bHDmkU= +golang.org/x/arch v0.0.0-20210923205945-b76863e36670/go.mod h1:5om86z9Hs0C8fWVUuoMHwpExlXzs5Tkyp9hOrfG7pp8= +golang.org/x/sync v0.0.0-20210220032951-036812b2e83c/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= +golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe h1:W8vbETX/n8S6EmY0Pu4Ix7VvpsJUESTwl0oCK8MJOgk= +golang.org/x/sys v0.0.0-20220110181412-a018aaa089fe/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4= +golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= +google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw= +google.golang.org/protobuf v1.27.1 h1:SnqbnDw1V7RiZcXPx5MEeqPv2s79L9i7BJUlG/+RurQ= +google.golang.org/protobuf v1.27.1/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc= +gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= +gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo= +gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +rsc.io/pdf v0.1.1/go.mod h1:n8OzWcQ6Sp37PL01nO98y4iUCRdTGarVfzxY20ICaU4= diff --git a/licenses/LICENSE-gotils b/licenses/LICENSE-gotils new file mode 100644 index 0000000..8edd0bc --- /dev/null +++ b/licenses/LICENSE-gotils @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2020-present Sergio Andres Virviescas Santana + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. \ No newline at end of file diff --git a/licenses/LICENSE-websocket b/licenses/LICENSE-websocket new file mode 100644 index 0000000..7590376 --- /dev/null +++ b/licenses/LICENSE-websocket @@ -0,0 +1,22 @@ +Copyright (c) 2013 The Gorilla WebSocket Authors. All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + + Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. \ No newline at end of file diff --git a/mask.go b/mask.go new file mode 100644 index 0000000..d245e8f --- /dev/null +++ b/mask.go @@ -0,0 +1,58 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file may have been modified by CloudWeGo authors. All CloudWeGo +// Modifications are Copyright 2022 CloudWeGo Authors. + +//go:build !appengine +// +build !appengine + +package websocket + +import "unsafe" + +const wordSize = int(unsafe.Sizeof(uintptr(0))) + +func maskBytes(key [4]byte, pos int, b []byte) int { + // Mask one byte at a time for small buffers. + if len(b) < 2*wordSize { + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + return pos & 3 + } + + // Mask one byte at a time to word boundary. + if n := int(uintptr(unsafe.Pointer(&b[0]))) % wordSize; n != 0 { + n = wordSize - n + for i := range b[:n] { + b[i] ^= key[pos&3] + pos++ + } + b = b[n:] + } + + // Create aligned word size key. + var k [wordSize]byte + for i := range k { + k[i] = key[(pos+i)&3] + } + kw := *(*uintptr)(unsafe.Pointer(&k)) + + // Mask one word at a time. + n := (len(b) / wordSize) * wordSize + for i := 0; i < n; i += wordSize { + *(*uintptr)(unsafe.Pointer(uintptr(unsafe.Pointer(&b[0])) + uintptr(i))) ^= kw + } + + // Mask one byte at a time for remaining bytes. + b = b[n:] + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + + return pos & 3 +} diff --git a/mask_safe.go b/mask_safe.go new file mode 100644 index 0000000..9062b0d --- /dev/null +++ b/mask_safe.go @@ -0,0 +1,19 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file may have been modified by CloudWeGo authors. All CloudWeGo +// Modifications are Copyright 2022 CloudWeGo Authors. + +//go:build appengine +// +build appengine + +package websocket + +func maskBytes(key [4]byte, pos int, b []byte) int { + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + return pos & 3 +} diff --git a/mask_test.go b/mask_test.go new file mode 100644 index 0000000..414e714 --- /dev/null +++ b/mask_test.go @@ -0,0 +1,75 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file may have been modified by CloudWeGo authors. All CloudWeGo +// Modifications are Copyright 2022 CloudWeGo Authors. + +// !appengine + +package websocket + +import ( + "fmt" + "testing" +) + +func maskBytesByByte(key [4]byte, pos int, b []byte) int { + for i := range b { + b[i] ^= key[pos&3] + pos++ + } + return pos & 3 +} + +func notzero(b []byte) int { + for i := range b { + if b[i] != 0 { + return i + } + } + return -1 +} + +func TestMaskBytes(t *testing.T) { + key := [4]byte{1, 2, 3, 4} + for size := 1; size <= 1024; size++ { + for align := 0; align < wordSize; align++ { + for pos := 0; pos < 4; pos++ { + b := make([]byte, size+align)[align:] + maskBytes(key, pos, b) + maskBytesByByte(key, pos, b) + if i := notzero(b); i >= 0 { + t.Errorf("size:%d, align:%d, pos:%d, offset:%d", size, align, pos, i) + } + } + } + } +} + +func BenchmarkMaskBytes(b *testing.B) { + for _, size := range []int{2, 4, 8, 16, 32, 512, 1024} { + b.Run(fmt.Sprintf("size-%d", size), func(b *testing.B) { + for _, align := range []int{wordSize / 2} { + b.Run(fmt.Sprintf("align-%d", align), func(b *testing.B) { + for _, fn := range []struct { + name string + fn func(key [4]byte, pos int, b []byte) int + }{ + {"byte", maskBytesByByte}, + {"word", maskBytes}, + } { + b.Run(fn.name, func(b *testing.B) { + key := newMaskKey() + data := make([]byte, size+align)[align:] + for i := 0; i < b.N; i++ { + fn.fn(key, 0, data) + } + b.SetBytes(int64(len(data))) + }) + } + }) + } + }) + } +} diff --git a/prepared.go b/prepared.go new file mode 100644 index 0000000..8979091 --- /dev/null +++ b/prepared.go @@ -0,0 +1,105 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file may have been modified by CloudWeGo authors. All CloudWeGo +// Modifications are Copyright 2022 CloudWeGo Authors. + +package websocket + +import ( + "bytes" + "net" + "sync" + "time" +) + +// PreparedMessage caches on the wire representations of a message payload. +// Use PreparedMessage to efficiently send a message payload to multiple +// connections. PreparedMessage is especially useful when compression is used +// because the CPU and memory expensive compression operation can be executed +// once for a given set of compression options. +type PreparedMessage struct { + messageType int + data []byte + mu sync.Mutex + frames map[prepareKey]*preparedFrame +} + +// prepareKey defines a unique set of options to cache prepared frames in PreparedMessage. +type prepareKey struct { + isServer bool + compress bool + compressionLevel int +} + +// preparedFrame contains data in wire representation. +type preparedFrame struct { + once sync.Once + data []byte +} + +// NewPreparedMessage returns an initialized PreparedMessage. You can then send +// it to connection using WritePreparedMessage method. Valid wire +// representation will be calculated lazily only once for a set of current +// connection options. +func NewPreparedMessage(messageType int, data []byte) (*PreparedMessage, error) { + pm := &PreparedMessage{ + messageType: messageType, + frames: make(map[prepareKey]*preparedFrame), + data: data, + } + + // Prepare a plain server frame. + _, frameData, err := pm.frame(prepareKey{isServer: true, compress: false}) + if err != nil { + return nil, err + } + + // To protect against caller modifying the data argument, remember the data + // copied to the plain server frame. + pm.data = frameData[len(frameData)-len(data):] + return pm, nil +} + +func (pm *PreparedMessage) frame(key prepareKey) (int, []byte, error) { + pm.mu.Lock() + frame, ok := pm.frames[key] + if !ok { + frame = &preparedFrame{} + pm.frames[key] = frame + } + pm.mu.Unlock() + + var err error + frame.once.Do(func() { + // Prepare a frame using a 'fake' connection. + // TODO: Refactor code in conn.go to allow more direct construction of + // the frame. + mu := make(chan struct{}, 1) + mu <- struct{}{} + var nc prepareConn + c := &Conn{ + conn: &nc, + mu: mu, + isServer: key.isServer, + compressionLevel: key.compressionLevel, + enableWriteCompression: true, + writeBuf: make([]byte, defaultWriteBufferSize+maxFrameHeaderSize), + } + if key.compress { + c.newCompressionWriter = compressNoContextTakeover + } + err = c.WriteMessage(pm.messageType, pm.data) + frame.data = nc.buf.Bytes() + }) + return pm.messageType, frame.data, err +} + +type prepareConn struct { + buf bytes.Buffer + net.Conn +} + +func (pc *prepareConn) Write(p []byte) (int, error) { return pc.buf.Write(p) } +func (pc *prepareConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/prepared_test.go b/prepared_test.go new file mode 100644 index 0000000..9f69f88 --- /dev/null +++ b/prepared_test.go @@ -0,0 +1,77 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file may have been modified by CloudWeGo authors. All CloudWeGo +// Modifications are Copyright 2022 CloudWeGo Authors. + +package websocket + +import ( + "bytes" + "compress/flate" + "math/rand" + "testing" +) + +var preparedMessageTests = []struct { + messageType int + isServer bool + enableWriteCompression bool + compressionLevel int +}{ + // Server + {TextMessage, true, false, flate.BestSpeed}, + {TextMessage, true, true, flate.BestSpeed}, + {TextMessage, true, true, flate.BestCompression}, + {PingMessage, true, false, flate.BestSpeed}, + {PingMessage, true, true, flate.BestSpeed}, + + // Client + {TextMessage, false, false, flate.BestSpeed}, + {TextMessage, false, true, flate.BestSpeed}, + {TextMessage, false, true, flate.BestCompression}, + {PingMessage, false, false, flate.BestSpeed}, + {PingMessage, false, true, flate.BestSpeed}, +} + +func TestPreparedMessage(t *testing.T) { + for _, tt := range preparedMessageTests { + data := []byte("this is a test") + var buf bytes.Buffer + c := newTestConn(nil, &buf, tt.isServer) + if tt.enableWriteCompression { + c.newCompressionWriter = compressNoContextTakeover + } + c.SetCompressionLevel(tt.compressionLevel) + + // Seed random number generator for consistent frame mask. + rand.Seed(1234) + + if err := c.WriteMessage(tt.messageType, data); err != nil { + t.Fatal(err) + } + want := buf.String() + + pm, err := NewPreparedMessage(tt.messageType, data) + if err != nil { + t.Fatal(err) + } + + // Scribble on data to ensure that NewPreparedMessage takes a snapshot. + copy(data, "hello world") + + // Seed random number generator for consistent frame mask. + rand.Seed(1234) + + buf.Reset() + if err := c.WritePreparedMessage(pm); err != nil { + t.Fatal(err) + } + got := buf.String() + + if got != want { + t.Errorf("write message != prepared message for %+v", tt) + } + } +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..75c643d --- /dev/null +++ b/server.go @@ -0,0 +1,232 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file may have been modified by CloudWeGo authors. All CloudWeGo +// Modifications are Copyright 2022 CloudWeGo Authors. + +package websocket + +import ( + "bytes" + "fmt" + "net/url" + "sync" + "time" + + "github.com/cloudwego/hertz/pkg/app" + "github.com/cloudwego/hertz/pkg/network" + "github.com/cloudwego/hertz/pkg/protocol/consts" + "github.com/savsgio/gotils/strconv" +) + +const badHandshake = "websocket: the client is not using the websocket protocol: " + +var strPermessageDeflate = []byte("permessage-deflate") + +// HandshakeError describes an error with the handshake from the peer. +type HandshakeError struct { + message string +} + +func (e HandshakeError) Error() string { return e.message } + +var poolWriteBuffer = sync.Pool{ + New: func() interface{} { + var buf []byte + return buf + }, +} + +// HertzHandler receives a websocket connection after the handshake has been +// completed. This must be provided. +type HertzHandler func(*Conn) + +// HertzUpgrader specifies parameters for upgrading an HTTP connection to a +// WebSocket connection. +type HertzUpgrader struct { + // HandshakeTimeout specifies the duration for the handshake to complete. + HandshakeTimeout time.Duration + + // ReadBufferSize and WriteBufferSize specify I/O buffer sizes in bytes. If a buffer + // size is zero, then buffers allocated by the HTTP server are used. The + // I/O buffer sizes do not limit the size of the messages that can be sent + // or received. + ReadBufferSize, WriteBufferSize int + + // WriteBufferPool is a pool of buffers for write operations. If the value + // is not set, then write buffers are allocated to the connection for the + // lifetime of the connection. + // + // A pool is most useful when the application has a modest volume of writes + // across a large number of connections. + // + // Applications should use a single pool for each unique value of + // WriteBufferSize. + WriteBufferPool BufferPool + + // Subprotocols specifies the server's supported protocols in order of + // preference. If this field is not nil, then the Upgrade method negotiates a + // subprotocol by selecting the first match in this list with a protocol + // requested by the client. If there's no match, then no protocol is + // negotiated (the Sec-Websocket-Protocol header is not included in the + // handshake response). + Subprotocols []string + + // Error specifies the function for generating HTTP error responses. If Error + // is nil, then http.Error is used to generate the HTTP response. + Error func(ctx *app.RequestContext, status int, reason error) + + // CheckOrigin returns true if the request Origin header is acceptable. If + // CheckOrigin is nil, then a safe default is used: return false if the + // Origin request header is present and the origin host is not equal to + // request Host header. + // + // A CheckOrigin function should carefully validate the request origin to + // prevent cross-site request forgery. + CheckOrigin func(ctx *app.RequestContext) bool + + // EnableCompression specify if the server should attempt to negotiate per + // message compression (RFC 7692). Setting this value to true does not + // guarantee that compression will be supported. Currently only "no context + // takeover" modes are supported. + EnableCompression bool +} + +func (u *HertzUpgrader) returnError(ctx *app.RequestContext, status int, reason string) error { + err := HandshakeError{reason} + if u.Error != nil { + u.Error(ctx, status, err) + } else { + ctx.Response.Header.Set("Sec-Websocket-Version", "13") + ctx.AbortWithMsg(consts.StatusMessage(status), status) + } + + return err +} + +func (u *HertzUpgrader) selectSubprotocol(ctx *app.RequestContext) []byte { + if u.Subprotocols != nil { + clientProtocols := parseDataHeader(ctx.Request.Header.Peek("Sec-Websocket-Protocol")) + + for _, serverProtocol := range u.Subprotocols { + for _, clientProtocol := range clientProtocols { + if strconv.B2S(clientProtocol) == serverProtocol { + return clientProtocol + } + } + } + } else if ctx.Response.Header.Len() > 0 { + return ctx.Response.Header.Peek("Sec-Websocket-Protocol") + } + + return nil +} + +func (u *HertzUpgrader) isCompressionEnable(ctx *app.RequestContext) bool { + extensions := parseDataHeader(ctx.Request.Header.Peek("Sec-WebSocket-Extensions")) + + // Negotiate PMCE + if u.EnableCompression { + for _, ext := range extensions { + if bytes.HasPrefix(ext, strPermessageDeflate) { + return true + } + } + } + + return false +} + +// Upgrade upgrades the HTTP server connection to the WebSocket protocol. +// +// The responseHeader is included in the response to the client's upgrade +// request. Use the responseHeader to specify cookies (Set-Cookie) and the +// application negotiated subprotocol (Sec-WebSocket-Protocol). +// +// If the upgrade fails, then Upgrade replies to the client with an HTTP error +// response. +func (u *HertzUpgrader) Upgrade(ctx *app.RequestContext, handler HertzHandler) error { + if !ctx.IsGet() { + return u.returnError(ctx, consts.StatusMethodNotAllowed, fmt.Sprintf("%s request method is not GET", badHandshake)) + } + + if !tokenContainsValue(strconv.B2S(ctx.Request.Header.Peek("Connection")), "Upgrade") { + return u.returnError(ctx, consts.StatusBadRequest, fmt.Sprintf("%s 'upgrade' token not found in 'Connection' header", badHandshake)) + } + + if !tokenContainsValue(strconv.B2S(ctx.Request.Header.Peek("Upgrade")), "Websocket") { + return u.returnError(ctx, consts.StatusBadRequest, fmt.Sprintf("%s 'websocket' token not found in 'Upgrade' header", badHandshake)) + } + + if !tokenContainsValue(strconv.B2S(ctx.Request.Header.Peek("Sec-Websocket-Version")), "13") { + return u.returnError(ctx, consts.StatusBadRequest, "websocket: unsupported version: 13 not found in 'Sec-Websocket-Version' header") + } + + if len(ctx.Response.Header.Peek("Sec-Websocket-Extensions")) > 0 { + return u.returnError(ctx, consts.StatusInternalServerError, "websocket: application specific 'Sec-WebSocket-Extensions' headers are unsupported") + } + + checkOrigin := u.CheckOrigin + if checkOrigin == nil { + checkOrigin = fastHTTPCheckSameOrigin + } + if !checkOrigin(ctx) { + return u.returnError(ctx, consts.StatusForbidden, "websocket: request origin not allowed by HertzUpgrader.CheckOrigin") + } + + challengeKey := ctx.Request.Header.Peek("Sec-Websocket-Key") + if len(challengeKey) == 0 { + return u.returnError(ctx, consts.StatusBadRequest, "websocket: not a websocket handshake: `Sec-WebSocket-Key' header is missing or blank") + } + + subprotocol := u.selectSubprotocol(ctx) + compress := u.isCompressionEnable(ctx) + + ctx.SetStatusCode(consts.StatusSwitchingProtocols) + ctx.Response.Header.Set("Upgrade", "websocket") + ctx.Response.Header.Set("Connection", "Upgrade") + ctx.Response.Header.Set("Sec-WebSocket-Accept", computeAcceptKeyBytes(challengeKey)) + if compress { + ctx.Response.Header.Set("Sec-WebSocket-Extensions", "permessage-deflate; server_no_context_takeover; client_no_context_takeover") + } + if subprotocol != nil { + ctx.Response.Header.SetBytesV("Sec-WebSocket-Protocol", subprotocol) + } + + ctx.Hijack(func(netConn network.Conn) { + writeBuf := poolWriteBuffer.Get().([]byte) + c := newConn(netConn, true, u.ReadBufferSize, u.WriteBufferSize, u.WriteBufferPool, nil, writeBuf) + if subprotocol != nil { + c.subprotocol = strconv.B2S(subprotocol) + } + + if compress { + c.newCompressionWriter = compressNoContextTakeover + c.newDecompressionReader = decompressNoContextTakeover + } + + // Clear deadlines set by HTTP server. + netConn.SetDeadline(time.Time{}) + + handler(c) + + writeBuf = writeBuf[0:0] + poolWriteBuffer.Put(writeBuf) + }) + + return nil +} + +// fastHTTPCheckSameOrigin returns true if the origin is not set or is equal to the request host. +func fastHTTPCheckSameOrigin(ctx *app.RequestContext) bool { + origin := ctx.Request.Header.Peek("Origin") + if len(origin) == 0 { + return true + } + u, err := url.Parse(strconv.B2S(origin)) + if err != nil { + return false + } + return equalASCIIFold(u.Host, strconv.B2S(ctx.Host())) +} diff --git a/util.go b/util.go new file mode 100644 index 0000000..7f21c28 --- /dev/null +++ b/util.go @@ -0,0 +1,191 @@ +// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. +// +// This file may have been modified by CloudWeGo authors. All CloudWeGo +// Modifications are Copyright 2022 CloudWeGo Authors. + +package websocket + +import ( + "bytes" + "crypto/sha1" + "encoding/base64" + "unicode/utf8" +) + +var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") + +// Token octets per RFC 2616. +var isTokenOctet = [256]bool{ + '!': true, + '#': true, + '$': true, + '%': true, + '&': true, + '\'': true, + '*': true, + '+': true, + '-': true, + '.': true, + '0': true, + '1': true, + '2': true, + '3': true, + '4': true, + '5': true, + '6': true, + '7': true, + '8': true, + '9': true, + 'A': true, + 'B': true, + 'C': true, + 'D': true, + 'E': true, + 'F': true, + 'G': true, + 'H': true, + 'I': true, + 'J': true, + 'K': true, + 'L': true, + 'M': true, + 'N': true, + 'O': true, + 'P': true, + 'Q': true, + 'R': true, + 'S': true, + 'T': true, + 'U': true, + 'W': true, + 'V': true, + 'X': true, + 'Y': true, + 'Z': true, + '^': true, + '_': true, + '`': true, + 'a': true, + 'b': true, + 'c': true, + 'd': true, + 'e': true, + 'f': true, + 'g': true, + 'h': true, + 'i': true, + 'j': true, + 'k': true, + 'l': true, + 'm': true, + 'n': true, + 'o': true, + 'p': true, + 'q': true, + 'r': true, + 's': true, + 't': true, + 'u': true, + 'v': true, + 'w': true, + 'x': true, + 'y': true, + 'z': true, + '|': true, + '~': true, +} + +// skipSpace returns a slice of the string s with all leading RFC 2616 linear +// whitespace removed. +func skipSpace(s string) (rest string) { + i := 0 + for ; i < len(s); i++ { + if b := s[i]; b != ' ' && b != '\t' { + break + } + } + return s[i:] +} + +// nextToken returns the leading RFC 2616 token of s and the string following +// the token. +func nextToken(s string) (token, rest string) { + i := 0 + for ; i < len(s); i++ { + if !isTokenOctet[s[i]] { + break + } + } + return s[:i], s[i:] +} + +// equalASCIIFold returns true if s is equal to t with ASCII case folding as +// defined in RFC 4790. +func equalASCIIFold(s, t string) bool { + for s != "" && t != "" { + sr, size := utf8.DecodeRuneInString(s) + s = s[size:] + tr, size := utf8.DecodeRuneInString(t) + t = t[size:] + if sr == tr { + continue + } + if 'A' <= sr && sr <= 'Z' { + sr = sr + 'a' - 'A' + } + if 'A' <= tr && tr <= 'Z' { + tr = tr + 'a' - 'A' + } + if sr != tr { + return false + } + } + return s == t +} + +// parseDataHeader returns a list with values if header value is comma-separated +func parseDataHeader(headerValue []byte) [][]byte { + h := bytes.TrimSpace(headerValue) + if bytes.Equal(h, []byte("")) { + return nil + } + + values := bytes.Split(h, []byte(",")) + for i := range values { + values[i] = bytes.TrimSpace(values[i]) + } + return values +} + +// tokenContainsValue returns true if the 1#token header with the given +// name contains a token equal to value with ASCII case folding. +func tokenContainsValue(s, value string) bool { + for { + var t string + t, s = nextToken(skipSpace(s)) + if t == "" { + return false + } + s = skipSpace(s) + if s != "" && s[0] != ',' { + return false + } + if equalASCIIFold(t, value) { + return true + } + if s == "" { + return false + } + + s = s[1:] + } +} + +func computeAcceptKeyBytes(challengeKey []byte) string { + h := sha1.New() + h.Write(challengeKey) + h.Write(keyGUID) + return base64.StdEncoding.EncodeToString(h.Sum(nil)) +}