diff --git a/internal/rpc/client_test.go b/internal/rpc/client_test.go index 79cfebed0..209133f96 100644 --- a/internal/rpc/client_test.go +++ b/internal/rpc/client_test.go @@ -266,7 +266,7 @@ func TestProxySockets(t *testing.T) { LoginService: &mockLoginService{}, } err := rpc.ProxySockets(ctx, proxyHelpers) - c.Check(err, qt.ErrorMatches, "error reading from (client|controller).*") + c.Check(err, qt.IsNil) errChan <- err return err }) @@ -298,6 +298,68 @@ func TestProxySockets(t *testing.T) { <-errChan // Ensure go routines are cleaned up } +func TestProxySocketsControllerConnectionFails(t *testing.T) { + c := qt.New(t) + ctx := context.Background() + + srvController := newServer(echo) + + var connController *websocket.Conn + errChan := make(chan error) + srvJIMM := newServer(func(connClient *websocket.Conn) error { + testTokenGen := testTokenGenerator{} + f := func(context.Context) (rpc.WebsocketConnectionWithMetadata, error) { + var err error + connController, err = srvController.dialer.DialWebsocket(ctx, srvController.URL) + c.Check(err, qt.IsNil) + return rpc.WebsocketConnectionWithMetadata{ + Conn: connController, + ModelName: "TestName", + }, nil + } + auditLogger := func(ale *dbmodel.AuditLogEntry) {} + proxyHelpers := rpc.ProxyHelpers{ + ConnClient: connClient, + TokenGen: &testTokenGen, + ConnectController: f, + AuditLog: auditLogger, + LoginService: &mockLoginService{}, + } + err := rpc.ProxySockets(ctx, proxyHelpers) + c.Check(err, qt.IsNil) + errChan <- err + return err + }) + + defer srvController.Close() + defer srvJIMM.Close() + ws, err := srvJIMM.dialer.DialWebsocket(ctx, srvJIMM.URL) + c.Assert(err, qt.IsNil) + defer ws.Close() + + p := json.RawMessage(`{"Key":"TestVal"}`) + msg := rpc.Message{RequestID: 1, Type: "TestType", Request: "TestReq", Params: p} + err = ws.WriteJSON(&msg) + c.Assert(err, qt.IsNil) + resp := rpc.Message{} + receiveChan := make(chan error) + go func() { + receiveChan <- ws.ReadJSON(&resp) + }() + select { + case err := <-receiveChan: + c.Assert(err, qt.IsNil) + case <-time.After(5 * time.Second): + c.Logf("took too long to read response") + c.FailNow() + } + c.Assert(resp.Response, qt.DeepEquals, msg.Params) + + // Now close the connection to the controller and ensure the model proxy is cleaned up. + connController.Close() + <-errChan // Ensure go routines are cleaned up +} + func TestCancelProxySockets(t *testing.T) { c := qt.New(t) @@ -368,7 +430,7 @@ func TestProxySocketsAuditLogs(t *testing.T) { LoginService: &mockLoginService{}, } err := rpc.ProxySockets(ctx, proxyHelpers) - c.Check(err, qt.ErrorMatches, `error reading from (client|controller).*`) + c.Check(err, qt.IsNil) errChan <- err return err }) diff --git a/internal/rpc/proxy.go b/internal/rpc/proxy.go index 60d7db1ce..79d5a09bb 100644 --- a/internal/rpc/proxy.go +++ b/internal/rpc/proxy.go @@ -9,6 +9,7 @@ import ( "sync" "time" + "github.com/gorilla/websocket" "github.com/juju/juju/rpc/params" "github.com/juju/names/v5" "github.com/juju/zaputil/zapctx" @@ -121,22 +122,18 @@ func ProxySockets(ctx context.Context, helpers ProxyHelpers) error { }() var err error select { - // No cleanup is needed on error, when the client closes the connection - // all go routines will proceed to error and exit. case err = <-errChan: - zapctx.Debug(ctx, "Proxy error", zap.Error(err)) + if err != nil { + zapctx.Debug(ctx, "Proxy error", zap.Error(err)) + } case <-ctx.Done(): err = errors.E(op, "Context cancelled") zapctx.Debug(ctx, "Context cancelled") - helpers.ConnClient.Close() - clProxy.mu.Lock() - clProxy.closed = true - // TODO(Kian): Test removing close on dst below. The client connection should do it. - if clProxy.dst != nil { - clProxy.dst.conn.Close() - } - clProxy.mu.Unlock() } + // Close the client connection to ensure everything is cleaned up. + // Normally the client would do this but we also do it here in case the + // connection to the controller fails and we want to trigger cleanup. + helpers.ConnClient.Close() clProxy.wg.Wait() return err } @@ -316,16 +313,22 @@ func (p *modelProxy) auditLogMessage(msg *message, isResponse bool) error { return nil } +func unexpectedReadError(err error) bool { + closeError := websocket.IsUnexpectedCloseError(err, + websocket.CloseNormalClosure, + websocket.CloseNoStatusReceived, + websocket.CloseAbnormalClosure) + _, unmarshalError := err.(*json.InvalidUnmarshalError) + return closeError || unmarshalError +} + // clientProxy proxies messages from client->controller. type clientProxy struct { modelProxy wg sync.WaitGroup errChan chan error createControllerConn func(context.Context) (WebsocketConnectionWithMetadata, error) - // mu synchronises changes to closed and modelproxy.dst, dst is is only created - // at some unspecified point in the future after a client request. - mu sync.Mutex - closed bool + connectController sync.Once } // start begins the client->controller proxier. @@ -339,8 +342,11 @@ func (p *clientProxy) start(ctx context.Context) error { zapctx.Debug(ctx, "Reading on client connection") msg := new(message) if err := p.src.readJson(&msg); err != nil { - // Error reading on the socket implies it is closed, simply return. - return fmt.Errorf("error reading from client: %w", err) + if unexpectedReadError(err) { + zapctx.Error(ctx, "unexpected client read error", zap.Error(err)) + return err + } + return nil } zapctx.Debug(ctx, "Read message from client", zap.Any("message", msg)) err := p.makeControllerConnection(ctx) @@ -387,43 +393,35 @@ func (p *clientProxy) start(ctx context.Context) error { // proxying requests from the controller to the client. func (p *clientProxy) makeControllerConnection(ctx context.Context) error { const op = errors.Op("rpc.makeControllerConnection") - p.mu.Lock() - defer p.mu.Unlock() - if p.dst != nil { - return nil - } - // Checking closed ensures we don't have a race condition with a cancelled context. - if p.closed { - err := errors.E(op, "Client connection closed while starting controller connection") - return err - } - connWithMetadata, err := p.createControllerConn(ctx) - if err != nil { - return err - } - - p.msgs.controllerUUID = connWithMetadata.ControllerUUID + var createConnErr error + // Create the controller connection once. + p.connectController.Do(func() { + connWithMetadata, err := p.createControllerConn(ctx) + if err != nil { + createConnErr = errors.E(op, err) + } - p.modelName = connWithMetadata.ModelName - p.dst = &writeLockConn{conn: connWithMetadata.Conn} - controllerToClient := controllerProxy{ - modelProxy: modelProxy{ - src: p.dst, - dst: p.src, - msgs: p.msgs, - auditLog: p.auditLog, - tokenGen: p.tokenGen, - modelName: p.modelName, - conversationId: p.conversationId, - }, - } - p.wg.Add(1) - go func() { - defer p.wg.Done() - p.errChan <- controllerToClient.start(ctx) - }() - zapctx.Debug(ctx, "Successfully made controller connection") - return nil + p.msgs.controllerUUID = connWithMetadata.ControllerUUID + p.modelName = connWithMetadata.ModelName + p.dst = &writeLockConn{conn: connWithMetadata.Conn} + controllerToClient := controllerProxy{ + modelProxy: modelProxy{ + src: p.dst, + dst: p.src, + msgs: p.msgs, + auditLog: p.auditLog, + tokenGen: p.tokenGen, + modelName: p.modelName, + conversationId: p.conversationId, + }, + } + p.wg.Add(1) + go func() { + defer p.wg.Done() + p.errChan <- controllerToClient.start(ctx) + }() + }) + return createConnErr } // controllerProxy proxies messages from controller->client with the caveat that @@ -438,8 +436,11 @@ func (p *controllerProxy) start(ctx context.Context) error { zapctx.Debug(ctx, "Reading on controller connection") msg := new(message) if err := p.src.readJson(msg); err != nil { - // Error reading on the socket implies it is closed, simply return. - return fmt.Errorf("error reading from controller: %w", err) + if unexpectedReadError(err) { + zapctx.Error(ctx, "unexpected controller read error", zap.Error(err)) + return err + } + return nil } zapctx.Debug(ctx, "Received message from controller", zap.Any("Message", msg)) permissionsRequired, err := checkPermissionsRequired(ctx, msg)