From 896441b1f6063b3c9307339f256721544bf0164d Mon Sep 17 00:00:00 2001 From: Robin Date: Wed, 3 Jan 2024 16:44:20 +0100 Subject: [PATCH] fix: Switch server and some data races --- pkg/edition/java/netmc/connection.go | 38 ++++-- pkg/edition/java/netmc/writer.go | 1 + .../java/proto/packet/clientsettings.go | 110 +++++------------- pkg/edition/java/proto/state/states.go | 3 +- pkg/edition/java/proto/util/preader.go | 24 ++++ pkg/edition/java/proxy/bungee.go | 2 +- pkg/edition/java/proxy/player.go | 15 ++- pkg/edition/java/proxy/player/settings.go | 2 +- .../java/proxy/session_backend_config.go | 52 +++++---- .../java/proxy/session_backend_login.go | 29 +++-- .../java/proxy/session_backend_play.go | 17 +-- .../java/proxy/session_backend_transition.go | 54 ++++++--- pkg/edition/java/proxy/session_client_auth.go | 23 ++-- .../java/proxy/session_client_config.go | 49 ++++---- pkg/edition/java/proxy/session_client_play.go | 89 ++++++++------ pkg/internal/future/future.go | 60 ++++++++++ pkg/internal/future/future_test.go | 31 +++++ pkg/internal/oncetrue/oncetrue.go | 42 ------- pkg/internal/tablist/tablist.go | 7 +- pkg/internal/tablist/tablist_keyed.go | 2 +- 20 files changed, 381 insertions(+), 269 deletions(-) create mode 100644 pkg/internal/future/future.go create mode 100644 pkg/internal/future/future_test.go delete mode 100644 pkg/internal/oncetrue/oncetrue.go diff --git a/pkg/edition/java/netmc/connection.go b/pkg/edition/java/netmc/connection.go index 41449e12..9eda53af 100644 --- a/pkg/edition/java/netmc/connection.go +++ b/pkg/edition/java/netmc/connection.go @@ -71,6 +71,8 @@ type MinecraftConn interface { // TODO convert to exported struct as this interf PacketWriter Reader() Reader // Only use if you know what you are doing! + Writer() Writer + EnablePlayPacketQueue() } // Closed returns true if the connection is closed. @@ -247,6 +249,7 @@ func (c *minecraftConn) startReadLoop() { } func (c *minecraftConn) Reader() Reader { return c.rd } +func (c *minecraftConn) Writer() Writer { return c.wr } func (c *minecraftConn) SetAutoReading(enabled bool) { c.log.V(1).Info("update auto reading", "enabled", enabled) @@ -293,7 +296,7 @@ func (c *minecraftConn) bufferNoQueue(packet proto.Packet) error { return c.bufferPacket(packet, false) } -func (c *minecraftConn) bufferPacket(packet proto.Packet, queue bool) (err error) { +func (c *minecraftConn) bufferPacket(packet proto.Packet, canQueue bool) (err error) { if Closed(c) { return ErrClosedConn } @@ -302,10 +305,15 @@ func (c *minecraftConn) bufferPacket(packet proto.Packet, queue bool) (err error c.closeOnWriteErr(err, "bufferPacket", fmt.Sprintf("%T", packet)) } }() - if queue && c.playPacketQueue.Queue(packet) { - // Packet was queued, don't write it now - c.log.V(1).Info("queued packet", "packet", fmt.Sprintf("%T", packet)) - return nil + if canQueue { + c.mu.Lock() + playPacketQueue := c.playPacketQueue + c.mu.Unlock() + if playPacketQueue.Queue(packet) { + // Packet was queued, don't write it now + c.log.V(1).Info("queued packet", "packet", fmt.Sprintf("%T", packet)) + return nil + } } _, err = c.wr.WritePacket(packet) return err @@ -469,14 +477,26 @@ func (c *minecraftConn) SetState(s *state.Registry) { } } +func (c *minecraftConn) EnablePlayPacketQueue() { + if c.mu.TryLock() { + defer c.mu.Unlock() + } + c.activatePlayPacketQueue() +} + +// calling function must hold c.mu +func (c *minecraftConn) activatePlayPacketQueue() { + // Activate the play packet queue if not already + if c.playPacketQueue == nil { + c.playPacketQueue = queue.NewPlayPacketQueue(c.protocol, c.Writer().Direction()) + } +} + // ensurePlayPacketQueue ensures the play packet queue is activated or deactivated // when the connection enters or leaves the play state. See PlayPacketQueue struct for more info. func (c *minecraftConn) ensurePlayPacketQueue(newState state.State) { if newState == state.ConfigState { // state exists since 1.20.2+ - // Activate the play packet queue if not already - if c.playPacketQueue == nil { - c.playPacketQueue = queue.NewPlayPacketQueue(c.protocol, c.direction) - } + c.activatePlayPacketQueue() return } diff --git a/pkg/edition/java/netmc/writer.go b/pkg/edition/java/netmc/writer.go index c6660e8a..09b78898 100644 --- a/pkg/edition/java/netmc/writer.go +++ b/pkg/edition/java/netmc/writer.go @@ -22,6 +22,7 @@ type Writer interface { Flush() (err error) StateChanger + Direction() proto.Direction } // NewWriter returns a new packet writer. diff --git a/pkg/edition/java/proto/packet/clientsettings.go b/pkg/edition/java/proto/packet/clientsettings.go index b43558d7..226ff47b 100644 --- a/pkg/edition/java/proto/packet/clientsettings.go +++ b/pkg/edition/java/proto/packet/clientsettings.go @@ -8,108 +8,56 @@ import ( ) type ClientSettings struct { - Locale string // may be empty - ViewDistance byte - ChatVisibility int - ChatColors bool - Difficulty bool // 1.7 Protocol - SkinParts byte - MainHand int - TextFiltering bool // 1.17+ - ClientListing bool // 1.18+, overwrites server-list "anonymous" mode + Locale string // may be empty + ViewDistance byte + ChatVisibility int + ChatColors bool + Difficulty byte // 1.7 Protocol + SkinParts byte + MainHand int + ChatFilteringEnabled bool // 1.17+ + ClientListingAllowed bool // 1.18+, overwrites server-list "anonymous" mode } func (s *ClientSettings) Encode(c *proto.PacketContext, wr io.Writer) error { - err := util.WriteString(wr, s.Locale) - if err != nil { - return err - } - err = util.WriteUint8(wr, s.ViewDistance) - if err != nil { - return err - } - err = util.WriteVarInt(wr, s.ChatVisibility) - if err != nil { - return err - } - err = util.WriteBool(wr, s.ChatColors) - if err != nil { - return err - } + w := util.PanicWriter(wr) + w.String(s.Locale) + w.Byte(s.ViewDistance) + w.VarInt(s.ChatVisibility) + w.Bool(s.ChatColors) if c.Protocol.LowerEqual(version.Minecraft_1_7_6) { - err = util.WriteBool(wr, s.Difficulty) - if err != nil { - return err - } - } - err = util.WriteUint8(wr, s.SkinParts) - if err != nil { - return err + w.Byte(s.Difficulty) } + w.Byte(s.SkinParts) if c.Protocol.GreaterEqual(version.Minecraft_1_9) { - err = util.WriteVarInt(wr, s.MainHand) - if err != nil { - return err - } + w.VarInt(s.MainHand) if c.Protocol.GreaterEqual(version.Minecraft_1_17) { - err = util.WriteBool(wr, s.TextFiltering) - if err != nil { - return err - } + w.Bool(s.ChatFilteringEnabled) } if c.Protocol.GreaterEqual(version.Minecraft_1_18) { - err = util.WriteBool(wr, s.ClientListing) - if err != nil { - return err - } + w.Bool(s.ClientListingAllowed) } } return nil } func (s *ClientSettings) Decode(c *proto.PacketContext, rd io.Reader) (err error) { - s.Locale, err = util.ReadString(rd) - if err != nil { - return err - } - s.ViewDistance, err = util.ReadUint8(rd) - if err != nil { - return err - } - s.ChatVisibility, err = util.ReadVarInt(rd) - if err != nil { - return err - } - s.ChatColors, err = util.ReadBool(rd) - if err != nil { - return err - } + r := util.PanicReader(rd) + r.StringMax(&s.Locale, 16) + r.Byte(&s.ViewDistance) + r.VarInt(&s.ChatVisibility) + r.Bool(&s.ChatColors) if c.Protocol.LowerEqual(version.Minecraft_1_7_6) { - s.Difficulty, err = util.ReadBool(rd) - if err != nil { - return err - } - } - s.SkinParts, err = util.ReadByte(rd) - if err != nil { - return err + r.Byte(&s.Difficulty) } + r.Byte(&s.SkinParts) // Go bytes are unsigned already if c.Protocol.GreaterEqual(version.Minecraft_1_9) { - s.MainHand, err = util.ReadVarInt(rd) - if err != nil { - return err - } + r.VarInt(&s.MainHand) if c.Protocol.GreaterEqual(version.Minecraft_1_17) { - s.TextFiltering, err = util.ReadBool(rd) - if err != nil { - return err - } + r.Bool(&s.ChatFilteringEnabled) } if c.Protocol.GreaterEqual(version.Minecraft_1_18) { - s.ClientListing, err = util.ReadBool(rd) - if err != nil { - return err - } + r.Bool(&s.ClientListingAllowed) } } return nil diff --git a/pkg/edition/java/proto/state/states.go b/pkg/edition/java/proto/state/states.go index 1a524be8..ca210cab 100644 --- a/pkg/edition/java/proto/state/states.go +++ b/pkg/edition/java/proto/state/states.go @@ -200,6 +200,8 @@ func init() { m(0x27, version.Minecraft_1_20_2), m(0x28, version.Minecraft_1_20_3), ) + Play.ServerBound.Register(&config.FinishedUpdate{}, + m(0x0B, version.Minecraft_1_20_2)) Play.ClientBound.Register(&p.KeepAlive{}, m(0x00, version.Minecraft_1_7_2), @@ -459,5 +461,4 @@ func init() { m(0x65, version.Minecraft_1_20_2), m(0x67, version.Minecraft_1_20_3), ) - } diff --git a/pkg/edition/java/proto/util/preader.go b/pkg/edition/java/proto/util/preader.go index 0cb63d7a..6a3db3ff 100644 --- a/pkg/edition/java/proto/util/preader.go +++ b/pkg/edition/java/proto/util/preader.go @@ -18,6 +18,14 @@ func (r *PReader) String(s *string) { PReadString(r.r, s) } +func (r *PReader) StringMax(s *string, max int) { + PReadStringMax(r.r, s, max) +} + +func (r *PReader) Uint8(i *uint8) { + PReadUint8(r.r, i) +} + func (r *PReader) Bytes(b *[]byte) { PReadBytes(r.r, b) } @@ -95,6 +103,22 @@ func PReadString(rd io.Reader, s *string) { *s = v } +func PReadStringMax(rd io.Reader, s *string, max int) { + v, err := ReadStringMax(rd, max) + if err != nil { + panic(err) + } + *s = v +} + +func PReadUint8(rd io.Reader, i *uint8) { + v, err := ReadUint8(rd) + if err != nil { + panic(err) + } + *i = v +} + func PReadBytes(rd io.Reader, b *[]byte) { v, err := ReadBytes(rd) if err != nil { diff --git a/pkg/edition/java/proxy/bungee.go b/pkg/edition/java/proxy/bungee.go index 392e8656..91759c42 100644 --- a/pkg/edition/java/proxy/bungee.go +++ b/pkg/edition/java/proxy/bungee.go @@ -14,7 +14,7 @@ import ( "go.minekube.com/gate/pkg/gate/proto" ) -func bungeeCordMessageResponder( +func newBungeeCordMessageResponder( bungeePluginChannelEnabled bool, player *connectedPlayer, proxy *Proxy, diff --git a/pkg/edition/java/proxy/player.go b/pkg/edition/java/proxy/player.go index db52b700..15bfc054 100644 --- a/pkg/edition/java/proxy/player.go +++ b/pkg/edition/java/proxy/player.go @@ -706,12 +706,15 @@ func (p *connectedPlayer) config() *config.Config { // switchToConfigState switches the connection of the client into config state. func (p *connectedPlayer) switchToConfigState() { - go func() { - if err := p.WritePacket(new(cfgpacket.StartUpdate)); err != nil { - p.log.Error(err, "error writing config packet") - } - p.SetState(state.Config) - }() + if err := p.BufferPacket(new(cfgpacket.StartUpdate)); err != nil { + p.log.Error(err, "error writing config packet") + } + + p.MinecraftConn.Writer().SetState(state.Config) + // Make sure we don't send any play packets to the player after update start + p.MinecraftConn.EnablePlayPacketQueue() + + _ = p.Flush() // Trigger switch finally } func (p *connectedPlayer) ClientBrand() string { diff --git a/pkg/edition/java/proxy/player/settings.go b/pkg/edition/java/proxy/player/settings.go index e73ce3f6..5cbebda5 100644 --- a/pkg/edition/java/proxy/player/settings.go +++ b/pkg/edition/java/proxy/player/settings.go @@ -75,7 +75,7 @@ type clientSettings struct { s *packet.ClientSettings } -func (s *clientSettings) ClientListing() bool { return s.s.ClientListing } +func (s *clientSettings) ClientListing() bool { return s.s.ClientListingAllowed } func (s *clientSettings) SkinParts() SkinParts { return SkinParts(s.s.SkinParts) diff --git a/pkg/edition/java/proxy/session_backend_config.go b/pkg/edition/java/proxy/session_backend_config.go index 264e8723..8e62be50 100644 --- a/pkg/edition/java/proxy/session_backend_config.go +++ b/pkg/edition/java/proxy/session_backend_config.go @@ -2,6 +2,7 @@ package proxy import ( "errors" + "fmt" "github.com/go-logr/logr" "go.minekube.com/gate/pkg/edition/java/netmc" "go.minekube.com/gate/pkg/edition/java/proto/packet" @@ -17,12 +18,11 @@ import ( // This version is to accommodate 1.20.2+ switching. It handles the transition of a player between servers in a proxy setup. // This is a complex process that involves multiple stages and can be interrupted by various events, such as plugin messages or resource pack requests. type backendConfigSessionHandler struct { - serverConn *serverConnection - requestCtx *connRequestCxt - state backendConfigSessionState - resourcePackToApply *ResourcePackInfo - playerConfigSessionHandler *clientConfigSessionHandler - log logr.Logger + serverConn *serverConnection + requestCtx *connRequestCxt + state backendConfigSessionState + resourcePackToApply *ResourcePackInfo + log logr.Logger nopSessionHandler } @@ -32,17 +32,11 @@ func newBackendConfigSessionHandler( serverConn *serverConnection, requestCtx *connRequestCxt, ) (netmc.SessionHandler, error) { - clientSession, ok := serverConn.player.ActiveSessionHandler().(*clientConfigSessionHandler) - if !ok { - return nil, errors.New("initializing backend config session handler with non-client config session handler") - } - return &backendConfigSessionHandler{ - serverConn: serverConn, - state: backendConfigSessionStateStart, - requestCtx: requestCtx, - playerConfigSessionHandler: clientSession, - log: serverConn.log.WithName("backendConfigSessionHandler"), + serverConn: serverConn, + state: backendConfigSessionStateStart, + requestCtx: requestCtx, + log: serverConn.log.WithName("backendConfigSessionHandler"), }, nil } @@ -127,18 +121,29 @@ func (b *backendConfigSessionHandler) handleFinishedUpdate(p *config.FinishedUpd return } player := b.serverConn.player - configHandler := b.playerConfigSessionHandler + + activehandler := player.ActiveSessionHandler() + configHandler, ok := activehandler.(*clientConfigSessionHandler) + if !ok { + err := fmt.Errorf("expected client config session handler, got %T", activehandler) + b.log.Error(err, "error handling finished update packet") + b.serverConn.disconnect() + b.requestCtx.result(nil, err) + return + } smc.SetAutoReading(false) - // Even when not auto reading messages are still decoded. Decode them with the correct state smc.Reader().SetState(state.Play) - configHandler.handleBackendFinishUpdate(b.serverConn, p, func() { + configHandler.handleBackendFinishUpdate(b.serverConn, p).ThenAccept(func(any) { defer smc.SetAutoReading(true) if b.serverConn == player.connectedServer() { if !smc.SwitchSessionHandler(state.Play) { err := errors.New("failed to switch session handler") b.log.Error(err, "expected to switch session handler to play state") + b.serverConn.disconnect() + b.requestCtx.result(nil, err) + return } header, footer := player.tabList.HeaderFooter() @@ -150,12 +155,15 @@ func (b *backendConfigSessionHandler) handleFinishedUpdate(p *config.FinishedUpd // The client cleared the tab list. // TODO: Restore changes done via TabList API - player.tabList.DeleteEntries() + err = player.tabList.RemoveAll() + if err != nil { + b.log.Error(err, "error removing all tab list entries") + return + } } else { smc.SetActiveSessionHandler(state.Play, newBackendTransitionSessionHandler( - b.serverConn, b.requestCtx, - b.proxy().Event(), b.proxy(), + b.serverConn, b.requestCtx, b.proxy(), ), ) } diff --git a/pkg/edition/java/proxy/session_backend_login.go b/pkg/edition/java/proxy/session_backend_login.go index 2723cc06..d7e8939b 100644 --- a/pkg/edition/java/proxy/session_backend_login.go +++ b/pkg/edition/java/proxy/session_backend_login.go @@ -331,23 +331,38 @@ func (b *backendLoginSessionHandler) handleServerLoginSuccess() { if serverMc.Protocol().Lower(version.Minecraft_1_20_2) { serverMc.SetActiveSessionHandler(state.Play, - newBackendTransitionSessionHandler(b.serverConn, b.requestCtx, b.eventMgr, b.proxy)) + newBackendTransitionSessionHandler(b.serverConn, b.requestCtx, b.proxy)) } else { - _ = serverMc.WritePacket(&packet.LoginAcknowledged{}) - sh, err := newBackendConfigSessionHandler(b.serverConn, b.requestCtx) - if err != nil { + fail := func(err error) { + b.log.V(1).Error(err, "error transitioning to backend config state") b.requestCtx.result(nil, err) b.serverConn.disconnect() + } + err := serverMc.WritePacket(&packet.LoginAcknowledged{}) + if err != nil { + fail(err) + return + } + sh, err := newBackendConfigSessionHandler(b.serverConn, b.requestCtx) + if err != nil { + fail(err) return } serverMc.SetActiveSessionHandler(state.Config, sh) player := b.serverConn.player if pkt := player.ClientSettingsPacket(); pkt != nil { - _ = serverMc.WritePacket(pkt) + err = serverMc.WritePacket(pkt) + if err != nil { + fail(err) + return + } } - if csh, ok := player.MinecraftConn.ActiveSessionHandler().(*clientPlaySessionHandler); ok { + + ash := player.ActiveSessionHandler() + csh, ok := ash.(*clientPlaySessionHandler) + if ok { serverMc.SetAutoReading(false) - csh.doSwitch().DoWhenTrue(func() { + csh.doSwitch().ThenAccept(func(any) { serverMc.SetAutoReading(true) }) } diff --git a/pkg/edition/java/proxy/session_backend_play.go b/pkg/edition/java/proxy/session_backend_play.go index a0585747..b6ed92c3 100644 --- a/pkg/edition/java/proxy/session_backend_play.go +++ b/pkg/edition/java/proxy/session_backend_play.go @@ -3,7 +3,6 @@ package proxy import ( "context" "encoding/hex" - "errors" "fmt" "go.minekube.com/gate/pkg/edition/java/proto/packet/chat" "go.minekube.com/gate/pkg/edition/java/proto/packet/config" @@ -36,17 +35,18 @@ type backendPlaySessionHandler struct { } func newBackendPlaySessionHandler(serverConn *serverConnection) (netmc.SessionHandler, error) { - cpsh, ok := serverConn.player.ActiveSessionHandler().(*clientPlaySessionHandler) + activeHandler := serverConn.player.ActiveSessionHandler() + psh, ok := activeHandler.(*clientPlaySessionHandler) if !ok { - return nil, errors.New("initializing backendPlaySessionHandler with no backing client play session handler") + return nil, fmt.Errorf("initializing backendPlaySessionHandler with no backing client play session handler, got %T", activeHandler) } return &backendPlaySessionHandler{ serverConn: serverConn, - bungeeCordMessageResponder: bungeeCordMessageResponder( + bungeeCordMessageResponder: newBungeeCordMessageResponder( serverConn.config().BungeePluginChannelEnabled, serverConn.player, serverConn.player.proxy, ), - playerSessionHandler: cpsh, + playerSessionHandler: psh, }, nil } @@ -154,12 +154,15 @@ func (b *backendPlaySessionHandler) handleClientSettings(p *packet.ClientSetting } func (b *backendPlaySessionHandler) handleBossBar(p *bossbar.BossBar, pc *proto.PacketContext) { + b.playerSessionHandler.mu.Lock() switch p.Action { case bossbar.AddAction: - b.playerSessionHandler.serverBossBars[p.ID] = struct{}{} + b.playerSessionHandler.mu.serverBossBars[p.ID] = struct{}{} case bossbar.RemoveAction: - delete(b.playerSessionHandler.serverBossBars, p.ID) + delete(b.playerSessionHandler.mu.serverBossBars, p.ID) + default: } + b.playerSessionHandler.mu.Unlock() b.forwardToPlayer(pc, nil) // forward on } diff --git a/pkg/edition/java/proxy/session_backend_transition.go b/pkg/edition/java/proxy/session_backend_transition.go index 9a17421d..08858add 100644 --- a/pkg/edition/java/proxy/session_backend_transition.go +++ b/pkg/edition/java/proxy/session_backend_transition.go @@ -3,19 +3,19 @@ package proxy import ( "errors" "fmt" - "go.minekube.com/gate/pkg/edition/java/proto/state" - "go.minekube.com/gate/pkg/edition/java/proto/version" - "reflect" - "github.com/go-logr/logr" "github.com/robinbraemer/event" "go.minekube.com/gate/pkg/edition/java/netmc" "go.minekube.com/gate/pkg/edition/java/proto/packet" "go.minekube.com/gate/pkg/edition/java/proto/packet/plugin" + "go.minekube.com/gate/pkg/edition/java/proto/state" + "go.minekube.com/gate/pkg/edition/java/proto/version" "go.minekube.com/gate/pkg/edition/java/proxy/bungeecord" - phase "go.minekube.com/gate/pkg/edition/java/proxy/phase" + "go.minekube.com/gate/pkg/edition/java/proxy/phase" "go.minekube.com/gate/pkg/edition/java/proxy/tablist" "go.minekube.com/gate/pkg/gate/proto" + "reflect" + "time" ) type backendTransitionSessionHandler struct { @@ -31,14 +31,13 @@ type backendTransitionSessionHandler struct { func newBackendTransitionSessionHandler( serverConn *serverConnection, requestCtx *connRequestCxt, - eventMgr event.Manager, proxy *Proxy, ) netmc.SessionHandler { return &backendTransitionSessionHandler{ - eventMgr: eventMgr, + eventMgr: proxy.Event(), serverConn: serverConn, requestCtx: requestCtx, - bungeeCordMessageRecorder: bungeeCordMessageResponder( + bungeeCordMessageRecorder: newBungeeCordMessageResponder( serverConn.config().BungeePluginChannelEnabled, serverConn.player, proxy, ), @@ -224,20 +223,43 @@ func (b *backendTransitionSessionHandler) handleJoinGame(pc *proto.PacketContext "previousAddr", previousServer.ServerInfo().Addr()) } - // Change client to use ClientPlaySessionHandler if required. - playHandler, ok := b.serverConn.player.MinecraftConn.ActiveSessionHandler().(*clientPlaySessionHandler) - if !ok { - playHandler = newClientPlaySessionHandler(b.serverConn.player) - b.serverConn.player.MinecraftConn.SetActiveSessionHandler(state.Play, playHandler) + var playHandler *clientPlaySessionHandler + const ( + maxWait = time.Second * 3 + tick = time.Millisecond * 100 + ) + for waited := time.Duration(0); ; { + if !b.serverConn.active() { + failResult("server connection is no longer active") + return + } + if waited >= maxWait { + failResult("timed out waiting for client play session handler to be set") + return + } + playHandler, ok = b.serverConn.player.MinecraftConn.ActiveSessionHandler().(*clientPlaySessionHandler) + if ok { + break + } + b.log.V(1).Info("waiting for client play session handler to be set") + time.Sleep(tick) + waited += tick } + // Change client to use ClientPlaySessionHandler if required. + //playHandler, ok := b.serverConn.player.MinecraftConn.ActiveSessionHandler().(*clientPlaySessionHandler) + //if !ok { + // playHandler = newClientPlaySessionHandler(b.serverConn.player) + // b.serverConn.player.MinecraftConn.SetActiveSessionHandler(state.Play, playHandler) + //} + if err := playHandler.handleBackendJoinGame(pc, p, b.serverConn); err != nil { failResult("JoinGame packet could not be handled, client-side switching server failed: %w", err) return // not handled } - // Strap on the correct session handler for the server. - // We will have nothing more to do with this connection once this task finishes up. + // Set the new play session handler for the server. We will have nothing more to do + // with this connection once this task finishes up. backendPlay, err := newBackendPlaySessionHandler(b.serverConn) if err != nil { failResult("error creating backend play session handler: %w", err) @@ -254,7 +276,7 @@ func (b *backendTransitionSessionHandler) handleJoinGame(pc *proto.PacketContext // Send client settings. In 1.20.2+ this is done in the config state. if smc.Protocol().Lower(version.Minecraft_1_20_2) { if csp := b.serverConn.player.ClientSettingsPacket(); csp != nil { - smc, ok := b.serverConn.ensureConnected() + smc, ok = b.serverConn.ensureConnected() if ok { _ = smc.WritePacket(csp) } diff --git a/pkg/edition/java/proxy/session_client_auth.go b/pkg/edition/java/proxy/session_client_auth.go index eacad17e..1505931c 100644 --- a/pkg/edition/java/proxy/session_client_auth.go +++ b/pkg/edition/java/proxy/session_client_auth.go @@ -13,6 +13,7 @@ import ( "go.minekube.com/gate/pkg/edition/java/proxy/crypto" "go.minekube.com/gate/pkg/gate/proto" "go.minekube.com/gate/pkg/util/uuid" + "sync/atomic" ) type authSessionHandler struct { @@ -23,17 +24,17 @@ type authSessionHandler struct { profile *profile.GameProfile onlineMode bool - loginState authLoginState // 1.20.2+ + loginState *atomic.Pointer[authLoginState] // 1.20.2+ connectedPlayer *connectedPlayer } type authLoginState int -const ( - startAuthLoginState authLoginState = iota - successSentAuthLoginState - acknowledgedAuthLoginState +var ( + startAuthLoginState authLoginState = 0 + successSentAuthLoginState authLoginState = 1 + acknowledgedAuthLoginState authLoginState = 2 ) type playerRegistrar interface { @@ -48,8 +49,10 @@ func newAuthSessionHandler( onlineMode bool, sessionHandlerDeps *sessionHandlerDeps, ) netmc.SessionHandler { + var defaultState atomic.Pointer[authLoginState] + defaultState.Store(&startAuthLoginState) return &authSessionHandler{ - loginState: startAuthLoginState, + loginState: &defaultState, sessionHandlerDeps: sessionHandlerDeps, log: logr.FromContextOrDiscard(inbound.Context()).WithName("authSession"), inbound: inbound, @@ -185,10 +188,10 @@ func (a *authSessionHandler) completeLoginProtocolPhaseAndInitialize(player *con return } - a.loginState = successSentAuthLoginState + a.loginState.Store(&successSentAuthLoginState) if a.inbound.Protocol().Lower(version.Minecraft_1_20_2) { - a.loginState = acknowledgedAuthLoginState + a.loginState.Store(&acknowledgedAuthLoginState) a.connectedPlayer.MinecraftConn.SetActiveSessionHandler(state.Play, newInitialConnectSessionHandler(a.connectedPlayer)) @@ -244,12 +247,12 @@ func (a *authSessionHandler) config() *config.Config { } func (a *authSessionHandler) handleLoginAcknowledged() bool { - if a.loginState != successSentAuthLoginState { + if *a.loginState.Load() != successSentAuthLoginState { _ = a.inbound.disconnect(&component.Translation{ Key: "multiplayer.disconnect.invalid_player_data", }) } else { - a.loginState = acknowledgedAuthLoginState + a.loginState.Store(&acknowledgedAuthLoginState) a.connectedPlayer.MinecraftConn.SetActiveSessionHandler(state.Config, newClientConfigSessionHandler(a.connectedPlayer)) diff --git a/pkg/edition/java/proxy/session_client_config.go b/pkg/edition/java/proxy/session_client_config.go index 35e8a63d..2e45da10 100644 --- a/pkg/edition/java/proxy/session_client_config.go +++ b/pkg/edition/java/proxy/session_client_config.go @@ -9,7 +9,7 @@ import ( "go.minekube.com/gate/pkg/edition/java/proto/state" "go.minekube.com/gate/pkg/edition/java/proto/util" "go.minekube.com/gate/pkg/gate/proto" - "go.minekube.com/gate/pkg/internal/oncetrue" + "go.minekube.com/gate/pkg/internal/future" ) type clientConfigSessionHandler struct { @@ -17,7 +17,7 @@ type clientConfigSessionHandler struct { brandChannel string - configSwitchDone oncetrue.OnceWhenTrue + configSwitchDone future.Future[any] nopSessionHandler } @@ -49,7 +49,7 @@ func (h *clientConfigSessionHandler) HandlePacket(pc *proto.PacketContext) { h.handleResourcePackResponse(p) case *config.FinishedUpdate: h.player.SetActiveSessionHandler(state.Play, newClientPlaySessionHandler(h.player)) - h.configSwitchDone.SetTrue() + h.configSwitchDone.Complete(nil) case *plugin.Message: h.handlePluginMessage(p) case *packet.PingIdentify: @@ -65,34 +65,33 @@ func (h *clientConfigSessionHandler) HandlePacket(pc *proto.PacketContext) { } // handleBackendFinishUpdate handles the backend finishing the config stage. -func (h *clientConfigSessionHandler) handleBackendFinishUpdate( - serverConn *serverConnection, - p *config.FinishedUpdate, - onConfigSwitch func(), -) { +func (h *clientConfigSessionHandler) handleBackendFinishUpdate(serverConn *serverConnection, p *config.FinishedUpdate) *future.Future[any] { smc, ok := serverConn.ensureConnected() - if ok { - brand := serverConn.player.ClientBrand() - if brand == "" && h.brandChannel != "" { - buf := new(bytes.Buffer) - _ = util.WriteString(buf, brand) - - brandPacket := &plugin.Message{ - Channel: h.brandChannel, - Data: buf.Bytes(), - } - _ = smc.WritePacket(brandPacket) - } - err := smc.WritePacket(p) - if err != nil { - return + if !ok { + return nil + } + brand := serverConn.player.ClientBrand() + if brand == "" && h.brandChannel != "" { + buf := new(bytes.Buffer) + _ = util.WriteString(buf, brand) + + brandPacket := &plugin.Message{ + Channel: h.brandChannel, + Data: buf.Bytes(), } + _ = smc.WritePacket(brandPacket) } + if err := h.player.WritePacket(p); err != nil { - return + return nil + } + + if err := smc.WritePacket(p); err != nil { + return nil } + smc.Writer().SetState(state.Play) - h.configSwitchDone.DoWhenTrue(onConfigSwitch) + return &h.configSwitchDone } func (h *clientConfigSessionHandler) handleResourcePackResponse(p *packet.ResourcePackResponse) { diff --git a/pkg/edition/java/proxy/session_client_play.go b/pkg/edition/java/proxy/session_client_play.go index 5323b91b..89249ef4 100644 --- a/pkg/edition/java/proxy/session_client_play.go +++ b/pkg/edition/java/proxy/session_client_play.go @@ -5,9 +5,11 @@ import ( "errors" "fmt" "go.minekube.com/gate/pkg/edition/java/proto/packet/config" - "go.minekube.com/gate/pkg/internal/oncetrue" + "go.minekube.com/gate/pkg/edition/java/proxy/tablist" + "go.minekube.com/gate/pkg/internal/future" "sort" "strings" + "sync" "time" "github.com/gammazero/deque" @@ -36,26 +38,28 @@ import ( // Handles communication with the connected Minecraft client. // This is effectively the primary nerve center that joins backend servers with players. type clientPlaySessionHandler struct { - log, log1 logr.Logger - player *connectedPlayer - spawned atomic.Bool - loginPluginMessages deque.Deque[*plugin.Message] - chatHandler *chatHandler - chatTimeKeeper chatTimeKeeper + log, log1 logr.Logger + player *connectedPlayer + spawned atomic.Bool + chatHandler *chatHandler + chatTimeKeeper chatTimeKeeper + outstandingTabComplete *packet.TabCompleteRequest - configSwitchDone oncetrue.OnceWhenTrue + configSwitchFuture future.Future[any] - serverBossBars map[uuid.UUID]struct{} - outstandingTabComplete *packet.TabCompleteRequest + mu struct { + sync.RWMutex + loginPluginMessages deque.Deque[*plugin.Message] + serverBossBars map[uuid.UUID]struct{} + } } func newClientPlaySessionHandler(player *connectedPlayer) *clientPlaySessionHandler { log := player.log.WithName("clientPlaySession") - return &clientPlaySessionHandler{ - player: player, - log: log, - log1: log.V(1), - serverBossBars: map[uuid.UUID]struct{}{}, + h := &clientPlaySessionHandler{ + player: player, + log: log, + log1: log.V(1), chatHandler: &chatHandler{ log: log, eventMgr: player.eventMgr, @@ -64,6 +68,8 @@ func newClientPlaySessionHandler(player *connectedPlayer) *clientPlaySessionHand configProvider: player.proxy, }, } + h.mu.serverBossBars = map[uuid.UUID]struct{}{} + return h } var _ netmc.SessionHandler = (*clientPlaySessionHandler)(nil) @@ -98,14 +104,16 @@ func (c *clientPlaySessionHandler) HandlePacket(pc *proto.PacketContext) { c.player.setClientSettings(p) c.forwardToServer(pc) // forward to server case *config.FinishedUpdate: - c.handleFinishUpdate(p) + c.handleFinishedUpdate(p) default: c.forwardToServer(pc) } } func (c *clientPlaySessionHandler) Deactivated() { - c.loginPluginMessages.Clear() + c.mu.Lock() + c.mu.loginPluginMessages.Clear() + c.mu.Unlock() } func (c *clientPlaySessionHandler) Activated() { @@ -154,8 +162,10 @@ func (c *clientPlaySessionHandler) FlushQueuedPluginMessages() { if !ok { return } - for c.loginPluginMessages.Len() != 0 { - pm := c.loginPluginMessages.PopFront() + c.mu.Lock() + defer c.mu.Unlock() + for c.mu.loginPluginMessages.Len() != 0 { + pm := c.mu.loginPluginMessages.PopFront() _ = serverMc.BufferPacket(pm) } _ = serverMc.Flush() @@ -297,7 +307,9 @@ func (c *clientPlaySessionHandler) handlePluginMessage(packet *plugin.Message) { // // We also need to make sure to retain these packets, so they can be flushed // appropriately. - c.loginPluginMessages.PushBack(packet) + c.mu.Lock() + c.mu.loginPluginMessages.PushBack(packet) + c.mu.Unlock() } } @@ -361,7 +373,9 @@ func (c *clientPlaySessionHandler) handleBackendJoinGame(pc *proto.PacketContext // Remove previous boss bars. These don't get cleared when sending JoinGame, thus the need to // track them. - for barID := range c.serverBossBars { + c.mu.Lock() + defer c.mu.Unlock() + for barID := range c.mu.serverBossBars { deletePacket := &bossbar.BossBar{ ID: barID, Action: bossbar.RemoveAction, @@ -370,7 +384,7 @@ func (c *clientPlaySessionHandler) handleBackendJoinGame(pc *proto.PacketContext return fmt.Errorf("error buffering boss bar remove packet for player: %w", err) } } - c.serverBossBars = make(map[uuid.UUID]struct{}) // clear + c.mu.serverBossBars = make(map[uuid.UUID]struct{}) // clear // Tell the server about the proxy's plugin message channels. serverVersion := serverMc.Protocol() @@ -383,8 +397,8 @@ func (c *clientPlaySessionHandler) handleBackendJoinGame(pc *proto.PacketContext } // If we had plugin messages queued during login/FML handshake, send them now. - for c.loginPluginMessages.Len() != 0 { - pm := c.loginPluginMessages.PopFront() + for c.mu.loginPluginMessages.Len() != 0 { + pm := c.mu.loginPluginMessages.PopFront() if err = serverMc.BufferPacket(pm); err != nil { return fmt.Errorf("error buffering %T for backend: %w", pm, err) } @@ -767,7 +781,7 @@ func (c *clientPlaySessionHandler) updateTimeKeeper(t time.Time) bool { return true } -func (c *clientPlaySessionHandler) handleFinishUpdate(p *config.FinishedUpdate) { +func (c *clientPlaySessionHandler) handleFinishedUpdate(p *config.FinishedUpdate) { // Complete client switch if !c.player.MinecraftConn.SwitchSessionHandler(state.Config) { panic("expected client to have config session handler") @@ -783,15 +797,18 @@ func (c *clientPlaySessionHandler) handleFinishUpdate(p *config.FinishedUpdate) if !smc.SwitchSessionHandler(state.Config) { err := errors.New("failed to switch session handler") c.log.Error(err, "expected to switch session handler to config state") + panic(err) } smc.SetAutoReading(true) }() } - c.configSwitchDone.SetTrue() + c.configSwitchFuture.Complete(nil) } // doSwitch handles switching stages for swapping between servers. -func (c *clientPlaySessionHandler) doSwitch() *oncetrue.OnceWhenTrue { +func (c *clientPlaySessionHandler) doSwitch() *future.Future[any] { + c.log.V(1).WithName("doSwitch").Info("switching servers") + existingConn := c.player.connectedServer() if existingConn != nil { // Shut down the existing server connection. @@ -800,17 +817,19 @@ func (c *clientPlaySessionHandler) doSwitch() *oncetrue.OnceWhenTrue { // Send keep alive to try to avoid timeouts if netmc.SendKeepAlive(c.player) != nil { - return new(oncetrue.OnceWhenTrue) + return future.New[any]().Complete(nil) } - // Reset TabList header and footer to prevent de-sync - if err := c.player.tabList.SetHeaderFooter(nil, nil); err != nil { - c.log.Error(err, "error resetting tablist header and footer") - return new(oncetrue.OnceWhenTrue) - } + // Config state clears everything in the client. No need to clear later. + c.spawned.Store(false) + c.mu.Lock() + c.mu.serverBossBars = make(map[uuid.UUID]struct{}) // clear + c.mu.Unlock() + _ = c.player.tabList.RemoveAll() + _ = tablist.ClearTabListHeaderFooter(c.player.tabList) } - c.spawned.Store(false) c.player.switchToConfigState() - return &c.configSwitchDone + + return &c.configSwitchFuture } diff --git a/pkg/internal/future/future.go b/pkg/internal/future/future.go new file mode 100644 index 00000000..ee3482d9 --- /dev/null +++ b/pkg/internal/future/future.go @@ -0,0 +1,60 @@ +package future + +import ( + "sync" +) + +// Future is a struct that holds a value of type T, a slice of callbacks to be called when the value is set, +// a boolean to track if the callbacks have been called, and a mutex for thread safety. +type Future[T any] struct { + value T // The value that completes the Future + callback []func(T) // The callbacks that get called when the Future is completed + completed bool // A flag to check if the Future is completed + mu sync.Mutex // Mutex for thread safety +} + +// New returns a new Future. +func New[T any]() *Future[T] { + return &Future[T]{} +} + +// ThenAccept registers a callback to be called when the Future is completed. +func (f *Future[T]) ThenAccept(callback func(T)) { + f.mu.Lock() + defer f.mu.Unlock() + + // If the Future is already completed, call the callback + if f.completed { + callback(f.value) + } else { + // Append the new callback to the slice of callbacks + f.callback = append(f.callback, callback) + } +} + +// ThenAcceptParallel registers a callback to be called in a separate goroutine when the Future is completed. +func (f *Future[T]) ThenAcceptParallel(callback func(T)) { + f.ThenAccept(func(t T) { + go callback(t) + }) +} + +// Complete sets the value and calls the registered callbacks if they haven't been called yet. +func (f *Future[T]) Complete(value T) *Future[T] { + f.mu.Lock() + defer f.mu.Unlock() + + // Check if the Future is already completed + if f.completed { + return f + } + + // Set the value and call the callbacks + f.value = value + f.completed = true + for _, fn := range f.callback { + fn(value) + } + + return f +} diff --git a/pkg/internal/future/future_test.go b/pkg/internal/future/future_test.go new file mode 100644 index 00000000..6451e3dd --- /dev/null +++ b/pkg/internal/future/future_test.go @@ -0,0 +1,31 @@ +package future + +import ( + "github.com/stretchr/testify/assert" + "testing" +) + +func TestFuture(t *testing.T) { + t.Run("ThenAccept", func(t *testing.T) { + f := &Future[int]{} + var result int + f.ThenAccept(func(value int) { + result = value + }) + f.Complete(10) + assert.Equal(t, 10, result) + + // Test that the callback is immediately called when the Future is already completed + f.ThenAccept(func(value int) { + result = 20 + }) + assert.Equal(t, 20, result) + }) + + t.Run("Complete", func(t *testing.T) { + f := &Future[int]{} + f.Complete(30) + assert.Equal(t, 30, f.value) + assert.True(t, f.completed) + }) +} diff --git a/pkg/internal/oncetrue/oncetrue.go b/pkg/internal/oncetrue/oncetrue.go deleted file mode 100644 index 3ba25db9..00000000 --- a/pkg/internal/oncetrue/oncetrue.go +++ /dev/null @@ -1,42 +0,0 @@ -package oncetrue - -import ( - "sync" -) - -type OnceWhenTrue struct { - condition bool - onTrue func() - called bool - mu sync.Mutex -} - -func NewOnceWhenTrue() *OnceWhenTrue { - return &OnceWhenTrue{} -} - -func (o *OnceWhenTrue) DoWhenTrue(onTrue func()) { - o.mu.Lock() - defer o.mu.Unlock() - - o.onTrue = onTrue - - // If condition is true and onTrue hasn't been called, call it - if o.condition && !o.called { - o.onTrue() - o.called = true - } -} - -func (o *OnceWhenTrue) SetTrue() { - o.mu.Lock() - defer o.mu.Unlock() - - o.condition = true - - // If onTrue is set and hasn't been called, call it - if o.onTrue != nil && !o.called { - o.onTrue() - o.called = true - } -} diff --git a/pkg/internal/tablist/tablist.go b/pkg/internal/tablist/tablist.go index 70463c84..c952707f 100644 --- a/pkg/internal/tablist/tablist.go +++ b/pkg/internal/tablist/tablist.go @@ -58,9 +58,6 @@ type InternalTabList interface { EmitActionRaw(action playerinfo.UpsertAction, entry *playerinfo.Entry) error UpdateEntry(action legacytablist.PlayerListItemAction, entry tablist.Entry) error - // DeleteEntries deletes the entries with the given ids without sending a packet. - DeleteEntries(ids ...uuid.UUID) []uuid.UUID - Parent() InternalTabList // Used to resolve the parent root struct of an embedded tab list struct } @@ -145,7 +142,7 @@ func (t *TabList) UpdateEntry(action legacytablist.PlayerListItemAction, entry t } func (t *TabList) RemoveAll(ids ...uuid.UUID) error { - if toRemove := t.DeleteEntries(ids...); len(toRemove) != 0 { + if toRemove := t.deleteEntries(ids...); len(toRemove) != 0 { return t.Viewer.BufferPacket(&playerinfo.Remove{ PlayersToRemove: toRemove, }) @@ -153,7 +150,7 @@ func (t *TabList) RemoveAll(ids ...uuid.UUID) error { return nil } -func (t *TabList) DeleteEntries(ids ...uuid.UUID) []uuid.UUID { +func (t *TabList) deleteEntries(ids ...uuid.UUID) []uuid.UUID { t.Lock() defer t.Unlock() if len(ids) == 0 { // Delete all diff --git a/pkg/internal/tablist/tablist_keyed.go b/pkg/internal/tablist/tablist_keyed.go index 4018ed0a..15c90773 100644 --- a/pkg/internal/tablist/tablist_keyed.go +++ b/pkg/internal/tablist/tablist_keyed.go @@ -63,7 +63,7 @@ func (k *KeyedTabList) Add(entries ...tablist.Entry) error { } func (k *KeyedTabList) RemoveAll(ids ...uuid.UUID) error { - toRemove := k.TabList.DeleteEntries(ids...) + toRemove := k.TabList.deleteEntries(ids...) items := make([]legacytablist.PlayerListItemEntry, 0, len(toRemove)) for _, id := range toRemove { items = append(items, legacytablist.PlayerListItemEntry{