Skip to content

Commit

Permalink
feat: websocket package for hertz (#1)
Browse files Browse the repository at this point in the history
* adapted to Hertz

* Upgrader: err info

* remove client side code

* feat: websocket package for hertz

* ci: modify license header

* style(.licenserc.yaml):remove unused char

Co-authored-by: kinggo <[email protected]>
  • Loading branch information
baiyutang and li-jin-gou authored Sep 13, 2022
1 parent 985b4b1 commit 0453d14
Show file tree
Hide file tree
Showing 38 changed files with 5,049 additions and 4 deletions.
4 changes: 2 additions & 2 deletions .github/workflows/pr-check.yml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ jobs:
- name: Set up Go
uses: actions/setup-go@v3
with:
go-version: 1.18
go-version: 1.16

- uses: actions/cache@v3
with:
Expand All @@ -43,4 +43,4 @@ jobs:
# Exit with 1 when it find at least one finding.
fail_on_error: true
# Set staticcheck flags
staticcheck_flags: -checks=inherit,-SA1029
staticcheck_flags: -checks=inherit,-SA1029,-SA6002
18 changes: 17 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
@@ -1 +1,17 @@
.idea/
# Binaries for programs and plugins
*.exe
*.exe~
*.dll
*.so
*.dylib

# Test binary, build with `go test -c`
*.test

# Output of the go coverage tool, specifically when used with LiteIDE
*.out

# goland
.idea
# vscode
.vscode
17 changes: 17 additions & 0 deletions .licenserc.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
header:
license:
spdx-id: Apache-2.0
copyright-owner: CloudWeGo Authors
content: |
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//
// This file may have been modified by CloudWeGo authors. All CloudWeGo
// Modifications are Copyright 2022 CloudWeGo Authors.
paths:
- '**/*.go'
- '**/*.s'

comment: on-failure
9 changes: 9 additions & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
# This is the official list of Gorilla WebSocket authors for copyright
# purposes.
#
# Please keep the list sorted.

Gary Burd <[email protected]>
Google LLC (https://opensource.google.com/)
Joachim Bauch <[email protected]>

57 changes: 57 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# Hertz-WebSocket(This is a community driven project)


This repo is forked from [Gorilla WebSocket](https://github.com/gorilla/websocket/) and adapted to Hertz.

### How to use
```go
package main

import (
"context"
"log"

"github.com/cloudwego/hertz/pkg/app"
"github.com/cloudwego/hertz/pkg/app/server"
"github.com/hertz-contrib/websocket"
)

var upgrader = websocket.HertzUpgrader{} // use default options

func echo(_ context.Context, c *app.RequestContext) {
err := upgrader.Upgrade(c, func(conn *websocket.Conn) {
for {
mt, message, err := conn.ReadMessage()
if err != nil {
log.Println("read:", err)
break
}
log.Printf("recv: %s", message)
err = conn.WriteMessage(mt, message)
if err != nil {
log.Println("write:", err)
break
}
}
})
if err != nil {
log.Print("upgrade:", err)
return
}
}


func main() {
h := server.Default(server.WithHostPorts(addr))
// https://github.com/cloudwego/hertz/issues/121
h.NoHijackConnPool = true
h.GET("/echo", echo)
h.Spin()
}

```

### More info

See [examples](examples/)

23 changes: 23 additions & 0 deletions _typos.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
# Typo check: https://github.com/crate-ci/typos

[files]
extend-exclude = ["go.sum"]

[default.extend-identifiers]
# *sigh* this just isn't worth the cost of fixing
ConnTLSer = "ConnTLSer"
flate = "flate"
TestCompressFlateSerial = "TestCompressFlateSerial"
testCompressFlate = "testCompressFlate"
TestCompressFlateConcurrent = "TestCompressFlateConcurrent"
trUe = "trUe"
OPTIO = "OPTIO"
contant = "contant"
referer = "referer"
HeaderReferer = "HeaderReferer"
expectedReferer = "expectedReferer"
Referer = "Referer"
flateWriterPools = "flateWriterPools"
flateReaderPool = "flateReaderPool"
flateWriteWrapper = "flateWriteWrapper"
flateReadWrapper = "flateReadWrapper"
151 changes: 151 additions & 0 deletions compression.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,151 @@
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//
// This file may have been modified by CloudWeGo authors. All CloudWeGo
// Modifications are Copyright 2022 CloudWeGo Authors.

package websocket

import (
"compress/flate"
"errors"
"io"
"strings"
"sync"
)

const (
minCompressionLevel = -2 // flate.HuffmanOnly not defined in Go < 1.6
maxCompressionLevel = flate.BestCompression
defaultCompressionLevel = 1
)

var (
flateWriterPools [maxCompressionLevel - minCompressionLevel + 1]sync.Pool
flateReaderPool = sync.Pool{New: func() interface{} {
return flate.NewReader(nil)
}}
)

func decompressNoContextTakeover(r io.Reader) io.ReadCloser {
const tail =
// Add four bytes as specified in RFC
"\x00\x00\xff\xff" +
// Add final block to squelch unexpected EOF error from flate reader.
"\x01\x00\x00\xff\xff"

fr, _ := flateReaderPool.Get().(io.ReadCloser)
fr.(flate.Resetter).Reset(io.MultiReader(r, strings.NewReader(tail)), nil)
return &flateReadWrapper{fr}
}

func isValidCompressionLevel(level int) bool {
return minCompressionLevel <= level && level <= maxCompressionLevel
}

func compressNoContextTakeover(w io.WriteCloser, level int) io.WriteCloser {
p := &flateWriterPools[level-minCompressionLevel]
tw := &truncWriter{w: w}
fw, _ := p.Get().(*flate.Writer)
if fw == nil {
fw, _ = flate.NewWriter(tw, level)
} else {
fw.Reset(tw)
}
return &flateWriteWrapper{fw: fw, tw: tw, p: p}
}

// truncWriter is an io.Writer that writes all but the last four bytes of the
// stream to another io.Writer.
type truncWriter struct {
w io.WriteCloser
n int
p [4]byte
}

func (w *truncWriter) Write(p []byte) (int, error) {
n := 0

// fill buffer first for simplicity.
if w.n < len(w.p) {
n = copy(w.p[w.n:], p)
p = p[n:]
w.n += n
if len(p) == 0 {
return n, nil
}
}

m := len(p)
if m > len(w.p) {
m = len(w.p)
}

if nn, err := w.w.Write(w.p[:m]); err != nil {
return n + nn, err
}

copy(w.p[:], w.p[m:])
copy(w.p[len(w.p)-m:], p[len(p)-m:])
nn, err := w.w.Write(p[:len(p)-m])
return n + nn, err
}

type flateWriteWrapper struct {
fw *flate.Writer
tw *truncWriter
p *sync.Pool
}

func (w *flateWriteWrapper) Write(p []byte) (int, error) {
if w.fw == nil {
return 0, errWriteClosed
}
return w.fw.Write(p)
}

func (w *flateWriteWrapper) Close() error {
if w.fw == nil {
return errWriteClosed
}
err1 := w.fw.Flush()
w.p.Put(w.fw)
w.fw = nil
if w.tw.p != [4]byte{0, 0, 0xff, 0xff} {
return errors.New("websocket: internal error, unexpected bytes at end of flate stream")
}
err2 := w.tw.w.Close()
if err1 != nil {
return err1
}
return err2
}

type flateReadWrapper struct {
fr io.ReadCloser
}

func (r *flateReadWrapper) Read(p []byte) (int, error) {
if r.fr == nil {
return 0, io.ErrClosedPipe
}
n, err := r.fr.Read(p)
if err == io.EOF {
// Preemptively place the reader back in the pool. This helps with
// scenarios where the application does not call NextReader() soon after
// this final read.
r.Close()
}
return n, err
}

func (r *flateReadWrapper) Close() error {
if r.fr == nil {
return io.ErrClosedPipe
}
err := r.fr.Close()
flateReaderPool.Put(r.fr)
r.fr = nil
return err
}
87 changes: 87 additions & 0 deletions compression_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
// Copyright 2017 The Gorilla WebSocket Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//
// This file may have been modified by CloudWeGo authors. All CloudWeGo
// Modifications are Copyright 2022 CloudWeGo Authors.

package websocket

import (
"bytes"
"fmt"
"io"
"io/ioutil"
"testing"
)

type nopCloser struct{ io.Writer }

func (nopCloser) Close() error { return nil }

func TestTruncWriter(t *testing.T) {
const data = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijlkmnopqrstuvwxyz987654321"
for n := 1; n <= 10; n++ {
var b bytes.Buffer
w := &truncWriter{w: nopCloser{&b}}
p := []byte(data)
for len(p) > 0 {
m := len(p)
if m > n {
m = n
}
w.Write(p[:m])
p = p[m:]
}
if b.String() != data[:len(data)-len(w.p)] {
t.Errorf("%d: %q", n, b.String())
}
}
}

func textMessages(num int) [][]byte {
messages := make([][]byte, num)
for i := 0; i < num; i++ {
msg := fmt.Sprintf("planet: %d, country: %d, city: %d, street: %d", i, i, i, i)
messages[i] = []byte(msg)
}
return messages
}

func BenchmarkWriteNoCompression(b *testing.B) {
w := ioutil.Discard
c := newTestConn(nil, w, false)
messages := textMessages(100)
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.WriteMessage(TextMessage, messages[i%len(messages)])
}
b.ReportAllocs()
}

func BenchmarkWriteWithCompression(b *testing.B) {
w := ioutil.Discard
c := newTestConn(nil, w, false)
messages := textMessages(100)
c.enableWriteCompression = true
c.newCompressionWriter = compressNoContextTakeover
b.ResetTimer()
for i := 0; i < b.N; i++ {
c.WriteMessage(TextMessage, messages[i%len(messages)])
}
b.ReportAllocs()
}

func TestValidCompressionLevel(t *testing.T) {
c := newTestConn(nil, nil, false)
for _, level := range []int{minCompressionLevel - 1, maxCompressionLevel + 1} {
if err := c.SetCompressionLevel(level); err == nil {
t.Errorf("no error for level %d", level)
}
}
for _, level := range []int{minCompressionLevel, maxCompressionLevel} {
if err := c.SetCompressionLevel(level); err != nil {
t.Errorf("error for level %d", level)
}
}
}
Loading

0 comments on commit 0453d14

Please sign in to comment.