Skip to content

Commit

Permalink
continue with next KDC on communication failure (#399)
Browse files Browse the repository at this point in the history
refactor send to KDC with trying subsequent KDC on failure
  • Loading branch information
jcmturner authored Jul 25, 2020
1 parent 265fb9b commit 260a581
Show file tree
Hide file tree
Showing 3 changed files with 115 additions and 103 deletions.
29 changes: 29 additions & 0 deletions v8/client/client_integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,35 @@ func TestClient_NetworkTimeout(t *testing.T) {
}
}

func TestClient_NetworkTryNextKDC(t *testing.T) {
test.Integration(t)

b, _ := hex.DecodeString(testdata.KEYTAB_TESTUSER1_TEST_GOKRB5)
kt := keytab.New()
kt.Unmarshal(b)
c, _ := config.NewFromString(testdata.KRB5_CONF)
addr := os.Getenv("TEST_KDC_ADDR")
if addr == "" {
addr = testdata.KDC_IP_TEST_GOKRB5
}
// Two out fo three times this should fail the first time.
// So will run login twice to expect at least once the first time it will be to a bad KDC
c.Realms[0].KDC = []string{testdata.KDC_IP_TEST_GOKRB5_BADADDR + ":88",
testdata.KDC_IP_TEST_GOKRB5_BADADDR + ":88",
addr + ":" + testdata.KDC_PORT_TEST_GOKRB5,
}
cl := client.NewWithKeytab("testuser1", "TEST.GOKRB5", kt, c)

err := cl.Login()
if err != nil {
t.Fatal("login failed")
}
err = cl.Login()
if err != nil {
t.Fatal("login failed")
}
}

func TestClient_GetServiceTicket(t *testing.T) {
test.Integration(t)

Expand Down
143 changes: 76 additions & 67 deletions v8/client/network.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"fmt"
"io"
"net"
"strings"
"time"

"github.com/jcmturner/gokrb5/v8/iana/errorcode"
Expand Down Expand Up @@ -67,88 +68,52 @@ func (cl *Client) sendToKDC(b []byte, realm string) ([]byte, error) {
return rb, nil
}

// dialKDCTCP establishes a UDP connection to a KDC.
func dialKDCUDP(count int, kdcs map[int]string) (*net.UDPConn, error) {
i := 1
for i <= count {
udpAddr, err := net.ResolveUDPAddr("udp", kdcs[i])
if err != nil {
return nil, fmt.Errorf("error resolving KDC address: %v", err)
}

conn, err := net.DialTimeout("udp", udpAddr.String(), 5*time.Second)
if err == nil {
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
return nil, err
}
// conn is guaranteed to be a UDPConn
return conn.(*net.UDPConn), nil
}
i++
}
return nil, errors.New("error in getting a UDP connection to any of the KDCs")
}

// dialKDCTCP establishes a TCP connection to a KDC.
func dialKDCTCP(count int, kdcs map[int]string) (*net.TCPConn, error) {
i := 1
for i <= count {
tcpAddr, err := net.ResolveTCPAddr("tcp", kdcs[i])
if err != nil {
return nil, fmt.Errorf("error resolving KDC address: %v", err)
}

conn, err := net.DialTimeout("tcp", tcpAddr.String(), 5*time.Second)
if err == nil {
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
return nil, err
}
// conn is guaranteed to be a TCPConn
return conn.(*net.TCPConn), nil
}
i++
}
return nil, errors.New("error in getting a TCP connection to any of the KDCs")
}

// sendKDCUDP sends bytes to the KDC via UDP.
func (cl *Client) sendKDCUDP(realm string, b []byte) ([]byte, error) {
var r []byte
count, kdcs, err := cl.Config.GetKDCs(realm, false)
if err != nil {
return r, err
}
conn, err := dialKDCUDP(count, kdcs)
_, kdcs, err := cl.Config.GetKDCs(realm, false)
if err != nil {
return r, err
}
r, err = cl.sendUDP(conn, b)
r, err = dialSendUDP(kdcs, b)
if err != nil {
return r, err
}
return checkForKRBError(r)
}

// sendKDCTCP sends bytes to the KDC via TCP.
func (cl *Client) sendKDCTCP(realm string, b []byte) ([]byte, error) {
var r []byte
count, kdcs, err := cl.Config.GetKDCs(realm, true)
if err != nil {
return r, err
}
conn, err := dialKDCTCP(count, kdcs)
if err != nil {
return r, err
}
rb, err := cl.sendTCP(conn, b)
if err != nil {
return r, err
// dialSendUDP establishes a UDP connection to a KDC.
func dialSendUDP(kdcs map[int]string, b []byte) ([]byte, error) {
var errs []string
for i := 1; i <= len(kdcs); i++ {
udpAddr, err := net.ResolveUDPAddr("udp", kdcs[i])
if err != nil {
errs = append(errs, fmt.Sprintf("error resolving KDC address: %v", err))
continue
}

conn, err := net.DialTimeout("udp", udpAddr.String(), 5*time.Second)
if err != nil {
errs = append(errs, fmt.Sprintf("error setting dial timeout on connection to %s: %v", kdcs[i], err))
continue
}
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
errs = append(errs, fmt.Sprintf("error setting deadline on connection to %s: %v", kdcs[i], err))
continue
}
// conn is guaranteed to be a UDPConn
rb, err := sendUDP(conn.(*net.UDPConn), b)
if err != nil {
errs = append(errs, fmt.Sprintf("error sneding to %s: %v", kdcs[i], err))
continue
}
return rb, nil
}
return checkForKRBError(rb)
return nil, fmt.Errorf("error sending to a KDC: %s", strings.Join(errs, "; "))
}

// sendUDP sends bytes to connection over UDP.
func (cl *Client) sendUDP(conn *net.UDPConn, b []byte) ([]byte, error) {
func sendUDP(conn *net.UDPConn, b []byte) ([]byte, error) {
var r []byte
defer conn.Close()
_, err := conn.Write(b)
Expand All @@ -167,8 +132,52 @@ func (cl *Client) sendUDP(conn *net.UDPConn, b []byte) ([]byte, error) {
return r, nil
}

// sendKDCTCP sends bytes to the KDC via TCP.
func (cl *Client) sendKDCTCP(realm string, b []byte) ([]byte, error) {
var r []byte
_, kdcs, err := cl.Config.GetKDCs(realm, true)
if err != nil {
return r, err
}
r, err = dialSendTCP(kdcs, b)
if err != nil {
return r, err
}
return checkForKRBError(r)
}

// dialKDCTCP establishes a TCP connection to a KDC.
func dialSendTCP(kdcs map[int]string, b []byte) ([]byte, error) {
var errs []string
for i := 1; i <= len(kdcs); i++ {
tcpAddr, err := net.ResolveTCPAddr("tcp", kdcs[i])
if err != nil {
errs = append(errs, fmt.Sprintf("error resolving KDC address: %v", err))
continue
}

conn, err := net.DialTimeout("tcp", tcpAddr.String(), 5*time.Second)
if err != nil {
errs = append(errs, fmt.Sprintf("error setting dial timeout on connection to %s: %v", kdcs[i], err))
continue
}
if err := conn.SetDeadline(time.Now().Add(5 * time.Second)); err != nil {
errs = append(errs, fmt.Sprintf("error setting deadline on connection to %s: %v", kdcs[i], err))
continue
}
// conn is guaranteed to be a TCPConn
rb, err := sendTCP(conn.(*net.TCPConn), b)
if err != nil {
errs = append(errs, fmt.Sprintf("error sneding to %s: %v", kdcs[i], err))
continue
}
return rb, nil
}
return nil, errors.New("error in getting a TCP connection to any of the KDCs")
}

// sendTCP sends bytes to connection over TCP.
func (cl *Client) sendTCP(conn *net.TCPConn, b []byte) ([]byte, error) {
func sendTCP(conn *net.TCPConn, b []byte) ([]byte, error) {
defer conn.Close()
var r []byte
// RFC 4120 7.2.2 specifies the first 4 bytes indicate the length of the message in big endian order.
Expand Down
46 changes: 10 additions & 36 deletions v8/client/passwd.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package client

import (
"fmt"
"net"

"github.com/jcmturner/gokrb5/v8/kadmin"
"github.com/jcmturner/gokrb5/v8/messages"
Expand Down Expand Up @@ -55,46 +54,21 @@ func (cl *Client) sendToKPasswd(msg kadmin.Request) (r kadmin.Reply, err error)
if err != nil {
return
}
addr := kps[1]
b, err := msg.Marshal()
if err != nil {
return
}
var rb []byte
if len(b) <= cl.Config.LibDefaults.UDPPreferenceLimit {
return cl.sendKPasswdUDP(b, addr)
}
return cl.sendKPasswdTCP(b, addr)
}

func (cl *Client) sendKPasswdTCP(b []byte, kadmindAddr string) (r kadmin.Reply, err error) {
tcpAddr, err := net.ResolveTCPAddr("tcp", kadmindAddr)
if err != nil {
return
}
conn, err := net.DialTCP("tcp", nil, tcpAddr)
if err != nil {
return
}
rb, err := cl.sendTCP(conn, b)
if err != nil {
return
}
err = r.Unmarshal(rb)
return
}

func (cl *Client) sendKPasswdUDP(b []byte, kadmindAddr string) (r kadmin.Reply, err error) {
udpAddr, err := net.ResolveUDPAddr("udp", kadmindAddr)
if err != nil {
return
}
conn, err := net.DialUDP("udp", nil, udpAddr)
if err != nil {
return
}
rb, err := cl.sendUDP(conn, b)
if err != nil {
return
rb, err = dialSendUDP(kps, b)
if err != nil {
return
}
} else {
rb, err = dialSendTCP(kps, b)
if err != nil {
return
}
}
err = r.Unmarshal(rb)
return
Expand Down

0 comments on commit 260a581

Please sign in to comment.