From 96c17f94da3c2997ed68ade41fa520986f94e93e Mon Sep 17 00:00:00 2001 From: Benjamin Rewis <32186188+benjirewis@users.noreply.github.com> Date: Wed, 20 Nov 2024 16:48:24 -0500 Subject: [PATCH] RSDK-9316 Disallow client stream creation when channel has been closed (#394) --- rpc/wrtc_client_channel.go | 8 +++-- rpc/wrtc_client_stream.go | 25 +++++++++++++-- rpc/wrtc_client_test.go | 63 ++++++++++++++++++++++++++++++++++++++ 3 files changed, 90 insertions(+), 6 deletions(-) diff --git a/rpc/wrtc_client_channel.go b/rpc/wrtc_client_channel.go index bed463e9..b4fe9964 100644 --- a/rpc/wrtc_client_channel.go +++ b/rpc/wrtc_client_channel.go @@ -265,23 +265,25 @@ func (ch *webrtcClientChannel) newStream( ) (*webrtcClientStream, error) { id := stream.GetId() ch.mu.Lock() + defer ch.mu.Unlock() activeStream, ok := ch.streams[id] if !ok { if len(ch.streams) == WebRTCMaxStreamCount { - ch.mu.Unlock() return nil, errWebRTCMaxStreams } - clientStream := newWebRTCClientStream( + clientStream, err := newWebRTCClientStream( ctx, ch, stream, ch.removeStreamByID, utils.AddFieldsToLogger(ch.webrtcBaseChannel.logger, "id", id), ) + if err != nil { + return nil, err + } activeStream = activeWebRTCClientStream{clientStream} ch.streams[id] = activeStream } - ch.mu.Unlock() return activeStream.cs, nil } diff --git a/rpc/wrtc_client_stream.go b/rpc/wrtc_client_stream.go index 11838b83..97f7ea54 100644 --- a/rpc/wrtc_client_stream.go +++ b/rpc/wrtc_client_stream.go @@ -16,7 +16,12 @@ import ( webrtcpb "go.viam.com/utils/proto/rpc/webrtc/v1" ) -var _ = grpc.ClientStream(&webrtcClientStream{}) +var ( + _ = grpc.ClientStream(&webrtcClientStream{}) + // ErrDisconnected indicates that the channel underlying the client stream + // has been closed, and the client is therefore disconnected. + ErrDisconnected = errors.New("client disconnected; underlying channel closed") +) // A webrtcClientStream is the high level gRPC streaming interface used for both // unary and streaming call requests. @@ -45,7 +50,21 @@ func newWebRTCClientStream( stream *webrtcpb.Stream, onDone func(id uint64), logger utils.ZapCompatibleLogger, -) *webrtcClientStream { +) (*webrtcClientStream, error) { + // Assume that cancelation of the client channel's context means the peer + // connection and base channel have both closed, and the client is + // disconnected. + // + // We could rely on eventual reads/writes from/to the stream failing with a + // `io.ErrClosedPipe`, but not checking the channel's context here will mean + // we can create a stream _while_ the channel is closing/closed, which can + // result in data races and undefined behavior. The caller to this function + // is holding the channel mutex that's also acquired in the "close" path that + // will cancel `channel.ctx`. + if channel.ctx.Err() != nil { + return nil, ErrDisconnected + } + ctx, cancel := utils.MergeContext(channel.ctx, ctx) bs := newWebRTCBaseStream(ctx, cancel, stream, onDone, logger) s := &webrtcClientStream{ @@ -65,7 +84,7 @@ func newWebRTCClientStream( } } }) - return s + return s, nil } // SendMsg is generally called by generated code. On error, SendMsg aborts diff --git a/rpc/wrtc_client_test.go b/rpc/wrtc_client_test.go index 4325adce..e83772f0 100644 --- a/rpc/wrtc_client_test.go +++ b/rpc/wrtc_client_test.go @@ -618,3 +618,66 @@ func TestWebRTCClientSubsequentStreams(t *testing.T) { err = <-errChan test.That(t, err, test.ShouldBeNil) } + +func TestErrDisconnected(t *testing.T) { + logger := golog.NewTestLogger(t) + serverOpts := []ServerOption{ + WithWebRTCServerOptions(WebRTCServerOptions{ + Enable: true, + }), + WithUnauthenticated(), + } + rpcServer, err := NewServer( + logger, + serverOpts..., + ) + test.That(t, err, test.ShouldBeNil) + + es := echoserver.Server{} + err = rpcServer.RegisterServiceServer( + context.Background(), + &echopb.EchoService_ServiceDesc, + &es, + echopb.RegisterEchoServiceHandlerFromEndpoint, + ) + test.That(t, err, test.ShouldBeNil) + + listener, err := net.Listen("tcp", "localhost:0") + test.That(t, err, test.ShouldBeNil) + + errChan := make(chan error) + go func() { + errChan <- rpcServer.Serve(listener) + }() + + rtcConn, err := DialWebRTC( + context.Background(), + listener.Addr().String(), + rpcServer.InstanceNames()[0], + logger, + WithDialDebug(), + WithInsecure(), + ) + test.That(t, err, test.ShouldBeNil) + + client := echopb.NewEchoServiceClient(rtcConn) + + msg := "these-are-not-the-droids-you're-looking-for" + echoResp, err := client.Echo(context.Background(), &echopb.EchoRequest{Message: msg}) + test.That(t, err, test.ShouldBeNil) + test.That(t, echoResp.GetMessage(), test.ShouldEqual, msg) + + // Close underlying ClientConn and expect that further usages of the gRPC + // client will result in `ErrDisconnected`. + test.That(t, rtcConn.Close(), test.ShouldBeNil) + for i := 0; i < 2; i++ { + echoResp, err = client.Echo(context.Background(), &echopb.EchoRequest{Message: msg}) + test.That(t, echoResp, test.ShouldBeNil) + test.That(t, err, test.ShouldNotBeNil) + test.That(t, err, test.ShouldBeError, ErrDisconnected) + } + + test.That(t, rpcServer.Stop(), test.ShouldBeNil) + err = <-errChan + test.That(t, err, test.ShouldBeNil) +}