Skip to content

Commit

Permalink
RSDK-9316 Disallow client stream creation when channel has been closed (
Browse files Browse the repository at this point in the history
  • Loading branch information
benjirewis authored Nov 20, 2024
1 parent 86ce518 commit 96c17f9
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 6 deletions.
8 changes: 5 additions & 3 deletions rpc/wrtc_client_channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,23 +265,25 @@ func (ch *webrtcClientChannel) newStream(
) (*webrtcClientStream, error) {
id := stream.GetId()
ch.mu.Lock()
defer ch.mu.Unlock()
activeStream, ok := ch.streams[id]
if !ok {
if len(ch.streams) == WebRTCMaxStreamCount {
ch.mu.Unlock()
return nil, errWebRTCMaxStreams
}
clientStream := newWebRTCClientStream(
clientStream, err := newWebRTCClientStream(
ctx,
ch,
stream,
ch.removeStreamByID,
utils.AddFieldsToLogger(ch.webrtcBaseChannel.logger, "id", id),
)
if err != nil {
return nil, err
}
activeStream = activeWebRTCClientStream{clientStream}
ch.streams[id] = activeStream
}
ch.mu.Unlock()
return activeStream.cs, nil
}

Expand Down
25 changes: 22 additions & 3 deletions rpc/wrtc_client_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ import (
webrtcpb "go.viam.com/utils/proto/rpc/webrtc/v1"
)

var _ = grpc.ClientStream(&webrtcClientStream{})
var (
_ = grpc.ClientStream(&webrtcClientStream{})
// ErrDisconnected indicates that the channel underlying the client stream
// has been closed, and the client is therefore disconnected.
ErrDisconnected = errors.New("client disconnected; underlying channel closed")
)

// A webrtcClientStream is the high level gRPC streaming interface used for both
// unary and streaming call requests.
Expand Down Expand Up @@ -45,7 +50,21 @@ func newWebRTCClientStream(
stream *webrtcpb.Stream,
onDone func(id uint64),
logger utils.ZapCompatibleLogger,
) *webrtcClientStream {
) (*webrtcClientStream, error) {
// Assume that cancelation of the client channel's context means the peer
// connection and base channel have both closed, and the client is
// disconnected.
//
// We could rely on eventual reads/writes from/to the stream failing with a
// `io.ErrClosedPipe`, but not checking the channel's context here will mean
// we can create a stream _while_ the channel is closing/closed, which can
// result in data races and undefined behavior. The caller to this function
// is holding the channel mutex that's also acquired in the "close" path that
// will cancel `channel.ctx`.
if channel.ctx.Err() != nil {
return nil, ErrDisconnected
}

ctx, cancel := utils.MergeContext(channel.ctx, ctx)
bs := newWebRTCBaseStream(ctx, cancel, stream, onDone, logger)
s := &webrtcClientStream{
Expand All @@ -65,7 +84,7 @@ func newWebRTCClientStream(
}
}
})
return s
return s, nil
}

// SendMsg is generally called by generated code. On error, SendMsg aborts
Expand Down
63 changes: 63 additions & 0 deletions rpc/wrtc_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -618,3 +618,66 @@ func TestWebRTCClientSubsequentStreams(t *testing.T) {
err = <-errChan
test.That(t, err, test.ShouldBeNil)
}

func TestErrDisconnected(t *testing.T) {
logger := golog.NewTestLogger(t)
serverOpts := []ServerOption{
WithWebRTCServerOptions(WebRTCServerOptions{
Enable: true,
}),
WithUnauthenticated(),
}
rpcServer, err := NewServer(
logger,
serverOpts...,
)
test.That(t, err, test.ShouldBeNil)

es := echoserver.Server{}
err = rpcServer.RegisterServiceServer(
context.Background(),
&echopb.EchoService_ServiceDesc,
&es,
echopb.RegisterEchoServiceHandlerFromEndpoint,
)
test.That(t, err, test.ShouldBeNil)

listener, err := net.Listen("tcp", "localhost:0")
test.That(t, err, test.ShouldBeNil)

errChan := make(chan error)
go func() {
errChan <- rpcServer.Serve(listener)
}()

rtcConn, err := DialWebRTC(
context.Background(),
listener.Addr().String(),
rpcServer.InstanceNames()[0],
logger,
WithDialDebug(),
WithInsecure(),
)
test.That(t, err, test.ShouldBeNil)

client := echopb.NewEchoServiceClient(rtcConn)

msg := "these-are-not-the-droids-you're-looking-for"
echoResp, err := client.Echo(context.Background(), &echopb.EchoRequest{Message: msg})
test.That(t, err, test.ShouldBeNil)
test.That(t, echoResp.GetMessage(), test.ShouldEqual, msg)

// Close underlying ClientConn and expect that further usages of the gRPC
// client will result in `ErrDisconnected`.
test.That(t, rtcConn.Close(), test.ShouldBeNil)
for i := 0; i < 2; i++ {
echoResp, err = client.Echo(context.Background(), &echopb.EchoRequest{Message: msg})
test.That(t, echoResp, test.ShouldBeNil)
test.That(t, err, test.ShouldNotBeNil)
test.That(t, err, test.ShouldBeError, ErrDisconnected)
}

test.That(t, rpcServer.Stop(), test.ShouldBeNil)
err = <-errChan
test.That(t, err, test.ShouldBeNil)
}

0 comments on commit 96c17f9

Please sign in to comment.