diff --git a/README.md b/README.md index 588873f..9acf3e9 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,7 @@ Usage: Flags: -h, --help help for query + -k, --insecure-skip-verify allow insecure server connections (e.g. self-signed TLS certificates) --resolver-addr string address of a DNS resolver to use for resolving DoH server names (e.g. 8.8.8.8:53) --resolver-network string protocol to use for resolving DoH server names (e.g. udp, tcp) (default "udp") --retry-max int maximum number of retries for each query (default 10) diff --git a/internal/cli/command_query.go b/internal/cli/command_query.go index 49b6b2b..fb7a3f8 100644 --- a/internal/cli/command_query.go +++ b/internal/cli/command_query.go @@ -2,6 +2,7 @@ package cli import ( "context" + "crypto/tls" "encoding/json" "fmt" "net" @@ -22,7 +23,8 @@ type result struct { Resp *dj.Response `json:"resp"` } -func newClient(retryMax int) (*http.Client, error) { +// newHTTPClient returns a new HTTP client, or an error if one occurs. +func newHTTPClient(retryMax int, insecureSkipVerify bool) (*http.Client, error) { retryClient := retryablehttp.NewClient() retryClient.RetryMax = retryMax @@ -31,6 +33,16 @@ func newClient(retryMax int) (*http.Client, error) { retryClient.Logger = nil // TODO: consider logger + if insecureSkipVerify { + transport := cleanhttp.DefaultTransport() + + transport.TLSClientConfig = &tls.Config{ + InsecureSkipVerify: true, + } + + retryClient.HTTPClient.Transport = transport + } + retryClient.CheckRetry = func(ctx context.Context, resp *http.Response, err error) (bool, error) { return retryablehttp.DefaultRetryPolicy(ctx, resp, err) } @@ -76,7 +88,12 @@ which can be piped to other commands (e.g. jq) or redirected to a file.`, return fmt.Errorf("invalid retry max: %w", err) } - httpClient, err := newClient(retryMax) + insecureSkipVerify, err := cmd.Flags().GetBool("insecure-skip-verify") + if err != nil { + return fmt.Errorf("invalid insecure skip verify: %w", err) + } + + httpClient, err := newHTTPClient(retryMax, insecureSkipVerify) if err != nil { return fmt.Errorf("error creating http client: %w", err) } @@ -172,6 +189,7 @@ func init() { CommandQuery.Flags().String("resolver-addr", "", "address of a DNS resolver to use for resolving DoH server names (e.g. 8.8.8.8:53)") CommandQuery.Flags().String("resolver-network", "udp", "protocol to use for resolving DoH server names (e.g. udp, tcp)") CommandQuery.Flags().Int("retry-max", 10, "maximum number of retries for each query") + CommandQuery.Flags().BoolP("insecure-skip-verify", "k", false, "allow insecure server connections (e.g. self-signed TLS certificates)") CommandRoot.AddCommand(CommandQuery) } diff --git a/internal/cli/command_test.go b/internal/cli/command_test.go index 5255e60..37e7097 100644 --- a/internal/cli/command_test.go +++ b/internal/cli/command_test.go @@ -3,9 +3,14 @@ package cli_test import ( "bytes" "io" + "net" + "net/http" + "net/http/httptest" "testing" + "github.com/miekg/dns" "github.com/picatz/doh/internal/cli" + "github.com/picatz/doh/pkg/doh" ) func testCommand(t *testing.T, args ...string) io.Reader { @@ -46,7 +51,7 @@ func TestCommand(t *testing.T) { }, }, { - name: "google.com", + name: "query google.com", args: []string{"query", "google.com"}, check: func(t *testing.T, output io.Reader) { b, err := io.ReadAll(output) @@ -62,7 +67,7 @@ func TestCommand(t *testing.T) { }, }, { - name: "cloudflare.com", + name: "query cloudflare.com", args: []string{"query", "cloudflare.com"}, check: func(t *testing.T, output io.Reader) { b, err := io.ReadAll(output) @@ -87,3 +92,39 @@ func TestCommand(t *testing.T) { }) } } + +func TestCommand_Query_InsecureSkipVerify(t *testing.T) { + mux := doh.NewServerMux(func(w http.ResponseWriter, httpReq *http.Request, dnsReq *dns.Msg) (*dns.Msg, error) { + dnsResp := new(dns.Msg) + dnsResp.SetReply(dnsReq) + dnsResp.Answer = append(dnsResp.Answer, &dns.A{ + Hdr: dns.RR_Header{ + Name: dnsReq.Question[0].Name, + Rrtype: dns.TypeA, + Class: dns.ClassINET, + Ttl: 300, + }, + A: net.ParseIP("8.8.8.8"), + }) + + return dnsResp, nil + }) + + server := httptest.NewTLSServer(mux) + t.Cleanup(server.Close) + + dohServerURL := server.URL + "/dns-query" + + output := testCommand(t, "query", "google.com", "--insecure-skip-verify", "--servers", dohServerURL) + + b, err := io.ReadAll(output) + if err != nil { + t.Fatal(err) + } + + if len(b) == 0 { + t.Fatal("got no output for known domain") + } + + t.Log(string(b)) +}