Skip to content

Commit

Permalink
RSDK-8941: Remove mutex dependence on base channel closed state varia…
Browse files Browse the repository at this point in the history
…ble. (#391)
  • Loading branch information
dgottlieb authored Nov 19, 2024
1 parent eef94a9 commit 7633571
Show file tree
Hide file tree
Showing 11 changed files with 106 additions and 139 deletions.
7 changes: 3 additions & 4 deletions rpc/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -614,7 +614,7 @@ func NewServer(logger utils.ZapCompatibleLogger, opts ...ServerOption) (Server,

if sOpts.webrtcOpts.ExternalSignalingAddress != "" {
logger.Infow(
"will run external signaling answerer",
"Running external signaling",
"signaling_address", sOpts.webrtcOpts.ExternalSignalingAddress,
"for_hosts", externalSignalingHosts,
)
Expand All @@ -631,7 +631,6 @@ func NewServer(logger utils.ZapCompatibleLogger, opts ...ServerOption) (Server,
}

if sOpts.webrtcOpts.EnableInternalSignaling {
logger.Debug("will run internal signaling service")
signalingCallQueue := NewMemoryWebRTCCallQueue(logger)
server.signalingCallQueue = signalingCallQueue
server.signalingServer = NewWebRTCSignalingServer(signalingCallQueue, nil, logger,
Expand All @@ -646,8 +645,8 @@ func NewServer(logger utils.ZapCompatibleLogger, opts ...ServerOption) (Server,
}

address := grpcListener.Addr().String()
logger.Debugw(
"will run internal signaling answerer",
logger.Infow(
"Running internal signaling",
"signaling_address", address,
"for_hosts", internalSignalingHosts,
)
Expand Down
65 changes: 24 additions & 41 deletions rpc/wrtc_base_channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ import (
"io"
"strings"
"sync"
"sync/atomic"

"github.com/pion/dtls/v2"
"github.com/pion/sctp"
"github.com/viamrobotics/webrtc/v3"
"google.golang.org/protobuf/proto"
Expand All @@ -22,8 +22,7 @@ type webrtcBaseChannel struct {
ctx context.Context
cancel func()
ready chan struct{}
closed bool
closedReason error
closed atomic.Bool
activeBackgroundWorkers sync.WaitGroup
logger utils.ZapCompatibleLogger
bufferWriteMu sync.RWMutex
Expand Down Expand Up @@ -51,7 +50,7 @@ func newBaseChannel(
}
ch.bufferWriteCond = sync.NewCond(ch.bufferWriteMu.RLocker())
dataChannel.OnOpen(ch.onChannelOpen)
dataChannel.OnClose(ch.onChannelClose)
dataChannel.OnClose(ch.Close)
dataChannel.OnError(ch.onChannelError)
dataChannel.SetBufferedAmountLowThreshold(bufferThreshold)
dataChannel.OnBufferedAmountLow(func() {
Expand All @@ -75,16 +74,17 @@ func newBaseChannel(
}
connStateChanged := func(connectionState webrtc.ICEConnectionState) {
ch.activeBackgroundWorkers.Add(1)

utils.PanicCapturingGo(func() {
defer ch.activeBackgroundWorkers.Done()

ch.mu.Lock()
defer ch.mu.Unlock()
if ch.closed {
if ch.closed.Load() {
doPeerDone()
return
}

ch.mu.Lock()
defer ch.mu.Unlock()
switch connectionState {
case webrtc.ICEConnectionStateDisconnected,
webrtc.ICEConnectionStateFailed,
Expand Down Expand Up @@ -154,43 +154,35 @@ func newBaseChannel(
return ch
}

func (ch *webrtcBaseChannel) closeWithReason(err error) error {
ch.mu.Lock()
defer ch.mu.Unlock()
if ch.closed {
return nil
}
// Close will always wait for background goroutines to exit before returning. It is safe to
// concurrently call `Close`.
//
// RSDK-8941: The above is a statement of expectations from existing code. Not a claim it is
// factually correct.
func (ch *webrtcBaseChannel) Close() {
// RSDK-8941: Having this instead early return when `closed` is set will result in `TestServer`
// to leak goroutines created by `dialWebRTC`.
ch.closed.CompareAndSwap(false, true)

ch.mu.Lock()
// APP-6839: We must hold the `bufferWriteMu` to avoid a "missed notification" that can happen
// when a `webrtcBaseChannel.write` happens concurrently with `closeWithReason`. Specifically,
// this lock makes atomic the `ch.cancel` with the broadcast. Such that a call to write that can
// `Wait` on this condition variable must either:
// - Observe the context being canceled, or
// - Call `Wait` before* the following `Broadcast` is invoked.
ch.bufferWriteMu.Lock()
ch.closed = true
ch.closedReason = err
ch.cancel()
ch.bufferWriteCond.Broadcast()
ch.bufferWriteMu.Unlock()
ch.mu.Unlock()

// Underlying connection may already be closed; ignore "conn is closed"
// errors.
if err := ch.peerConn.GracefulClose(); !errors.Is(err, dtls.ErrConnClosed) {
return err
}
return nil
utils.UncheckedError(ch.peerConn.GracefulClose())
ch.activeBackgroundWorkers.Wait()
}

func (ch *webrtcBaseChannel) Close() error {
defer ch.activeBackgroundWorkers.Wait()
return ch.closeWithReason(nil)
}

func (ch *webrtcBaseChannel) Closed() (bool, error) {
ch.mu.Lock()
defer ch.mu.Unlock()
return ch.closed, ch.closedReason
func (ch *webrtcBaseChannel) Closed() bool {
return ch.closed.Load()
}

func (ch *webrtcBaseChannel) Ready() <-chan struct{} {
Expand All @@ -201,14 +193,6 @@ func (ch *webrtcBaseChannel) onChannelOpen() {
close(ch.ready)
}

var errDataChannelClosed = errors.New("data channel closed")

func (ch *webrtcBaseChannel) onChannelClose() {
if err := ch.closeWithReason(errDataChannelClosed); err != nil {
ch.logger.Errorw("error closing channel", "error", err)
}
}

// isUserInitiatedAbortChunkErr returns true if the error is an abort chunk
// error that the user initiated through Close. Certain browsers (Safari,
// Chrome and potentially others) close RTCPeerConnections with this type of
Expand All @@ -224,9 +208,7 @@ func (ch *webrtcBaseChannel) onChannelError(err error) {
return
}
ch.logger.Errorw("channel error", "error", err)
if err := ch.closeWithReason(err); err != nil {
ch.logger.Errorw("error closing channel", "error", err)
}
ch.Close()
}

const maxDataChannelSize = 65535
Expand All @@ -239,6 +221,7 @@ func (ch *webrtcBaseChannel) write(msg proto.Message) error {
ch.bufferWriteCond.L.Lock()
for {
if ch.ctx.Err() != nil {
ch.bufferWriteCond.L.Unlock()
return io.ErrClosedPipe
}

Expand Down
33 changes: 13 additions & 20 deletions rpc/wrtc_base_channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,46 +60,39 @@ func TestWebRTCBaseChannel(t *testing.T) {
someStatus, _ := status.FromError(errors.New("ouch"))
test.That(t, bc1.write(someStatus.Proto()), test.ShouldBeNil)

isClosed, reason := bc1.Closed()
isClosed := bc1.Closed()
test.That(t, isClosed, test.ShouldBeFalse)
test.That(t, reason, test.ShouldBeNil)
isClosed, reason = bc2.Closed()
isClosed = bc2.Closed()
test.That(t, isClosed, test.ShouldBeFalse)
test.That(t, reason, test.ShouldBeNil)
test.That(t, bc1.Close(), test.ShouldBeNil)
bc1.Close()
<-peer1Done
<-peer2Done
isClosed, reason = bc1.Closed()
isClosed = bc1.Closed()
test.That(t, isClosed, test.ShouldBeTrue)
test.That(t, reason, test.ShouldBeNil)
isClosed, reason = bc2.Closed()
isClosed = bc2.Closed()
test.That(t, isClosed, test.ShouldBeTrue)
test.That(t, reason, test.ShouldEqual, errDataChannelClosed)
test.That(t, bc1.Close(), test.ShouldBeNil)
test.That(t, bc2.Close(), test.ShouldBeNil)
// Double calling close poses no problems.
bc1.Close()
bc2.Close()

bc1, bc2, peer1Done, peer2Done = setupWebRTCBaseChannels(t)
err1 := errors.New("whoops")
test.That(t, bc2.closeWithReason(err1), test.ShouldBeNil)
bc2.Close()
<-peer1Done
<-peer2Done
isClosed, reason = bc1.Closed()
isClosed = bc1.Closed()
test.That(t, isClosed, test.ShouldBeTrue)
test.That(t, reason, test.ShouldEqual, errDataChannelClosed)
isClosed, reason = bc2.Closed()
isClosed = bc2.Closed()
test.That(t, isClosed, test.ShouldBeTrue)
test.That(t, reason, test.ShouldEqual, err1)

bc1, bc2, peer1Done, peer2Done = setupWebRTCBaseChannels(t)
bc2.onChannelError(err1)
<-peer1Done
<-peer2Done
isClosed, reason = bc1.Closed()
isClosed = bc1.Closed()
test.That(t, isClosed, test.ShouldBeTrue)
test.That(t, reason, test.ShouldEqual, errDataChannelClosed)
isClosed, reason = bc2.Closed()
isClosed = bc2.Closed()
test.That(t, isClosed, test.ShouldBeTrue)
test.That(t, reason, test.ShouldEqual, err1)

test.That(t, bc2.write(someStatus.Proto()), test.ShouldEqual, io.ErrClosedPipe)
}
14 changes: 10 additions & 4 deletions rpc/wrtc_client.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"io"
"strings"
"sync"
"testing"
"time"

"github.com/pkg/errors"
Expand Down Expand Up @@ -161,8 +162,12 @@ func dialWebRTC(
)
onICEConnected := func() {
// Delay by up to 5s to allow more caller updates/better stats.
waitTime := 5 * time.Second
if testing.Testing() {
waitTime = 100 * time.Millisecond
}
select {
case <-time.After(5 * time.Second):
case <-time.After(waitTime):
case <-ctx.Done():
}

Expand Down Expand Up @@ -190,7 +195,6 @@ func dialWebRTC(
errCh := make(chan error)
sendErr := func(err error) {
if haveInit && isEOF(err) {
logger.Warnf("caller swallowing err %v", err)
return
}
if s, ok := status.FromError(err); ok && strings.Contains(s.Message(), noActiveOfferStr) {
Expand Down Expand Up @@ -376,11 +380,13 @@ func dialWebRTC(
doCall := func() error {
select {
case <-exchangeCtx.Done():
return multierr.Combine(exchangeCtx.Err(), clientCh.Close())
clientCh.close()
return exchangeCtx.Err()
case <-clientCh.Ready():
return nil
case err := <-errCh:
return multierr.Combine(err, clientCh.Close())
clientCh.close()
return err
}
}

Expand Down
12 changes: 10 additions & 2 deletions rpc/wrtc_client_channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,16 @@ func (ch *webrtcClientChannel) PeerConn() *webrtc.PeerConnection {
return ch.webrtcBaseChannel.peerConn
}

// Close closes all streams and the underlying channel.
// Close returns a nil error to satisfy ClientConn. Prefer `close` for the internal API that has no
// return value. There's nothing to do when close "has an error". This choice simplifies error
// handling.
func (ch *webrtcClientChannel) Close() error {
ch.close()
return nil
}

// Close closes all streams and the underlying channel.
func (ch *webrtcClientChannel) close() {
ch.mu.Lock()
streamsToClose := make(map[uint64]activeWebRTCClientStream, len(ch.streams))
for k, v := range ch.streams {
Expand All @@ -86,7 +94,7 @@ func (ch *webrtcClientChannel) Close() error {
for _, s := range streamsToClose {
s.cs.Close()
}
return ch.webrtcBaseChannel.Close()
ch.webrtcBaseChannel.Close()
}

// Invoke sends the RPC request on the wire and returns after response is
Expand Down
20 changes: 5 additions & 15 deletions rpc/wrtc_client_channel_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,9 +33,7 @@ func TestWebRTCClientChannel(t *testing.T) {
test.That(t, clientCh.Close(), test.ShouldBeNil)
}()
serverCh := newBaseChannel(context.Background(), pc2, dc2, nil, nil, logger)
defer func() {
test.That(t, serverCh.Close(), test.ShouldBeNil)
}()
defer serverCh.Close()

<-clientCh.Ready()
<-serverCh.Ready()
Expand Down Expand Up @@ -344,9 +342,7 @@ func TestWebRTCClientChannelResetStream(t *testing.T) {
test.That(t, clientCh.Close(), test.ShouldBeNil)
}()
serverCh := newBaseChannel(context.Background(), pc2, dc2, nil, nil, logger)
defer func() {
test.That(t, serverCh.Close(), test.ShouldBeNil)
}()
defer serverCh.Close()

<-clientCh.Ready()
<-serverCh.Ready()
Expand Down Expand Up @@ -468,9 +464,7 @@ func TestWebRTCClientChannelWithInterceptor(t *testing.T) {
test.That(t, clientCh.Close(), test.ShouldBeNil)
}()
serverCh := newBaseChannel(context.Background(), pc2, dc2, nil, nil, logger)
defer func() {
test.That(t, serverCh.Close(), test.ShouldBeNil)
}()
defer serverCh.Close()

<-clientCh.Ready()
<-serverCh.Ready()
Expand Down Expand Up @@ -551,9 +545,7 @@ func TestWebRTCClientChannelCanStopStreamRecvMsg(t *testing.T) {
test.That(t, clientCh.Close(), test.ShouldBeNil)
}()
serverCh := newBaseChannel(context.Background(), pc2, dc2, nil, nil, logger)
defer func() {
test.That(t, serverCh.Close(), test.ShouldBeNil)
}()
defer serverCh.Close()

<-clientCh.Ready()
<-serverCh.Ready()
Expand Down Expand Up @@ -651,9 +643,7 @@ func TestClientStreamCancel(t *testing.T) {
}, nil)

serverCh := newWebRTCServerChannel(server, pc2, dc2, nil, logger)
defer func() {
test.That(t, serverCh.Close(), test.ShouldBeNil)
}()
defer serverCh.Close()

<-clientCh.Ready()
<-serverCh.Ready()
Expand Down
9 changes: 6 additions & 3 deletions rpc/wrtc_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,8 @@ func (srv *webrtcServer) unaryHandler(ss interface{}, handler methodHandler) han

response, err := handler(ss, ctx, s.webrtcBaseStream.RecvMsg, srv.unaryInt)
if err != nil {
return s.closeWithSendError(err)
s.closeWithSendError(err)
return err
}

err = s.SendMsg(response)
Expand All @@ -285,7 +286,8 @@ func (srv *webrtcServer) unaryHandler(ss interface{}, handler methodHandler) han
return err
}

return s.closeWithSendError(nil)
s.closeWithSendError(nil)
return nil
}
}

Expand All @@ -308,6 +310,7 @@ func (srv *webrtcServer) streamHandler(ss interface{}, method string, desc grpc.
if errors.Is(err, io.EOF) {
return nil
}
return s.closeWithSendError(err)
s.closeWithSendError(err)
return nil
}
}
Loading

0 comments on commit 7633571

Please sign in to comment.