Skip to content

Commit

Permalink
RSDK-8779: Remove additional cases where signaling server errors over…
Browse files Browse the repository at this point in the history
…ride successful PeerConnection attempts. (#388)
  • Loading branch information
dgottlieb authored Nov 19, 2024
1 parent 7633571 commit 96ba3e3
Show file tree
Hide file tree
Showing 7 changed files with 162 additions and 160 deletions.
4 changes: 4 additions & 0 deletions logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ type ILogger interface {
type ZapCompatibleLogger interface {
Desugar() *zap.Logger

// Not defined: Named(name string) *zap.SugaredLogger
//
// Use `Sublogger(logger, "name")` instead of calling `Named` directly.

Debug(args ...interface{})
Debugf(template string, args ...interface{})
Debugw(msg string, keysAndValues ...interface{})
Expand Down
29 changes: 14 additions & 15 deletions rpc/dial.go
Original file line number Diff line number Diff line change
Expand Up @@ -138,11 +138,11 @@ func dial(
if !dOpts.mdnsOptions.Disable && tryLocal && isJustDomain {
wg.Add(1)
go func(dOpts dialOptions) {
mdnsLogger := utils.Sublogger(logger, "mdns")
defer wg.Done()
if dOpts.debug {
logger.Debugw("trying mDNS", "address", address)
}
conn, cached, err := dialMulticastDNS(ctxParallel, address, logger, dOpts)

mdnsLogger.Debugw("trying mDNS", "address", address)
conn, cached, err := dialMulticastDNS(ctxParallel, address, mdnsLogger, dOpts)
if err != nil {
dialCh <- dialResult{err: err}
} else {
Expand All @@ -152,6 +152,7 @@ func dial(
}

if !dOpts.webrtcOpts.Disable {
webrtcLogger := utils.Sublogger(logger, "webrtc")
wg.Add(1)
go func(dOpts dialOptions) {
defer wg.Done()
Expand All @@ -168,7 +169,7 @@ func dial(
// that the direct dialing address might be different from the
// signaling address, but it seems better to fail fast and let the
// client fix any configuration issues.
logger.Errorw("failed to parse signaling address", "address", signalingAddress, "error", err)
webrtcLogger.Errorw("failed to parse signaling address", "address", signalingAddress, "error", err)
dialCh <- dialResult{err: err, skipDirect: true}
return
}
Expand All @@ -178,19 +179,17 @@ func dial(
// This path is also called by an mdns direct connection and ignores that case.
// This will skip all Authenticate/AuthenticateTo calls for the signaler.
if !dOpts.usingMDNS && dOpts.authMaterial == "" && dOpts.webrtcOpts.SignalingExternalAuthAuthMaterial != "" {
logger.Debug("using signaling's external auth as auth material")
webrtcLogger.Debug("using signaling's external auth as auth material")
dOpts.authMaterial = dOpts.webrtcOpts.SignalingExternalAuthAuthMaterial
dOpts.creds = Credentials{}
}
}

if dOpts.debug {
logger.Debugw(
"trying WebRTC",
"signaling_server", dOpts.webrtcOpts.SignalingServerAddress,
"host", originalAddress,
)
}
webrtcLogger.Debugw(
"trying WebRTC",
"signaling_server", dOpts.webrtcOpts.SignalingServerAddress,
"host", originalAddress,
)

conn, cached, err := dialFunc(
ctxParallel,
Expand All @@ -203,14 +202,14 @@ func dial(
dOpts.webrtcOpts.SignalingServerAddress,
originalAddress,
dOpts,
logger,
webrtcLogger,
)
})

switch {
case err == nil:
if dOpts.debug {
logger.Debugw("connected via WebRTC",
webrtcLogger.Debugw("connected via WebRTC",
"address", address,
"cached", cached,
"using mDNS", dOpts.usingMDNS,
Expand Down
26 changes: 18 additions & 8 deletions rpc/dial_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1150,10 +1150,11 @@ func TestDialMulticastDNS(t *testing.T) {

t.Run("fix mdns instance name", func(t *testing.T) {
rpcServer, err := NewServer(
logger,
logger.Named("server"),
WithUnauthenticated(),
WithInstanceNames("this.is.a.test.cloud"),
)
logger = logger.Named("fixmdns")
test.That(t, err, test.ShouldBeNil)
test.That(t, rpcServer.Start(), test.ShouldBeNil)
test.That(t, rpcServer.InstanceNames(), test.ShouldHaveLength, 1)
Expand Down Expand Up @@ -1182,7 +1183,7 @@ func TestDialMulticastDNS(t *testing.T) {

t.Run("unauthenticated", func(t *testing.T) {
rpcServer, err := NewServer(
logger,
logger.Named("server"),
WithUnauthenticated(),
)
test.That(t, err, test.ShouldBeNil)
Expand All @@ -1192,58 +1193,66 @@ func TestDialMulticastDNS(t *testing.T) {
conn, err := Dial(
context.Background(),
rpcServer.InstanceNames()[0],
logger,
logger.Named("unauthenticated1"),
WithInsecure(),
WithDialDebug(),
)
test.That(t, err, test.ShouldBeNil)
logger.Info("Dial1 success")
// There's no webrtc. The connection must not be backed by a PeerConn.
test.That(t, conn.PeerConn(), test.ShouldBeNil)
test.That(t, conn.Close(), test.ShouldBeNil)

ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
_, err = Dial(
ctx,
rpcServer.InstanceNames()[0],
logger,
logger.Named("unauthenticated2"),
WithInsecure(),
WithDialDebug(),
WithDialMulticastDNSOptions(DialMulticastDNSOptions{Disable: true}),
)
test.That(t, err, test.ShouldResemble, context.DeadlineExceeded)
logger.Info("Dial2 'success'")

test.That(t, rpcServer.Stop(), test.ShouldBeNil)

logger.Info("Dial2 server stopped")
rpcServer, err = NewServer(
logger,
logger.Named("server"),
WithUnauthenticated(),
WithWebRTCServerOptions(WebRTCServerOptions{Enable: true}),
)
test.That(t, err, test.ShouldBeNil)
test.That(t, rpcServer.Start(), test.ShouldBeNil)
logger.Info("Dial3 server started")

conn, err = Dial(
context.Background(),
rpcServer.InstanceNames()[0],
logger,
logger.Named("unauthenticated3"),
WithInsecure(),
WithDialDebug(),
WithDisableDirectGRPC(),
)
test.That(t, err, test.ShouldBeNil)
test.That(t, conn.PeerConn(), test.ShouldNotBeNil)
test.That(t, conn.Close(), test.ShouldBeNil)

test.That(t, rpcServer.Stop(), test.ShouldBeNil)
})

t.Run("authenticated", func(t *testing.T) {
rpcServer, err := NewServer(
logger,
logger.Named("server"),
WithAuthHandler("fake", AuthHandlerFunc(func(ctx context.Context, entity, payload string) (map[string]string, error) {
return map[string]string{}, nil
})),
)
test.That(t, err, test.ShouldBeNil)

logger = logger.Named("authenticated")
err = rpcServer.RegisterServiceServer(
context.Background(),
&pb.EchoService_ServiceDesc,
Expand Down Expand Up @@ -1309,7 +1318,7 @@ func TestDialMulticastDNS(t *testing.T) {
t.Run("authenticated with names", func(t *testing.T) {
names := []string{primitive.NewObjectID().Hex(), primitive.NewObjectID().Hex()}
rpcServer, err := NewServer(
logger,
logger.Named("server"),
WithAuthHandler("fake", AuthHandlerFunc(func(ctx context.Context, entity, payload string) (map[string]string, error) {
return map[string]string{}, nil
})),
Expand All @@ -1318,6 +1327,7 @@ func TestDialMulticastDNS(t *testing.T) {
test.That(t, err, test.ShouldBeNil)
test.That(t, rpcServer.InstanceNames(), test.ShouldResemble, names)

logger = logger.Named("authwithnames")
err = rpcServer.RegisterServiceServer(
context.Background(),
&pb.EchoService_ServiceDesc,
Expand Down
Loading

0 comments on commit 96ba3e3

Please sign in to comment.