From cd31910d792e102b4521c0a9acda4cadab6268ec Mon Sep 17 00:00:00 2001 From: Dan Gottlieb Date: Tue, 12 Nov 2024 09:36:16 -0500 Subject: [PATCH] Remove the `closedReason`. Its only used for log lines that are defered. And it complicates the mutex order acquisition story. --- rpc/wrtc_base_channel.go | 6 ++-- rpc/wrtc_base_channel_test.go | 24 +++++--------- rpc/wrtc_server.go | 9 ++++-- rpc/wrtc_server_stream.go | 60 +++++++++++------------------------ 4 files changed, 34 insertions(+), 65 deletions(-) diff --git a/rpc/wrtc_base_channel.go b/rpc/wrtc_base_channel.go index 9b4ad0ce..a5739347 100644 --- a/rpc/wrtc_base_channel.go +++ b/rpc/wrtc_base_channel.go @@ -24,7 +24,6 @@ type webrtcBaseChannel struct { cancel func() ready chan struct{} closed atomic.Bool - closedReason error activeBackgroundWorkers sync.WaitGroup logger utils.ZapCompatibleLogger bufferWriteMu sync.RWMutex @@ -162,7 +161,6 @@ func (ch *webrtcBaseChannel) closeWithReason(err error) error { ch.mu.Lock() defer ch.mu.Unlock() - ch.closedReason = err ch.cancel() ch.bufferWriteCond.Broadcast() @@ -179,8 +177,8 @@ func (ch *webrtcBaseChannel) Close() error { return ch.closeWithReason(nil) } -func (ch *webrtcBaseChannel) Closed() (bool, error) { - return ch.closed.Load(), ch.closedReason +func (ch *webrtcBaseChannel) Closed() bool { + return ch.closed.Load() } func (ch *webrtcBaseChannel) Ready() <-chan struct{} { diff --git a/rpc/wrtc_base_channel_test.go b/rpc/wrtc_base_channel_test.go index 9cf9a1e2..f75abdab 100644 --- a/rpc/wrtc_base_channel_test.go +++ b/rpc/wrtc_base_channel_test.go @@ -60,21 +60,17 @@ func TestWebRTCBaseChannel(t *testing.T) { someStatus, _ := status.FromError(errors.New("ouch")) test.That(t, bc1.write(someStatus.Proto()), test.ShouldBeNil) - isClosed, reason := bc1.Closed() + isClosed := bc1.Closed() test.That(t, isClosed, test.ShouldBeFalse) - test.That(t, reason, test.ShouldBeNil) - isClosed, reason = bc2.Closed() + isClosed = bc2.Closed() test.That(t, isClosed, test.ShouldBeFalse) - test.That(t, reason, test.ShouldBeNil) test.That(t, bc1.Close(), test.ShouldBeNil) <-peer1Done <-peer2Done - isClosed, reason = bc1.Closed() + isClosed = bc1.Closed() test.That(t, isClosed, test.ShouldBeTrue) - test.That(t, reason, test.ShouldBeNil) - isClosed, reason = bc2.Closed() + isClosed = bc2.Closed() test.That(t, isClosed, test.ShouldBeTrue) - test.That(t, reason, test.ShouldEqual, errDataChannelClosed) test.That(t, bc1.Close(), test.ShouldBeNil) test.That(t, bc2.Close(), test.ShouldBeNil) @@ -83,23 +79,19 @@ func TestWebRTCBaseChannel(t *testing.T) { test.That(t, bc2.closeWithReason(err1), test.ShouldBeNil) <-peer1Done <-peer2Done - isClosed, reason = bc1.Closed() + isClosed = bc1.Closed() test.That(t, isClosed, test.ShouldBeTrue) - test.That(t, reason, test.ShouldEqual, errDataChannelClosed) - isClosed, reason = bc2.Closed() + isClosed = bc2.Closed() test.That(t, isClosed, test.ShouldBeTrue) - test.That(t, reason, test.ShouldEqual, err1) bc1, bc2, peer1Done, peer2Done = setupWebRTCBaseChannels(t) bc2.onChannelError(err1) <-peer1Done <-peer2Done - isClosed, reason = bc1.Closed() + isClosed = bc1.Closed() test.That(t, isClosed, test.ShouldBeTrue) - test.That(t, reason, test.ShouldEqual, errDataChannelClosed) - isClosed, reason = bc2.Closed() + isClosed = bc2.Closed() test.That(t, isClosed, test.ShouldBeTrue) - test.That(t, reason, test.ShouldEqual, err1) test.That(t, bc2.write(someStatus.Proto()), test.ShouldEqual, io.ErrClosedPipe) } diff --git a/rpc/wrtc_server.go b/rpc/wrtc_server.go index 5feb2518..cd13579a 100644 --- a/rpc/wrtc_server.go +++ b/rpc/wrtc_server.go @@ -230,7 +230,8 @@ func (srv *webrtcServer) unaryHandler(ss interface{}, handler methodHandler) han response, err := handler(ss, ctx, s.webrtcBaseStream.RecvMsg, srv.unaryInt) if err != nil { - return s.closeWithSendError(err) + s.closeWithSendError(err) + return err } err = s.SendMsg(response) @@ -239,7 +240,8 @@ func (srv *webrtcServer) unaryHandler(ss interface{}, handler methodHandler) han return err } - return s.closeWithSendError(nil) + s.closeWithSendError(nil) + return nil } } @@ -262,6 +264,7 @@ func (srv *webrtcServer) streamHandler(ss interface{}, method string, desc grpc. if errors.Is(err, io.EOF) { return nil } - return s.closeWithSendError(err) + s.closeWithSendError(err) + return nil } } diff --git a/rpc/wrtc_server_stream.go b/rpc/wrtc_server_stream.go index 74485a36..fb334baa 100644 --- a/rpc/wrtc_server_stream.go +++ b/rpc/wrtc_server_stream.go @@ -8,9 +8,7 @@ import ( "sync/atomic" protov1 "github.com/golang/protobuf/proto" //nolint:staticcheck - "github.com/pion/sctp" "github.com/pkg/errors" - "go.uber.org/multierr" "google.golang.org/grpc" "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" @@ -172,7 +170,7 @@ func (s *webrtcServerStream) SendMsg(m interface{}) (err error) { defer func() { if err != nil { - err = multierr.Combine(err, s.closeWithSendError(err)) + s.closeWithSendError(err) } }() @@ -227,29 +225,21 @@ func (s *webrtcServerStream) onRequest(request *webrtcpb.Request) { switch r := request.Type.(type) { case *webrtcpb.Request_Headers: if s.headersReceived { - if err := s.closeWithSendError(status.Error(codes.InvalidArgument, "headers already received")); err != nil { - s.logger.Warnw("error closing", "error", err) - } + s.closeWithSendError(status.Error(codes.InvalidArgument, "headers already received")) return } s.processHeaders(r.Headers) case *webrtcpb.Request_Message: if !s.headersReceived { - if err := s.closeWithSendError(status.Error(codes.InvalidArgument, "headers not yet received")); err != nil { - s.logger.Warnw("error closing", "error", err) - } + s.closeWithSendError(status.Error(codes.InvalidArgument, "headers not yet received")) return } s.processMessage(r.Message) case *webrtcpb.Request_RstStream: - if err := s.closeWithSendError(status.Error(codes.Canceled, "request cancelled")); err != nil { - s.logger.Warnw("error closing", "error", err) - } + s.closeWithSendError(status.Error(codes.Canceled, "request cancelled")) return default: - if err := s.closeWithSendError(status.Error(codes.InvalidArgument, fmt.Sprintf("unknown request type %T", r))); err != nil { - s.logger.Warnw("error closing", "error", err) - } + s.closeWithSendError(status.Error(codes.InvalidArgument, fmt.Sprintf("unknown request type %T", r))) } } @@ -273,9 +263,7 @@ func (s *webrtcServerStream) processHeaders(headers *webrtcpb.RequestHeaders) { if s.ch.server.unknownStreamDesc != nil { handlerFunc = s.ch.server.streamHandler(s.ch.server, headers.Method, *s.ch.server.unknownStreamDesc) } else { - if err := s.closeWithSendError(status.Error(codes.Unimplemented, codes.Unimplemented.String())); err != nil { - s.logger.Errorw("error closing", "error", err) - } + s.closeWithSendError(status.Error(codes.Unimplemented, codes.Unimplemented.String())) return } } @@ -295,9 +283,7 @@ func (s *webrtcServerStream) processHeaders(headers *webrtcpb.RequestHeaders) { select { case s.ch.server.callTickets <- struct{}{}: default: - if err := s.closeWithSendError(status.Error(codes.ResourceExhausted, "too many in-flight requests")); err != nil { - s.logger.Errorw("error closing", "error", err) - } + s.closeWithSendError(status.Error(codes.ResourceExhausted, "too many in-flight requests")) return } @@ -354,33 +340,21 @@ func (s *webrtcServerStream) processMessage(msg *webrtcpb.RequestMessage) { } // Must not be called with the `s.webrtcBaseStream.mu` mutex held. -func (s *webrtcServerStream) closeWithSendError(err error) (writeErr error) { +func (s *webrtcServerStream) closeWithSendError(err error) { if !s.sendClosed.CompareAndSwap(false, true) { - return nil + return } - defer func() { - if writeErr == nil || errors.Is(writeErr, sctp.ErrStreamClosed) { - writeErr = nil - } - }() defer func() { s.webrtcBaseStream.mu.Lock() defer s.webrtcBaseStream.mu.Unlock() s.close() }() - if err != nil && (errors.Is(err, io.ErrClosedPipe)) { - return nil - } - chClosed, chClosedReason := s.ch.Closed() - if s.Closed() || chClosed { - if errors.Is(chClosedReason, errDataChannelClosed) && - isContextCanceled(err) { - return nil - } - return errors.Wrap(err, "close called multiple times with error") + if s.ch.Closed() { + return } - if err := s.writeHeaders(); err != nil { - return err + if headersErr := s.writeHeaders(); headersErr != nil { + s.logger.Warnw("Error writing headers", "err", headersErr) + return } var respStatus *status.Status if err == nil { @@ -388,10 +362,12 @@ func (s *webrtcServerStream) closeWithSendError(err error) (writeErr error) { } else { respStatus = ErrorToStatus(err) } - return s.ch.writeTrailers(s.stream, &webrtcpb.ResponseTrailers{ + if trailersErr := s.ch.writeTrailers(s.stream, &webrtcpb.ResponseTrailers{ Status: respStatus.Proto(), Metadata: metadataToProto(s.trailer), - }) + }); trailersErr != nil { + s.logger.Warnw("Error writing trailers", "err", trailersErr) + } } func (s *webrtcServerStream) writeHeaders() error {