-
Notifications
You must be signed in to change notification settings - Fork 14
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: websocket package for hertz (#1)
* adapted to Hertz * Upgrader: err info * remove client side code * feat: websocket package for hertz * ci: modify license header * style(.licenserc.yaml):remove unused char Co-authored-by: kinggo <[email protected]>
- Loading branch information
1 parent
985b4b1
commit 0453d14
Showing
38 changed files
with
5,049 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1 +1,17 @@ | ||
.idea/ | ||
# Binaries for programs and plugins | ||
*.exe | ||
*.exe~ | ||
*.dll | ||
*.so | ||
*.dylib | ||
|
||
# Test binary, build with `go test -c` | ||
*.test | ||
|
||
# Output of the go coverage tool, specifically when used with LiteIDE | ||
*.out | ||
|
||
# goland | ||
.idea | ||
# vscode | ||
.vscode |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
header: | ||
license: | ||
spdx-id: Apache-2.0 | ||
copyright-owner: CloudWeGo Authors | ||
content: | | ||
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
// | ||
// This file may have been modified by CloudWeGo authors. All CloudWeGo | ||
// Modifications are Copyright 2022 CloudWeGo Authors. | ||
paths: | ||
- '**/*.go' | ||
- '**/*.s' | ||
|
||
comment: on-failure |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
# This is the official list of Gorilla WebSocket authors for copyright | ||
# purposes. | ||
# | ||
# Please keep the list sorted. | ||
|
||
Gary Burd <[email protected]> | ||
Google LLC (https://opensource.google.com/) | ||
Joachim Bauch <[email protected]> | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# Hertz-WebSocket(This is a community driven project) | ||
|
||
|
||
This repo is forked from [Gorilla WebSocket](https://github.com/gorilla/websocket/) and adapted to Hertz. | ||
|
||
### How to use | ||
```go | ||
package main | ||
|
||
import ( | ||
"context" | ||
"log" | ||
|
||
"github.com/cloudwego/hertz/pkg/app" | ||
"github.com/cloudwego/hertz/pkg/app/server" | ||
"github.com/hertz-contrib/websocket" | ||
) | ||
|
||
var upgrader = websocket.HertzUpgrader{} // use default options | ||
|
||
func echo(_ context.Context, c *app.RequestContext) { | ||
err := upgrader.Upgrade(c, func(conn *websocket.Conn) { | ||
for { | ||
mt, message, err := conn.ReadMessage() | ||
if err != nil { | ||
log.Println("read:", err) | ||
break | ||
} | ||
log.Printf("recv: %s", message) | ||
err = conn.WriteMessage(mt, message) | ||
if err != nil { | ||
log.Println("write:", err) | ||
break | ||
} | ||
} | ||
}) | ||
if err != nil { | ||
log.Print("upgrade:", err) | ||
return | ||
} | ||
} | ||
|
||
|
||
func main() { | ||
h := server.Default(server.WithHostPorts(addr)) | ||
// https://github.com/cloudwego/hertz/issues/121 | ||
h.NoHijackConnPool = true | ||
h.GET("/echo", echo) | ||
h.Spin() | ||
} | ||
|
||
``` | ||
|
||
### More info | ||
|
||
See [examples](examples/) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
# Typo check: https://github.com/crate-ci/typos | ||
|
||
[files] | ||
extend-exclude = ["go.sum"] | ||
|
||
[default.extend-identifiers] | ||
# *sigh* this just isn't worth the cost of fixing | ||
ConnTLSer = "ConnTLSer" | ||
flate = "flate" | ||
TestCompressFlateSerial = "TestCompressFlateSerial" | ||
testCompressFlate = "testCompressFlate" | ||
TestCompressFlateConcurrent = "TestCompressFlateConcurrent" | ||
trUe = "trUe" | ||
OPTIO = "OPTIO" | ||
contant = "contant" | ||
referer = "referer" | ||
HeaderReferer = "HeaderReferer" | ||
expectedReferer = "expectedReferer" | ||
Referer = "Referer" | ||
flateWriterPools = "flateWriterPools" | ||
flateReaderPool = "flateReaderPool" | ||
flateWriteWrapper = "flateWriteWrapper" | ||
flateReadWrapper = "flateReadWrapper" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,151 @@ | ||
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
// | ||
// This file may have been modified by CloudWeGo authors. All CloudWeGo | ||
// Modifications are Copyright 2022 CloudWeGo Authors. | ||
|
||
package websocket | ||
|
||
import ( | ||
"compress/flate" | ||
"errors" | ||
"io" | ||
"strings" | ||
"sync" | ||
) | ||
|
||
const ( | ||
minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6 | ||
maxCompressionLevel = flate.BestCompression | ||
defaultCompressionLevel = 1 | ||
) | ||
|
||
var ( | ||
flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool | ||
flateReaderPool = sync.Pool{New: func() interface{} { | ||
return flate.NewReader(nil) | ||
}} | ||
) | ||
|
||
func decompressNoContextTakeover(r io.Reader) io.ReadCloser { | ||
const tail = | ||
// Add four bytes as specified in RFC | ||
"\x00\x00\xff\xff" + | ||
// Add final block to squelch unexpected EOF error from flate reader. | ||
"\x01\x00\x00\xff\xff" | ||
|
||
fr, _ := flateReaderPool.Get().(io.ReadCloser) | ||
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil) | ||
return &flateReadWrapper{fr} | ||
} | ||
|
||
func isValidCompressionLevel(level int) bool { | ||
return minCompressionLevel <= level && level <= maxCompressionLevel | ||
} | ||
|
||
func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser { | ||
p := &flateWriterPools[level-minCompressionLevel] | ||
tw := &truncWriter{w: w} | ||
fw, _ := p.Get().(*flate.Writer) | ||
if fw == nil { | ||
fw, _ = flate.NewWriter(tw, level) | ||
} else { | ||
fw.Reset(tw) | ||
} | ||
return &flateWriteWrapper{fw: fw, tw: tw, p: p} | ||
} | ||
|
||
// truncWriter is an io.Writer that writes all but the last four bytes of the | ||
// stream to another io.Writer. | ||
type truncWriter struct { | ||
w io.WriteCloser | ||
n int | ||
p [4]byte | ||
} | ||
|
||
func (w *truncWriter) Write(p []byte) (int, error) { | ||
n := 0 | ||
|
||
// fill buffer first for simplicity. | ||
if w.n < len(w.p) { | ||
n = copy(w.p[w.n:], p) | ||
p = p[n:] | ||
w.n += n | ||
if len(p) == 0 { | ||
return n, nil | ||
} | ||
} | ||
|
||
m := len(p) | ||
if m > len(w.p) { | ||
m = len(w.p) | ||
} | ||
|
||
if nn, err := w.w.Write(w.p[:m]); err != nil { | ||
return n + nn, err | ||
} | ||
|
||
copy(w.p[:], w.p[m:]) | ||
copy(w.p[len(w.p)-m:], p[len(p)-m:]) | ||
nn, err := w.w.Write(p[:len(p)-m]) | ||
return n + nn, err | ||
} | ||
|
||
type flateWriteWrapper struct { | ||
fw *flate.Writer | ||
tw *truncWriter | ||
p *sync.Pool | ||
} | ||
|
||
func (w *flateWriteWrapper) Write(p []byte) (int, error) { | ||
if w.fw == nil { | ||
return 0, errWriteClosed | ||
} | ||
return w.fw.Write(p) | ||
} | ||
|
||
func (w *flateWriteWrapper) Close() error { | ||
if w.fw == nil { | ||
return errWriteClosed | ||
} | ||
err1 := w.fw.Flush() | ||
w.p.Put(w.fw) | ||
w.fw = nil | ||
if w.tw.p != [4]byte{0, 0, 0xff, 0xff} { | ||
return errors.New("websocket: internal error, unexpected bytes at end of flate stream") | ||
} | ||
err2 := w.tw.w.Close() | ||
if err1 != nil { | ||
return err1 | ||
} | ||
return err2 | ||
} | ||
|
||
type flateReadWrapper struct { | ||
fr io.ReadCloser | ||
} | ||
|
||
func (r *flateReadWrapper) Read(p []byte) (int, error) { | ||
if r.fr == nil { | ||
return 0, io.ErrClosedPipe | ||
} | ||
n, err := r.fr.Read(p) | ||
if err == io.EOF { | ||
// Preemptively place the reader back in the pool. This helps with | ||
// scenarios where the application does not call NextReader() soon after | ||
// this final read. | ||
r.Close() | ||
} | ||
return n, err | ||
} | ||
|
||
func (r *flateReadWrapper) Close() error { | ||
if r.fr == nil { | ||
return io.ErrClosedPipe | ||
} | ||
err := r.fr.Close() | ||
flateReaderPool.Put(r.fr) | ||
r.fr = nil | ||
return err | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,87 @@ | ||
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved. | ||
// Use of this source code is governed by a BSD-style | ||
// license that can be found in the LICENSE file. | ||
// | ||
// This file may have been modified by CloudWeGo authors. All CloudWeGo | ||
// Modifications are Copyright 2022 CloudWeGo Authors. | ||
|
||
package websocket | ||
|
||
import ( | ||
"bytes" | ||
"fmt" | ||
"io" | ||
"io/ioutil" | ||
"testing" | ||
) | ||
|
||
type nopCloser struct{ io.Writer } | ||
|
||
func (nopCloser) Close() error { return nil } | ||
|
||
func TestTruncWriter(t *testing.T) { | ||
const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321" | ||
for n := 1; n <= 10; n++ { | ||
var b bytes.Buffer | ||
w := &truncWriter{w: nopCloser{&b}} | ||
p := []byte(data) | ||
for len(p) > 0 { | ||
m := len(p) | ||
if m > n { | ||
m = n | ||
} | ||
w.Write(p[:m]) | ||
p = p[m:] | ||
} | ||
if b.String() != data[:len(data)-len(w.p)] { | ||
t.Errorf("%d: %q", n, b.String()) | ||
} | ||
} | ||
} | ||
|
||
func textMessages(num int) [][]byte { | ||
messages := make([][]byte, num) | ||
for i := 0; i < num; i++ { | ||
msg := fmt.Sprintf("planet: %d, country: %d, city: %d, street: %d", i, i, i, i) | ||
messages[i] = []byte(msg) | ||
} | ||
return messages | ||
} | ||
|
||
func BenchmarkWriteNoCompression(b *testing.B) { | ||
w := ioutil.Discard | ||
c := newTestConn(nil, w, false) | ||
messages := textMessages(100) | ||
b.ResetTimer() | ||
for i := 0; i < b.N; i++ { | ||
c.WriteMessage(TextMessage, messages[i%len(messages)]) | ||
} | ||
b.ReportAllocs() | ||
} | ||
|
||
func BenchmarkWriteWithCompression(b *testing.B) { | ||
w := ioutil.Discard | ||
c := newTestConn(nil, w, false) | ||
messages := textMessages(100) | ||
c.enableWriteCompression = true | ||
c.newCompressionWriter = compressNoContextTakeover | ||
b.ResetTimer() | ||
for i := 0; i < b.N; i++ { | ||
c.WriteMessage(TextMessage, messages[i%len(messages)]) | ||
} | ||
b.ReportAllocs() | ||
} | ||
|
||
func TestValidCompressionLevel(t *testing.T) { | ||
c := newTestConn(nil, nil, false) | ||
for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} { | ||
if err := c.SetCompressionLevel(level); err == nil { | ||
t.Errorf("no error for level %d", level) | ||
} | ||
} | ||
for _, level := range []int{minCompressionLevel, maxCompressionLevel} { | ||
if err := c.SetCompressionLevel(level); err != nil { | ||
t.Errorf("error for level %d", level) | ||
} | ||
} | ||
} |
Oops, something went wrong.