From 260a581c95704286bc96d20fda24cedf91054a04 Mon Sep 17 00:00:00 2001 From: Jonathan Turner Date: Sat, 25 Jul 2020 14:37:20 +0100 Subject: [PATCH] continue with next KDC on communication failure (#399) refactor send to KDC with trying subsequent KDC on failure --- v8/client/client_integration_test.go | 29 ++++++ v8/client/network.go | 143 ++++++++++++++------------- v8/client/passwd.go | 46 ++------- 3 files changed, 115 insertions(+), 103 deletions(-) diff --git a/v8/client/client_integration_test.go b/v8/client/client_integration_test.go index c241c833..e369bba5 100644 --- a/v8/client/client_integration_test.go +++ b/v8/client/client_integration_test.go @@ -253,6 +253,35 @@ func TestClient_NetworkTimeout(t *testing.T) { } } +func TestClient_NetworkTryNextKDC(t *testing.T) { + test.Integration(t) + + b, _ := hex.DecodeString(testdata.KEYTAB_TESTUSER1_TEST_GOKRB5) + kt := keytab.New() + kt.Unmarshal(b) + c, _ := config.NewFromString(testdata.KRB5_CONF) + addr := os.Getenv("TEST_KDC_ADDR") + if addr == "" { + addr = testdata.KDC_IP_TEST_GOKRB5 + } + // Two out fo three times this should fail the first time. + // So will run login twice to expect at least once the first time it will be to a bad KDC + c.Realms[0].KDC = []string{testdata.KDC_IP_TEST_GOKRB5_BADADDR + ":88", + testdata.KDC_IP_TEST_GOKRB5_BADADDR + ":88", + addr + ":" + testdata.KDC_PORT_TEST_GOKRB5, + } + cl := client.NewWithKeytab("testuser1", "TEST.GOKRB5", kt, c) + + err := cl.Login() + if err != nil { + t.Fatal("login failed") + } + err = cl.Login() + if err != nil { + t.Fatal("login failed") + } +} + func TestClient_GetServiceTicket(t *testing.T) { test.Integration(t) diff --git a/v8/client/network.go b/v8/client/network.go index 009f6f15..634f015c 100644 --- a/v8/client/network.go +++ b/v8/client/network.go @@ -6,6 +6,7 @@ import ( "fmt" "io" "net" + "strings" "time" "github.com/jcmturner/gokrb5/v8/iana/errorcode" @@ -67,88 +68,52 @@ func (cl *Client) sendToKDC(b []byte, realm string) ([]byte, error) { return rb, nil } -// dialKDCTCP establishes a UDP connection to a KDC. -func dialKDCUDP(count int, kdcs map[int]string) (*net.UDPConn, error) { - i := 1 - for i <= count { - udpAddr, err := net.ResolveUDPAddr("udp", kdcs[i]) - if err != nil { - return nil, fmt.Errorf("error resolving KDC address: %v", err) - } - - conn, err := net.DialTimeout("udp", udpAddr.String(), 5*time.Second) - if err == nil { - if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { - return nil, err - } - // conn is guaranteed to be a UDPConn - return conn.(*net.UDPConn), nil - } - i++ - } - return nil, errors.New("error in getting a UDP connection to any of the KDCs") -} - -// dialKDCTCP establishes a TCP connection to a KDC. -func dialKDCTCP(count int, kdcs map[int]string) (*net.TCPConn, error) { - i := 1 - for i <= count { - tcpAddr, err := net.ResolveTCPAddr("tcp", kdcs[i]) - if err != nil { - return nil, fmt.Errorf("error resolving KDC address: %v", err) - } - - conn, err := net.DialTimeout("tcp", tcpAddr.String(), 5*time.Second) - if err == nil { - if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { - return nil, err - } - // conn is guaranteed to be a TCPConn - return conn.(*net.TCPConn), nil - } - i++ - } - return nil, errors.New("error in getting a TCP connection to any of the KDCs") -} - // sendKDCUDP sends bytes to the KDC via UDP. func (cl *Client) sendKDCUDP(realm string, b []byte) ([]byte, error) { var r []byte - count, kdcs, err := cl.Config.GetKDCs(realm, false) - if err != nil { - return r, err - } - conn, err := dialKDCUDP(count, kdcs) + _, kdcs, err := cl.Config.GetKDCs(realm, false) if err != nil { return r, err } - r, err = cl.sendUDP(conn, b) + r, err = dialSendUDP(kdcs, b) if err != nil { return r, err } return checkForKRBError(r) } -// sendKDCTCP sends bytes to the KDC via TCP. -func (cl *Client) sendKDCTCP(realm string, b []byte) ([]byte, error) { - var r []byte - count, kdcs, err := cl.Config.GetKDCs(realm, true) - if err != nil { - return r, err - } - conn, err := dialKDCTCP(count, kdcs) - if err != nil { - return r, err - } - rb, err := cl.sendTCP(conn, b) - if err != nil { - return r, err +// dialSendUDP establishes a UDP connection to a KDC. +func dialSendUDP(kdcs map[int]string, b []byte) ([]byte, error) { + var errs []string + for i := 1; i <= len(kdcs); i++ { + udpAddr, err := net.ResolveUDPAddr("udp", kdcs[i]) + if err != nil { + errs = append(errs, fmt.Sprintf("error resolving KDC address: %v", err)) + continue + } + + conn, err := net.DialTimeout("udp", udpAddr.String(), 5*time.Second) + if err != nil { + errs = append(errs, fmt.Sprintf("error setting dial timeout on connection to %s: %v", kdcs[i], err)) + continue + } + if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { + errs = append(errs, fmt.Sprintf("error setting deadline on connection to %s: %v", kdcs[i], err)) + continue + } + // conn is guaranteed to be a UDPConn + rb, err := sendUDP(conn.(*net.UDPConn), b) + if err != nil { + errs = append(errs, fmt.Sprintf("error sneding to %s: %v", kdcs[i], err)) + continue + } + return rb, nil } - return checkForKRBError(rb) + return nil, fmt.Errorf("error sending to a KDC: %s", strings.Join(errs, "; ")) } // sendUDP sends bytes to connection over UDP. -func (cl *Client) sendUDP(conn *net.UDPConn, b []byte) ([]byte, error) { +func sendUDP(conn *net.UDPConn, b []byte) ([]byte, error) { var r []byte defer conn.Close() _, err := conn.Write(b) @@ -167,8 +132,52 @@ func (cl *Client) sendUDP(conn *net.UDPConn, b []byte) ([]byte, error) { return r, nil } +// sendKDCTCP sends bytes to the KDC via TCP. +func (cl *Client) sendKDCTCP(realm string, b []byte) ([]byte, error) { + var r []byte + _, kdcs, err := cl.Config.GetKDCs(realm, true) + if err != nil { + return r, err + } + r, err = dialSendTCP(kdcs, b) + if err != nil { + return r, err + } + return checkForKRBError(r) +} + +// dialKDCTCP establishes a TCP connection to a KDC. +func dialSendTCP(kdcs map[int]string, b []byte) ([]byte, error) { + var errs []string + for i := 1; i <= len(kdcs); i++ { + tcpAddr, err := net.ResolveTCPAddr("tcp", kdcs[i]) + if err != nil { + errs = append(errs, fmt.Sprintf("error resolving KDC address: %v", err)) + continue + } + + conn, err := net.DialTimeout("tcp", tcpAddr.String(), 5*time.Second) + if err != nil { + errs = append(errs, fmt.Sprintf("error setting dial timeout on connection to %s: %v", kdcs[i], err)) + continue + } + if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil { + errs = append(errs, fmt.Sprintf("error setting deadline on connection to %s: %v", kdcs[i], err)) + continue + } + // conn is guaranteed to be a TCPConn + rb, err := sendTCP(conn.(*net.TCPConn), b) + if err != nil { + errs = append(errs, fmt.Sprintf("error sneding to %s: %v", kdcs[i], err)) + continue + } + return rb, nil + } + return nil, errors.New("error in getting a TCP connection to any of the KDCs") +} + // sendTCP sends bytes to connection over TCP. -func (cl *Client) sendTCP(conn *net.TCPConn, b []byte) ([]byte, error) { +func sendTCP(conn *net.TCPConn, b []byte) ([]byte, error) { defer conn.Close() var r []byte // RFC 4120 7.2.2 specifies the first 4 bytes indicate the length of the message in big endian order. diff --git a/v8/client/passwd.go b/v8/client/passwd.go index efb1593b..fe20559c 100644 --- a/v8/client/passwd.go +++ b/v8/client/passwd.go @@ -2,7 +2,6 @@ package client import ( "fmt" - "net" "github.com/jcmturner/gokrb5/v8/kadmin" "github.com/jcmturner/gokrb5/v8/messages" @@ -55,46 +54,21 @@ func (cl *Client) sendToKPasswd(msg kadmin.Request) (r kadmin.Reply, err error) if err != nil { return } - addr := kps[1] b, err := msg.Marshal() if err != nil { return } + var rb []byte if len(b) <= cl.Config.LibDefaults.UDPPreferenceLimit { - return cl.sendKPasswdUDP(b, addr) - } - return cl.sendKPasswdTCP(b, addr) -} - -func (cl *Client) sendKPasswdTCP(b []byte, kadmindAddr string) (r kadmin.Reply, err error) { - tcpAddr, err := net.ResolveTCPAddr("tcp", kadmindAddr) - if err != nil { - return - } - conn, err := net.DialTCP("tcp", nil, tcpAddr) - if err != nil { - return - } - rb, err := cl.sendTCP(conn, b) - if err != nil { - return - } - err = r.Unmarshal(rb) - return -} - -func (cl *Client) sendKPasswdUDP(b []byte, kadmindAddr string) (r kadmin.Reply, err error) { - udpAddr, err := net.ResolveUDPAddr("udp", kadmindAddr) - if err != nil { - return - } - conn, err := net.DialUDP("udp", nil, udpAddr) - if err != nil { - return - } - rb, err := cl.sendUDP(conn, b) - if err != nil { - return + rb, err = dialSendUDP(kps, b) + if err != nil { + return + } + } else { + rb, err = dialSendTCP(kps, b) + if err != nil { + return + } } err = r.Unmarshal(rb) return