Skip to content

Commit

Permalink
Remove the closedReason. Its only used for log lines that are defer…
Browse files Browse the repository at this point in the history
…ed. And it complicates the mutex order acquisition story.
  • Loading branch information
dgottlieb committed Nov 12, 2024
1 parent df806ea commit cd31910
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 65 deletions.
6 changes: 2 additions & 4 deletions rpc/wrtc_base_channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ type webrtcBaseChannel struct {
cancel func()
ready chan struct{}
closed atomic.Bool
closedReason error
activeBackgroundWorkers sync.WaitGroup
logger utils.ZapCompatibleLogger
bufferWriteMu sync.RWMutex
Expand Down Expand Up @@ -162,7 +161,6 @@ func (ch *webrtcBaseChannel) closeWithReason(err error) error {

ch.mu.Lock()
defer ch.mu.Unlock()
ch.closedReason = err
ch.cancel()
ch.bufferWriteCond.Broadcast()

Expand All @@ -179,8 +177,8 @@ func (ch *webrtcBaseChannel) Close() error {
return ch.closeWithReason(nil)
}

func (ch *webrtcBaseChannel) Closed() (bool, error) {
return ch.closed.Load(), ch.closedReason
func (ch *webrtcBaseChannel) Closed() bool {
return ch.closed.Load()
}

func (ch *webrtcBaseChannel) Ready() <-chan struct{} {
Expand Down
24 changes: 8 additions & 16 deletions rpc/wrtc_base_channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,21 +60,17 @@ func TestWebRTCBaseChannel(t *testing.T) {
someStatus, _ := status.FromError(errors.New("ouch"))
test.That(t, bc1.write(someStatus.Proto()), test.ShouldBeNil)

isClosed, reason := bc1.Closed()
isClosed := bc1.Closed()
test.That(t, isClosed, test.ShouldBeFalse)
test.That(t, reason, test.ShouldBeNil)
isClosed, reason = bc2.Closed()
isClosed = bc2.Closed()
test.That(t, isClosed, test.ShouldBeFalse)
test.That(t, reason, test.ShouldBeNil)
test.That(t, bc1.Close(), test.ShouldBeNil)
<-peer1Done
<-peer2Done
isClosed, reason = bc1.Closed()
isClosed = bc1.Closed()
test.That(t, isClosed, test.ShouldBeTrue)
test.That(t, reason, test.ShouldBeNil)
isClosed, reason = bc2.Closed()
isClosed = bc2.Closed()
test.That(t, isClosed, test.ShouldBeTrue)
test.That(t, reason, test.ShouldEqual, errDataChannelClosed)
test.That(t, bc1.Close(), test.ShouldBeNil)
test.That(t, bc2.Close(), test.ShouldBeNil)

Expand All @@ -83,23 +79,19 @@ func TestWebRTCBaseChannel(t *testing.T) {
test.That(t, bc2.closeWithReason(err1), test.ShouldBeNil)
<-peer1Done
<-peer2Done
isClosed, reason = bc1.Closed()
isClosed = bc1.Closed()
test.That(t, isClosed, test.ShouldBeTrue)
test.That(t, reason, test.ShouldEqual, errDataChannelClosed)
isClosed, reason = bc2.Closed()
isClosed = bc2.Closed()
test.That(t, isClosed, test.ShouldBeTrue)
test.That(t, reason, test.ShouldEqual, err1)

bc1, bc2, peer1Done, peer2Done = setupWebRTCBaseChannels(t)
bc2.onChannelError(err1)
<-peer1Done
<-peer2Done
isClosed, reason = bc1.Closed()
isClosed = bc1.Closed()
test.That(t, isClosed, test.ShouldBeTrue)
test.That(t, reason, test.ShouldEqual, errDataChannelClosed)
isClosed, reason = bc2.Closed()
isClosed = bc2.Closed()
test.That(t, isClosed, test.ShouldBeTrue)
test.That(t, reason, test.ShouldEqual, err1)

test.That(t, bc2.write(someStatus.Proto()), test.ShouldEqual, io.ErrClosedPipe)
}
9 changes: 6 additions & 3 deletions rpc/wrtc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,7 +230,8 @@ func (srv *webrtcServer) unaryHandler(ss interface{}, handler methodHandler) han

response, err := handler(ss, ctx, s.webrtcBaseStream.RecvMsg, srv.unaryInt)
if err != nil {
return s.closeWithSendError(err)
s.closeWithSendError(err)
return err
}

err = s.SendMsg(response)
Expand All @@ -239,7 +240,8 @@ func (srv *webrtcServer) unaryHandler(ss interface{}, handler methodHandler) han
return err
}

return s.closeWithSendError(nil)
s.closeWithSendError(nil)
return nil
}
}

Expand All @@ -262,6 +264,7 @@ func (srv *webrtcServer) streamHandler(ss interface{}, method string, desc grpc.
if errors.Is(err, io.EOF) {
return nil
}
return s.closeWithSendError(err)
s.closeWithSendError(err)
return nil
}
}
60 changes: 18 additions & 42 deletions rpc/wrtc_server_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,7 @@ import (
"sync/atomic"

protov1 "github.com/golang/protobuf/proto" //nolint:staticcheck
"github.com/pion/sctp"
"github.com/pkg/errors"
"go.uber.org/multierr"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/metadata"
Expand Down Expand Up @@ -172,7 +170,7 @@ func (s *webrtcServerStream) SendMsg(m interface{}) (err error) {

defer func() {
if err != nil {
err = multierr.Combine(err, s.closeWithSendError(err))
s.closeWithSendError(err)
}
}()

Expand Down Expand Up @@ -227,29 +225,21 @@ func (s *webrtcServerStream) onRequest(request *webrtcpb.Request) {
switch r := request.Type.(type) {
case *webrtcpb.Request_Headers:
if s.headersReceived {
if err := s.closeWithSendError(status.Error(codes.InvalidArgument, "headers already received")); err != nil {
s.logger.Warnw("error closing", "error", err)
}
s.closeWithSendError(status.Error(codes.InvalidArgument, "headers already received"))
return
}
s.processHeaders(r.Headers)
case *webrtcpb.Request_Message:
if !s.headersReceived {
if err := s.closeWithSendError(status.Error(codes.InvalidArgument, "headers not yet received")); err != nil {
s.logger.Warnw("error closing", "error", err)
}
s.closeWithSendError(status.Error(codes.InvalidArgument, "headers not yet received"))
return
}
s.processMessage(r.Message)
case *webrtcpb.Request_RstStream:
if err := s.closeWithSendError(status.Error(codes.Canceled, "request cancelled")); err != nil {
s.logger.Warnw("error closing", "error", err)
}
s.closeWithSendError(status.Error(codes.Canceled, "request cancelled"))
return
default:
if err := s.closeWithSendError(status.Error(codes.InvalidArgument, fmt.Sprintf("unknown request type %T", r))); err != nil {
s.logger.Warnw("error closing", "error", err)
}
s.closeWithSendError(status.Error(codes.InvalidArgument, fmt.Sprintf("unknown request type %T", r)))
}
}

Expand All @@ -273,9 +263,7 @@ func (s *webrtcServerStream) processHeaders(headers *webrtcpb.RequestHeaders) {
if s.ch.server.unknownStreamDesc != nil {
handlerFunc = s.ch.server.streamHandler(s.ch.server, headers.Method, *s.ch.server.unknownStreamDesc)
} else {
if err := s.closeWithSendError(status.Error(codes.Unimplemented, codes.Unimplemented.String())); err != nil {
s.logger.Errorw("error closing", "error", err)
}
s.closeWithSendError(status.Error(codes.Unimplemented, codes.Unimplemented.String()))
return
}
}
Expand All @@ -295,9 +283,7 @@ func (s *webrtcServerStream) processHeaders(headers *webrtcpb.RequestHeaders) {
select {
case s.ch.server.callTickets <- struct{}{}:
default:
if err := s.closeWithSendError(status.Error(codes.ResourceExhausted, "too many in-flight requests")); err != nil {
s.logger.Errorw("error closing", "error", err)
}
s.closeWithSendError(status.Error(codes.ResourceExhausted, "too many in-flight requests"))
return
}

Expand Down Expand Up @@ -354,44 +340,34 @@ func (s *webrtcServerStream) processMessage(msg *webrtcpb.RequestMessage) {
}

// Must not be called with the `s.webrtcBaseStream.mu` mutex held.
func (s *webrtcServerStream) closeWithSendError(err error) (writeErr error) {
func (s *webrtcServerStream) closeWithSendError(err error) {
if !s.sendClosed.CompareAndSwap(false, true) {
return nil
return
}
defer func() {
if writeErr == nil || errors.Is(writeErr, sctp.ErrStreamClosed) {
writeErr = nil
}
}()
defer func() {
s.webrtcBaseStream.mu.Lock()
defer s.webrtcBaseStream.mu.Unlock()
s.close()
}()
if err != nil && (errors.Is(err, io.ErrClosedPipe)) {
return nil
}
chClosed, chClosedReason := s.ch.Closed()
if s.Closed() || chClosed {
if errors.Is(chClosedReason, errDataChannelClosed) &&
isContextCanceled(err) {
return nil
}
return errors.Wrap(err, "close called multiple times with error")
if s.ch.Closed() {
return
}
if err := s.writeHeaders(); err != nil {
return err
if headersErr := s.writeHeaders(); headersErr != nil {
s.logger.Warnw("Error writing headers", "err", headersErr)
return
}
var respStatus *status.Status
if err == nil {
respStatus = ErrorToStatus(s.ctx.Err())
} else {
respStatus = ErrorToStatus(err)
}
return s.ch.writeTrailers(s.stream, &webrtcpb.ResponseTrailers{
if trailersErr := s.ch.writeTrailers(s.stream, &webrtcpb.ResponseTrailers{
Status: respStatus.Proto(),
Metadata: metadataToProto(s.trailer),
})
}); trailersErr != nil {
s.logger.Warnw("Error writing trailers", "err", trailersErr)
}
}

func (s *webrtcServerStream) writeHeaders() error {
Expand Down

0 comments on commit cd31910

Please sign in to comment.