Skip to content

Commit

Permalink
[fastws] fix reading very big messages
Browse files Browse the repository at this point in the history
  • Loading branch information
ehsannm committed May 24, 2024
1 parent 72e9128 commit a6162fd
Show file tree
Hide file tree
Showing 6 changed files with 96 additions and 39 deletions.
4 changes: 3 additions & 1 deletion kit/edge.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ type EdgeServer struct {

// configs
prefork bool
reusePort bool
shutdownTimeout time.Duration

// local store
Expand All @@ -61,6 +62,7 @@ func NewServer(opts ...Option) *EdgeServer {

s.l = cfg.logger
s.prefork = cfg.prefork
s.reusePort = cfg.reusePort
s.eh = cfg.errHandler
s.gh = cfg.globalHandlers
s.cd = cfg.connDelegate
Expand Down Expand Up @@ -288,7 +290,7 @@ func (s *EdgeServer) startup(ctx context.Context) {
err := s.nb[idx].gw.Start(
ctx,
GatewayStartConfig{
ReusePort: s.prefork,
ReusePort: s.prefork || s.reusePort,
},
)
if err != nil {
Expand Down
9 changes: 9 additions & 0 deletions kit/edge_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import "time"
type edgeConfig struct {
logger Logger
prefork bool
reusePort bool
shutdownTimeout time.Duration
gateways []Gateway
cluster Cluster
Expand Down Expand Up @@ -101,3 +102,11 @@ func WithConnDelegate(d ConnDelegate) Option {
s.connDelegate = d
}
}

// ReusePort asks Gateway to listen in REUSE_PORT mode.
// default this is false
func ReusePort(t bool) Option {
return func(s *edgeConfig) {
s.reusePort = t
}
}
6 changes: 6 additions & 0 deletions kit/utils/generic.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,3 +48,9 @@ func OkOr[T any](v T, err error, fallback T) T {

return v
}

func TryCast[T any](v any) T {
t, _ := v.(T)

return t
}
75 changes: 50 additions & 25 deletions std/gateways/fastws/gateway.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,6 @@ import (
"github.com/clubpay/ronykit/kit/utils"
"github.com/clubpay/ronykit/kit/utils/buf"
"github.com/gobwas/ws"
"github.com/gobwas/ws/wsutil"
"github.com/panjf2000/gnet/v2"
)

Expand Down Expand Up @@ -88,6 +87,8 @@ func (gw *gateway) OnClose(c gnet.Conn, _ error) (action gnet.Action) {
func (gw *gateway) OnTraffic(c gnet.Conn) gnet.Action {
wsc := gw.getConnWrap(c)
if wsc == nil {
gw.b.l.Debugf("did not find ws conn for connID(%d)", utils.TryCast[int64](c.Context()))

return gnet.Close
}

Expand All @@ -96,6 +97,7 @@ func (gw *gateway) OnTraffic(c gnet.Conn) gnet.Action {
_, err := sp.Upgrade(wsc.c)
if err != nil {
wsc.Close()
gw.b.l.Debugf("faild to upgrade websocket connID(%d): %v", utils.TryCast[int64](c.Context()), err)

return gnet.Close
}
Expand All @@ -111,46 +113,69 @@ func (gw *gateway) OnTraffic(c gnet.Conn) gnet.Action {
hdr ws.Header
)

hdr, err = wsc.r.NextFrame()
if err != nil {
if builtinErr.Is(err, io.EOF) {
for {
hdr, err = wsc.r.NextFrame()
if err != nil {
if builtinErr.Is(err, io.EOF) {
return gnet.None
}
gw.b.l.Debugf("failed to read next frame of connID(%d): %v", utils.TryCast[int64](c.Context()), err)

return gnet.Close
}

if hdr.OpCode.IsControl() {
if err = wsc.r.OnIntermediate(hdr, wsc.r); err != nil {
gw.b.l.Debugf(
"failed to handle control message of connID(%d), opCode(%d): %v",
utils.TryCast[int64](c.Context()), hdr.OpCode, err,
)

return gnet.Close
}

if err = wsc.r.Discard(); err != nil {
gw.b.l.Debugf(
"failed to discard on control message connID(%d): %v",
utils.TryCast[int64](c.Context()), err,
)

return gnet.Close
}

return gnet.None
}

return gnet.Close
if hdr.OpCode&(ws.OpText|ws.OpBinary) != hdr.OpCode {
if err = wsc.r.Discard(); err != nil {
return gnet.Close
}

continue
}

break
}

var p []byte
var pBuff *buf.Bytes
if hdr.Fin {
// No more frames will be read. Use fixed sized buffer to read payload.
p = make([]byte, hdr.Length)
// It is not possible to receive io.EOF here because Reader does not
// return EOF if frame payload was successfully fetched.
_, err = io.ReadFull(wsc.r, p)
pBuff = buf.GetLen(int(hdr.Length))
_, err = io.ReadFull(wsc.r, *pBuff.Bytes())
} else {
// Frame is fragmented, thus use io.ReadAll behavior.
var buff bytes.Buffer
// create a default buffer cap, since we don't know the exact size of payload
pBuff = buf.GetCap(8192)
buff := bytes.NewBuffer(*pBuff.Bytes())
_, err = buff.ReadFrom(wsc.r)
p = buff.Bytes()
pBuff.SetBytes(utils.ValPtr(buff.Bytes()))
}
if err != nil {
return gnet.Close
}

wsc.msgs = append(wsc.msgs, wsutil.Message{OpCode: hdr.OpCode, Payload: p})

for _, msg := range wsc.msgs {
if msg.OpCode&(ws.OpText|ws.OpBinary) == 0 {
continue
}

payloadBuffer := buf.GetLen(len(msg.Payload))
payloadBuffer.CopyFrom(msg.Payload)

go gw.reactFunc(wsc, payloadBuffer, len(msg.Payload))
}

wsc.msgs = wsc.msgs[:0]
go gw.reactFunc(wsc, pBuff, pBuff.Len())

return gnet.None
}
Expand Down
33 changes: 21 additions & 12 deletions testenv/fastws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@ func invokeEdgeServerWithFastWS(port int, desc ...kit.ServiceDescriptor) fx.Opti
return fx.Invoke(
func(lc fx.Lifecycle) {
edge := kit.NewServer(
kit.ReusePort(true),
kit.WithLogger(&stdLogger{}),
kit.WithErrorHandler(
func(ctx *kit.Context, err error) {
Expand All @@ -47,6 +48,7 @@ func invokeEdgeServerWithFastWS(port int, desc ...kit.ServiceDescriptor) fx.Opti
fastws.MustNew(
fastws.WithPredicateKey("cmd"),
fastws.Listen(fmt.Sprintf("tcp4://0.0.0.0:%d", port)),
fastws.WithLogger(&stdLogger{}),
),
),
kit.WithServiceDesc(desc...),
Expand Down Expand Up @@ -81,23 +83,30 @@ func fastwsWithHugePayload(t *testing.T, opt fx.Option) func(c C) {
),
)

time.Sleep(time.Second * 5)
time.Sleep(time.Second * 2)

wsCtx := stub.New("localhost:8082").
wsCtx := stub.New(
"localhost:8082",
//stub.WithLogger(&stdLogger{}),
).
Websocket(
stub.WithPredicateKey("cmd"),
stub.WithPingTime(time.Second),
)
c.So(wsCtx.Connect(ctx, "/"), ShouldBeNil)

req := &services.EchoRequest{Input: utils.RandomID(12000)}
res := &services.EchoResponse{}
err := wsCtx.BinaryMessage(
ctx, "echo", req, res,
func(ctx context.Context, msg kit.Message, hdr stub.Header, err error) {
c.So(err, ShouldBeNil)
c.So(msg.(*services.EchoResponse).Output, ShouldEqual, req.Input) //nolint:forcetypeassert
},
)
c.So(err, ShouldBeNil)
for i := 0; i < 10; i++ {
req := &services.EchoRequest{Input: utils.RandomID(1024)}
res := &services.EchoResponse{}
err := wsCtx.BinaryMessage(
ctx, "echo", req, res,
func(ctx context.Context, msg kit.Message, hdr stub.Header, err error) {
c.So(err, ShouldBeNil)
c.So(msg.(*services.EchoResponse).Output, ShouldEqual, req.Input) //nolint:forcetypeassert
},
)
c.So(err, ShouldBeNil)
time.Sleep(time.Second)
}
}
}
8 changes: 7 additions & 1 deletion testenv/utils.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,15 @@
package testenv

import "fmt"
import (
"fmt"

"github.com/clubpay/ronykit/kit"
)

type stdLogger struct{}

var _ kit.Logger = (*stdLogger)(nil)

func (s stdLogger) Debugf(format string, args ...any) {
fmt.Printf("DEBUG: %s\n", fmt.Sprintf(format, args...))
}
Expand Down

0 comments on commit a6162fd

Please sign in to comment.