Skip to content

Commit

Permalink
Add -4 and -6 flags (redo)
Browse files Browse the repository at this point in the history
This PR is a redo of #1802. Since that PR has been idle so long,
the branches have diverged quite a bit and it was easier to start
anew.

The work in this PR includes the work originally done by @dmke in #1802.

This PR is to resolve #1801.
  • Loading branch information
jsumners committed Aug 3, 2023
1 parent 35c259e commit 63e3bcc
Show file tree
Hide file tree
Showing 11 changed files with 347 additions and 21 deletions.
2 changes: 2 additions & 0 deletions .golangci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,8 @@ issues:
text: 'dnsTimeout is a global variable'
- path: challenge/dns01/nameserver_test.go
text: 'findXByFqdnTestCases is a global variable'
- path: challenge/dns01/network.go
text: 'currentNetworkStack is a global variable'
- path: challenge/http01/domain_matcher.go
text: 'string `Host` has \d occurrences, make it a constant'
- path: challenge/http01/domain_matcher.go
Expand Down
14 changes: 10 additions & 4 deletions challenge/dns01/nameserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,17 +253,23 @@ func createDNSMsg(fqdn string, rtype uint16, recursive bool) *dns.Msg {

func sendDNSQuery(m *dns.Msg, ns string) (*dns.Msg, error) {
if ok, _ := strconv.ParseBool(os.Getenv("LEGO_EXPERIMENTAL_DNS_TCP_ONLY")); ok {
tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
network := currentNetworkStack.Network("tcp")
tcp := &dns.Client{Net: network, Timeout: dnsTimeout}
in, _, err := tcp.Exchange(m, ns)

return in, err
}

udp := &dns.Client{Net: "udp", Timeout: dnsTimeout}
udpNetwork := currentNetworkStack.Network("udp")
udp := &dns.Client{Net: udpNetwork, Timeout: dnsTimeout}
in, _, err := udp.Exchange(m, ns)

if in != nil && in.Truncated {
tcp := &dns.Client{Net: "tcp", Timeout: dnsTimeout}
// We can encounter a net.OpError if the nameserver is not listening
// on UDP at all, i.e. net.Dial could not make a connection.
var opErr *net.OpError
if (in != nil && in.Truncated) || errors.As(err, &opErr) {
tcpNetwork := currentNetworkStack.Network("tcp")
tcp := &dns.Client{Net: tcpNetwork, Timeout: dnsTimeout}
// If the TCP request succeeds, the "err" will reset to nil
in, _, err = tcp.Exchange(m, ns)
}
Expand Down
143 changes: 133 additions & 10 deletions challenge/dns01/nameserver_test.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,133 @@
package dns01

import (
"net"
"sort"
"sync"
"testing"

"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func testDNSHandler(writer dns.ResponseWriter, reply *dns.Msg) {
msg := dns.Msg{}
msg.SetReply(reply)

if reply.Question[0].Qtype == dns.TypeA {
msg.Authoritative = true
domain := msg.Question[0].Name
msg.Answer = append(
msg.Answer,
&dns.A{
Hdr: dns.RR_Header{
Name: domain,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 60,
},
A: net.IPv4(127, 0, 0, 1),
},
)
}

_ = writer.WriteMsg(&msg)
}

// getTestNameserver constructs a new DNS server on a local address, or set
// of addresses, that responds to an `A` query for `example.com`.
func getTestNameserver(t *testing.T, network string) *dns.Server {
t.Helper()
server := &dns.Server{
Handler: dns.HandlerFunc(testDNSHandler),
Net: network,
}
switch network {
case "tcp", "udp":
server.Addr = "0.0.0.0:0"
case "tcp4", "udp4":
server.Addr = "127.0.0.1:0"
case "tcp6", "udp6":
server.Addr = "[::1]:0"
}

waitLock := sync.Mutex{}
waitLock.Lock()
server.NotifyStartedFunc = waitLock.Unlock

go func() { _ = server.ListenAndServe() }()

waitLock.Lock()
return server
}

func startTestNameserver(t *testing.T, stack networkStack, proto string) (shutdown func(), addr string) {
t.Helper()
currentNetworkStack = stack
srv := getTestNameserver(t, currentNetworkStack.Network(proto))

shutdown = func() { _ = srv.Shutdown() }
if proto == "tcp" {
addr = srv.Listener.Addr().String()
} else {
addr = srv.PacketConn.LocalAddr().String()
}
return
}

func TestSendDNSQuery(t *testing.T) {
currentNameservers := recursiveNameservers

t.Cleanup(func() {
recursiveNameservers = currentNameservers
currentNetworkStack = dualStack
})

t.Run("does udp4 only", func(t *testing.T) {
stop, addr := startTestNameserver(t, ipv4only, "udp")
defer stop()

recursiveNameservers = ParseNameservers([]string{addr})
msg := createDNSMsg("example.com.", dns.TypeA, true)
result, queryError := sendDNSQuery(msg, addr)
require.NoError(t, queryError)
assert.Equal(t, result.Answer[0].(*dns.A).A.String(), "127.0.0.1")
})

t.Run("does udp6 only", func(t *testing.T) {
stop, addr := startTestNameserver(t, ipv6only, "udp")
defer stop()

recursiveNameservers = ParseNameservers([]string{addr})
msg := createDNSMsg("example.com.", dns.TypeA, true)
result, queryError := sendDNSQuery(msg, addr)
require.NoError(t, queryError)
assert.Equal(t, result.Answer[0].(*dns.A).A.String(), "127.0.0.1")
})

t.Run("does tcp4 and tcp6", func(t *testing.T) {
stop, addr := startTestNameserver(t, dualStack, "tcp")
host, port, _ := net.SplitHostPort(addr)
defer stop()
t.Logf("### port: %s", port)

addr6 := net.JoinHostPort(host, port)
recursiveNameservers = ParseNameservers([]string{addr6})
msg := createDNSMsg("example.com.", dns.TypeA, true)
result, queryError := sendDNSQuery(msg, addr6)
require.NoError(t, queryError)
assert.Equal(t, result.Answer[0].(*dns.A).A.String(), "127.0.0.1")

addr4 := net.JoinHostPort("127.0.0.1", port)
recursiveNameservers = ParseNameservers([]string{addr4})
msg = createDNSMsg("example.com.", dns.TypeA, true)
result, queryError = sendDNSQuery(msg, addr4)
require.NoError(t, queryError)
assert.Equal(t, result.Answer[0].(*dns.A).A.String(), "127.0.0.1")
})
}

func TestLookupNameserversOK(t *testing.T) {
testCases := []struct {
fqdn string
Expand Down Expand Up @@ -74,7 +194,7 @@ var findXByFqdnTestCases = []struct {
zone string
primaryNs string
nameservers []string
expectedError string
expectedError string // regular expression
}{
{
desc: "domain is a CNAME",
Expand Down Expand Up @@ -109,7 +229,7 @@ var findXByFqdnTestCases = []struct {
fqdn: "test.lego.zz.",
zone: "lego.zz.",
nameservers: []string{"8.8.8.8:53"},
expectedError: "could not find the start of authority for test.lego.zz.: NXDOMAIN",
expectedError: `^could not find the start of authority for test\.lego\.zz.: NXDOMAIN`,
},
{
desc: "several non existent nameservers",
Expand All @@ -119,18 +239,21 @@ var findXByFqdnTestCases = []struct {
nameservers: []string{":7053", ":8053", "8.8.8.8:53"},
},
{
desc: "only non-existent nameservers",
fqdn: "mail.google.com.",
zone: "google.com.",
nameservers: []string{":7053", ":8053", ":9053"},
expectedError: "could not find the start of authority for mail.google.com.: read udp",
desc: "only non-existent nameservers",
fqdn: "mail.google.com.",
zone: "google.com.",
nameservers: []string{":7053", ":8053", ":9053"},
// NOTE: On Windows, net.DialContext finds a way down to the ContectEx syscall.
// There a fault is marked as "connectex", not "connect", see
// https://cs.opensource.google/go/go/+/refs/tags/go1.19.5:src/net/fd_windows.go;l=112
expectedError: `^could not find the start of authority for mail\.google\.com.: dial tcp :9053: connect(ex)?:`,
},
{
desc: "no nameservers",
fqdn: "test.ldez.com.",
zone: "ldez.com.",
nameservers: []string{},
expectedError: "could not find the start of authority for test.ldez.com.",
expectedError: `^could not find the start of authority for test\.ldez\.com\.`,
},
}

Expand All @@ -142,7 +265,7 @@ func TestFindZoneByFqdnCustom(t *testing.T) {
zone, err := FindZoneByFqdnCustom(test.fqdn, test.nameservers)
if test.expectedError != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), test.expectedError)
assert.Regexp(t, test.expectedError, err.Error())
} else {
require.NoError(t, err)
assert.Equal(t, test.zone, zone)
Expand All @@ -159,7 +282,7 @@ func TestFindPrimaryNsByFqdnCustom(t *testing.T) {
ns, err := FindPrimaryNsByFqdnCustom(test.fqdn, test.nameservers)
if test.expectedError != "" {
require.Error(t, err)
assert.Contains(t, err.Error(), test.expectedError)
assert.Regexp(t, test.expectedError, err.Error())
} else {
require.NoError(t, err)
assert.Equal(t, test.primaryNs, ns)
Expand Down
41 changes: 41 additions & 0 deletions challenge/dns01/network.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
package dns01

// networkStack is used to indicate which IP stack should be used for DNS queries.
type networkStack int

const (
dualStack networkStack = iota
ipv4only
ipv6only
)

// currentNetworkStack is used to define which IP stack will be used. The default is
// both IPv4 and IPv6. Set to IPv4Only or IPv6Only to select either version.
var currentNetworkStack = dualStack

// Network interprets the NetworkStack setting in relation to the desired
// protocol. The proto value should be either "udp" or "tcp".
func (s networkStack) Network(proto string) string {
// The DNS client passes whatever value is set in (*dns.Client).Net to
// the [net.Dialer](https://github.com/miekg/dns/blob/fe20d5d/client.go#L119-L141).
// And the net.Dialer accepts strings such as "udp4" or "tcp6"
// (https://cs.opensource.google/go/go/+/refs/tags/go1.18.9:src/net/dial.go;l=167-182).
switch s {
case ipv4only:
return proto + "4"
case ipv6only:
return proto + "6"
default:
return proto
}
}

// SetIPv4Only forces DNS queries to only happen over the IPv4 stack.
func SetIPv4Only() { currentNetworkStack = ipv4only }

// SetIPv6Only forces DNS queries to only happen over the IPv6 stack.
func SetIPv6Only() { currentNetworkStack = ipv6only }

// SetDualStack indicates that both IPv4 and IPv6 should be allowed.
// This setting lets the OS determine which IP stack to use.
func SetDualStack() { currentNetworkStack = dualStack }
22 changes: 22 additions & 0 deletions challenge/http01/http_challenge_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,28 @@ func NewUnixProviderServer(socketPath string, mode fs.FileMode) *ProviderServer
return &ProviderServer{network: "unix", address: socketPath, socketMode: mode, matcher: &hostMatcher{}}
}

// SetIPv4Only starts the challenge server on an IPv4 address.
//
// Calling this method has no effect if s was created with NewUnixProviderServer.
func (s *ProviderServer) SetIPv4Only() { s.setTCPStack("tcp4") }

// SetIPv6Only starts the challenge server on an IPv6 address.
//
// Calling this method has no effect if s was created with NewUnixProviderServer.
func (s *ProviderServer) SetIPv6Only() { s.setTCPStack("tcp6") }

// SetDualStack indicates that both IPv4 and IPv6 should be allowed.
// This setting lets the OS determine which IP stack to use for the challenge server.
//
// Calling this method has no effect if s was created with NewUnixProviderServer.
func (s *ProviderServer) SetDualStack() { s.setTCPStack("tcp") }

func (s *ProviderServer) setTCPStack(network string) {
if s.network != "unix" {
s.network = network
}
}

// Present starts a web server and makes the token available at `ChallengePath(token)` for web requests.
func (s *ProviderServer) Present(domain, token, keyAuth string) error {
var err error
Expand Down
17 changes: 17 additions & 0 deletions challenge/http01/http_challenge_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ func TestProviderServer_GetAddress(t *testing.T) {
testCases := []struct {
desc string
server *ProviderServer
network func(server *ProviderServer)
expected string
}{
{
Expand All @@ -49,6 +50,18 @@ func TestProviderServer_GetAddress(t *testing.T) {
server: NewProviderServer("localhost", "8080"),
expected: "localhost:8080",
},
{
desc: "TCP4 with host and port",
server: NewProviderServer("localhost", "8080"),
network: func(s *ProviderServer) { s.SetIPv4Only() },
expected: "localhost:8080",
},
{
desc: "TCP6 with host and port",
server: NewProviderServer("localhost", "8080"),
network: func(s *ProviderServer) { s.SetIPv6Only() },
expected: "localhost:8080",
},
{
desc: "UDS socket",
server: NewUnixProviderServer(sock, fs.ModeSocket|0o666),
Expand All @@ -61,6 +74,10 @@ func TestProviderServer_GetAddress(t *testing.T) {
t.Run(test.desc, func(t *testing.T) {
t.Parallel()

if test.network != nil {
test.network(test.server)
}

address := test.server.GetAddress()
assert.Equal(t, test.expected, address)
})
Expand Down
18 changes: 16 additions & 2 deletions challenge/tlsalpn01/tls_alpn_challenge_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,30 @@ const (
type ProviderServer struct {
iface string
port string
network string
listener net.Listener
}

// NewProviderServer creates a new ProviderServer on the selected interface and port.
// Setting iface and / or port to an empty string will make the server fall back to
// the "any" interface and port 443 respectively.
func NewProviderServer(iface, port string) *ProviderServer {
return &ProviderServer{iface: iface, port: port}
if port == "" {
port = defaultTLSPort
}
return &ProviderServer{iface: iface, port: port, network: "tcp"}
}

// SetIPv4Only starts the challenge server on an IPv4 address.
func (s *ProviderServer) SetIPv4Only() { s.network = "tcp4" }

// SetIPv6Only starts the challenge server on an IPv6 address.
func (s *ProviderServer) SetIPv6Only() { s.network = "tcp6" }

// SetDualStack indicates that both IPv4 and IPv6 should be allowed.
// This setting lets the OS determine which IP stack to use for the challenge server.
func (s *ProviderServer) SetDualStack() { s.network = "tcp" }

func (s *ProviderServer) GetAddress() string {
return net.JoinHostPort(s.iface, s.port)
}
Expand Down Expand Up @@ -65,7 +79,7 @@ func (s *ProviderServer) Present(domain, token, keyAuth string) error {
tlsConf.NextProtos = []string{ACMETLS1Protocol}

// Create the listener with the created tls.Config.
s.listener, err = tls.Listen("tcp", s.GetAddress(), tlsConf)
s.listener, err = tls.Listen(s.network, s.GetAddress(), tlsConf)
if err != nil {
return fmt.Errorf("could not start HTTPS server for challenge: %w", err)
}
Expand Down
Loading

0 comments on commit 63e3bcc

Please sign in to comment.