Skip to content

Commit

Permalink
Merge pull request #193 from nhooyr/ensure-close
Browse files Browse the repository at this point in the history
Ensure connection is closed at all error points
  • Loading branch information
nhooyr authored Feb 21, 2020
2 parents 43c4dc0 + 2e0dd1c commit c62c0dc
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 26 deletions.
26 changes: 12 additions & 14 deletions read.go
Original file line number Diff line number Diff line change
Expand Up @@ -304,7 +304,9 @@ func (c *Conn) reader(ctx context.Context) (_ MessageType, _ io.Reader, err erro
defer c.readMu.unlock()

if !c.msgReader.fin {
return 0, nil, errors.New("previous message not read to completion")
err = errors.New("previous message not read to completion")
c.close(fmt.Errorf("failed to get reader: %w", err))
return 0, nil, err
}

h, err := c.readLoop(ctx)
Expand Down Expand Up @@ -361,21 +363,9 @@ func (mr *msgReader) setFrame(h header) {
}

func (mr *msgReader) Read(p []byte) (n int, err error) {
defer func() {
if errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
err = io.EOF
}
if errors.Is(err, io.EOF) {
err = io.EOF
mr.putFlateReader()
return
}
errd.Wrap(&err, "failed to read")
}()

err = mr.c.readMu.lock(mr.ctx)
if err != nil {
return 0, err
return 0, fmt.Errorf("failed to read: %w", err)
}
defer mr.c.readMu.unlock()

Expand All @@ -384,6 +374,14 @@ func (mr *msgReader) Read(p []byte) (n int, err error) {
p = p[:n]
mr.dict.write(p)
}
if errors.Is(err, io.EOF) || errors.Is(err, io.ErrUnexpectedEOF) && mr.fin && mr.flate {
mr.putFlateReader()
return n, io.EOF
}
if err != nil {
err = fmt.Errorf("failed to read: %w", err)
mr.c.close(err)
}
return n, err
}

Expand Down
42 changes: 30 additions & 12 deletions write.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@ import (
"errors"
"fmt"
"io"
"sync"
"time"

"github.com/klauspost/compress/flate"
Expand Down Expand Up @@ -71,7 +70,7 @@ type msgWriterState struct {
c *Conn

mu *mu
writeMu sync.Mutex
writeMu *mu

ctx context.Context
opcode opcode
Expand All @@ -83,8 +82,9 @@ type msgWriterState struct {

func newMsgWriterState(c *Conn) *msgWriterState {
mw := &msgWriterState{
c: c,
mu: newMu(c),
c: c,
mu: newMu(c),
writeMu: newMu(c),
}
return mw
}
Expand Down Expand Up @@ -155,10 +155,18 @@ func (mw *msgWriterState) reset(ctx context.Context, typ MessageType) error {

// Write writes the given bytes to the WebSocket connection.
func (mw *msgWriterState) Write(p []byte) (_ int, err error) {
defer errd.Wrap(&err, "failed to write")
err = mw.writeMu.lock(mw.ctx)
if err != nil {
return 0, fmt.Errorf("failed to write: %w", err)
}
defer mw.writeMu.unlock()

mw.writeMu.Lock()
defer mw.writeMu.Unlock()
defer func() {
if err != nil {
err = fmt.Errorf("failed to write: %w", err)
mw.c.close(err)
}
}()

if mw.c.flate() {
// Only enables flate if the length crosses the
Expand Down Expand Up @@ -193,8 +201,11 @@ func (mw *msgWriterState) write(p []byte) (int, error) {
func (mw *msgWriterState) Close() (err error) {
defer errd.Wrap(&err, "failed to close writer")

mw.writeMu.Lock()
defer mw.writeMu.Unlock()
err = mw.writeMu.lock(mw.ctx)
if err != nil {
return err
}
defer mw.writeMu.unlock()

_, err = mw.c.writeFrame(mw.ctx, true, mw.flate, mw.opcode, nil)
if err != nil {
Expand All @@ -214,7 +225,7 @@ func (mw *msgWriterState) close() {
putBufioWriter(mw.c.bw)
}

mw.writeMu.Lock()
mw.writeMu.forceLock()
mw.dict.close()
}

Expand All @@ -230,8 +241,8 @@ func (c *Conn) writeControl(ctx context.Context, opcode opcode, p []byte) error
}

// frame handles all writes to the connection.
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (int, error) {
err := c.writeFrameMu.lock(ctx)
func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opcode, p []byte) (_ int, err error) {
err = c.writeFrameMu.lock(ctx)
if err != nil {
return 0, err
}
Expand All @@ -243,6 +254,13 @@ func (c *Conn) writeFrame(ctx context.Context, fin bool, flate bool, opcode opco
case c.writeTimeout <- ctx:
}

defer func() {
if err != nil {
err = fmt.Errorf("failed to write frame: %w", err)
c.close(err)
}
}()

c.writeHeader.fin = fin
c.writeHeader.opcode = opcode
c.writeHeader.payloadLength = int64(len(p))
Expand Down

0 comments on commit c62c0dc

Please sign in to comment.