Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RSDK-9316 Disallow client stream creation when channel has been closed #394

Merged
merged 4 commits into from
Nov 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions rpc/wrtc_client_channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,23 +265,25 @@ func (ch *webrtcClientChannel) newStream(
) (*webrtcClientStream, error) {
id := stream.GetId()
ch.mu.Lock()
defer ch.mu.Unlock()
activeStream, ok := ch.streams[id]
if !ok {
if len(ch.streams) == WebRTCMaxStreamCount {
ch.mu.Unlock()
return nil, errWebRTCMaxStreams
}
clientStream := newWebRTCClientStream(
clientStream, err := newWebRTCClientStream(
ctx,
ch,
stream,
ch.removeStreamByID,
utils.AddFieldsToLogger(ch.webrtcBaseChannel.logger, "id", id),
)
if err != nil {
return nil, err
}
activeStream = activeWebRTCClientStream{clientStream}
ch.streams[id] = activeStream
}
ch.mu.Unlock()
return activeStream.cs, nil
}

Expand Down
25 changes: 22 additions & 3 deletions rpc/wrtc_client_stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,12 @@ import (
webrtcpb "go.viam.com/utils/proto/rpc/webrtc/v1"
)

var _ = grpc.ClientStream(&webrtcClientStream{})
var (
_ = grpc.ClientStream(&webrtcClientStream{})
// ErrDisconnected indicates that the channel underlying the client stream
// has been closed, and the client is therefore disconnected.
ErrDisconnected = errors.New("client disconnected; underlying channel closed")
)

// A webrtcClientStream is the high level gRPC streaming interface used for both
// unary and streaming call requests.
Expand Down Expand Up @@ -45,7 +50,21 @@ func newWebRTCClientStream(
stream *webrtcpb.Stream,
onDone func(id uint64),
logger utils.ZapCompatibleLogger,
) *webrtcClientStream {
) (*webrtcClientStream, error) {
// Assume that cancelation of the client channel's context means the peer
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a good start to a comment. But I feel like there's a bit of "what's the consequence of removing this"? This just looks like an optimization to inform the caller sooner that their invocation/stream isn't going to work.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added some elaboration there; let me know what you think. Thanks for asking for more, will be helpful for posterity.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Much better!

I'd also note (because it's very not obvious in this case) the caller is holding the channel mutex that's also acquired in the "close" code path that's canceling this context.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For sure; included (almost) exactly what you said there.

// connection and base channel have both closed, and the client is
// disconnected.
//
// We could rely on eventual reads/writes from/to the stream failing with a
// `io.ErrClosedPipe`, but not checking the channel's context here will mean
// we can create a stream _while_ the channel is closing/closed, which can
// result in data races and undefined behavior. The caller to this function
// is holding the channel mutex that's also acquired in the "close" path that
// will cancel `channel.ctx`.
if channel.ctx.Err() != nil {
return nil, ErrDisconnected
}

ctx, cancel := utils.MergeContext(channel.ctx, ctx)
bs := newWebRTCBaseStream(ctx, cancel, stream, onDone, logger)
s := &webrtcClientStream{
Expand All @@ -65,7 +84,7 @@ func newWebRTCClientStream(
}
}
})
return s
return s, nil
}

// SendMsg is generally called by generated code. On error, SendMsg aborts
Expand Down
63 changes: 63 additions & 0 deletions rpc/wrtc_client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -618,3 +618,66 @@ func TestWebRTCClientSubsequentStreams(t *testing.T) {
err = <-errChan
test.That(t, err, test.ShouldBeNil)
}

func TestErrDisconnected(t *testing.T) {
logger := golog.NewTestLogger(t)
serverOpts := []ServerOption{
WithWebRTCServerOptions(WebRTCServerOptions{
Enable: true,
}),
WithUnauthenticated(),
}
rpcServer, err := NewServer(
logger,
serverOpts...,
)
test.That(t, err, test.ShouldBeNil)

es := echoserver.Server{}
err = rpcServer.RegisterServiceServer(
context.Background(),
&echopb.EchoService_ServiceDesc,
&es,
echopb.RegisterEchoServiceHandlerFromEndpoint,
)
test.That(t, err, test.ShouldBeNil)

listener, err := net.Listen("tcp", "localhost:0")
test.That(t, err, test.ShouldBeNil)

errChan := make(chan error)
go func() {
errChan <- rpcServer.Serve(listener)
}()

rtcConn, err := DialWebRTC(
context.Background(),
listener.Addr().String(),
rpcServer.InstanceNames()[0],
logger,
WithDialDebug(),
WithInsecure(),
)
test.That(t, err, test.ShouldBeNil)

client := echopb.NewEchoServiceClient(rtcConn)

msg := "these-are-not-the-droids-you're-looking-for"
echoResp, err := client.Echo(context.Background(), &echopb.EchoRequest{Message: msg})
test.That(t, err, test.ShouldBeNil)
test.That(t, echoResp.GetMessage(), test.ShouldEqual, msg)

// Close underlying ClientConn and expect that further usages of the gRPC
// client will result in `ErrDisconnected`.
test.That(t, rtcConn.Close(), test.ShouldBeNil)
for i := 0; i < 2; i++ {
echoResp, err = client.Echo(context.Background(), &echopb.EchoRequest{Message: msg})
test.That(t, echoResp, test.ShouldBeNil)
test.That(t, err, test.ShouldNotBeNil)
test.That(t, err, test.ShouldBeError, ErrDisconnected)
}

test.That(t, rpcServer.Stop(), test.ShouldBeNil)
err = <-errChan
test.That(t, err, test.ShouldBeNil)
}
Loading