From e9a1e8974d4cbc43d16bbda5ca71db764aca4655 Mon Sep 17 00:00:00 2001 From: ThinkChaos Date: Tue, 19 Dec 2023 19:50:17 -0500 Subject: [PATCH] feat(api): support client name lookup when querying via the API --- api/api_interface_impl.go | 33 +++++++++++++++-- api/api_interface_impl_test.go | 66 +++++++++++++++++++++++++++++++--- server/server_endpoints.go | 32 +++-------------- util/http.go | 20 +++++++++++ util/http_test.go | 45 +++++++++++++++++++++++ 5 files changed, 162 insertions(+), 34 deletions(-) create mode 100644 util/http.go create mode 100644 util/http_test.go diff --git a/api/api_interface_impl.go b/api/api_interface_impl.go index 109c9c5d9..fb1a3fba0 100644 --- a/api/api_interface_impl.go +++ b/api/api_interface_impl.go @@ -7,6 +7,8 @@ package api import ( "context" "fmt" + "net" + "net/http" "strings" "time" @@ -17,6 +19,8 @@ import ( "github.com/miekg/dns" ) +type httpReqCtxKey struct{} + // BlockingStatus represents the current blocking status type BlockingStatus struct { // True if blocking is enabled @@ -40,7 +44,9 @@ type ListRefresher interface { } type Querier interface { - Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error) + Query( + ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type, + ) (*model.Response, error) } type CacheControl interface { @@ -48,7 +54,17 @@ type CacheControl interface { } func RegisterOpenAPIEndpoints(router chi.Router, impl StrictServerInterface) { - HandlerFromMuxWithBaseURL(NewStrictHandler(impl, nil), router, "/api") + middleware := []StrictMiddlewareFunc{ctxWithHTTPRequestMiddleware} + + HandlerFromMuxWithBaseURL(NewStrictHandler(impl, middleware), router, "/api") +} + +func ctxWithHTTPRequestMiddleware(handler StrictHandlerFunc, operationID string) StrictHandlerFunc { + return func(ctx context.Context, w http.ResponseWriter, r *http.Request, request any) (response any, err error) { + ctx = context.WithValue(ctx, httpReqCtxKey{}, r) + + return handler(ctx, w, r, request) + } } type OpenAPIInterfaceImpl struct { @@ -143,7 +159,18 @@ func (i *OpenAPIInterfaceImpl) Query(ctx context.Context, request QueryRequestOb return Query400TextResponse(fmt.Sprintf("unknown query type '%s'", request.Body.Type)), nil } - resp, err := i.querier.Query(ctx, dns.Fqdn(request.Body.Query), qType) + var ( + serverHost string + clientIP net.IP + ) + + httpReq, ok := ctx.Value(httpReqCtxKey{}).(*http.Request) + if ok { + serverHost = httpReq.Host + clientIP = util.HTTPClientIP(httpReq) + } + + resp, err := i.querier.Query(ctx, serverHost, clientIP, dns.Fqdn(request.Body.Query), qType) if err != nil { return nil, err } diff --git a/api/api_interface_impl_test.go b/api/api_interface_impl_test.go index 9115bc698..a1b141382 100644 --- a/api/api_interface_impl_test.go +++ b/api/api_interface_impl_test.go @@ -3,10 +3,13 @@ package api import ( "context" "errors" + "net" + "net/http" "time" "github.com/0xERR0R/blocky/model" "github.com/0xERR0R/blocky/util" + "github.com/go-chi/chi/v5" "github.com/miekg/dns" "github.com/stretchr/testify/mock" @@ -53,10 +56,17 @@ func (m *BlockingControlMock) BlockingStatus() BlockingStatus { return args.Get(0).(BlockingStatus) } -func (m *QuerierMock) Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error) { - args := m.Called(ctx, question, qType) +func (m *QuerierMock) Query( + ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type, +) (*model.Response, error) { + args := m.Called(ctx, serverHost, clientIP, question, qType) - return args.Get(0).(*model.Response), args.Error(1) + err := args.Error(1) + if err != nil { + return nil, err + } + + return args.Get(0).(*model.Response), nil } func (m *CacheControlMock) FlushCaches(ctx context.Context) { @@ -92,6 +102,34 @@ var _ = Describe("API implementation tests", func() { listRefreshMock.AssertExpectations(GinkgoT()) }) + Describe("RegisterOpenAPIEndpoints", func() { + It("adds routes", func() { + rtr := chi.NewRouter() + RegisterOpenAPIEndpoints(rtr, sut) + + Expect(rtr.Routes()).ShouldNot(BeEmpty()) + }) + }) + + Describe("ctxWithHTTPRequestMiddleware", func() { + It("adds the request to the context", func() { + handler := func(ctx context.Context, _ http.ResponseWriter, r *http.Request, _ any) (any, error) { + Expect(ctx.Value(httpReqCtxKey{})).Should(BeIdenticalTo(r)) + + return nil, nil //nolint:nilnil + } + + handler = ctxWithHTTPRequestMiddleware(handler, "operation-id") + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://example.com", nil) + Expect(err).Should(Succeed()) + + resp, err := handler(ctx, nil, req, nil) + Expect(err).Should(Succeed()) + Expect(resp).Should(BeNil()) + }) + }) + Describe("Query API", func() { When("Query is called", func() { It("should return 200 on success", func() { @@ -100,7 +138,7 @@ var _ = Describe("API implementation tests", func() { ) Expect(err).Should(Succeed()) - querierMock.On("Query", ctx, "google.com.", A).Return(&model.Response{ + querierMock.On("Query", ctx, "", net.IP(nil), "google.com.", A).Return(&model.Response{ Res: queryResponse, Reason: "reason", }, nil) @@ -120,6 +158,26 @@ var _ = Describe("API implementation tests", func() { Expect(resp200.ReturnCode).Should(Equal("NOERROR")) }) + It("extracts metadata from the HTTP request", func() { + r, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://blocky.localhost", nil) + Expect(err).Should(Succeed()) + + clientIP := net.IPv4allrouter + r.RemoteAddr = net.JoinHostPort(clientIP.String(), "89685") + + ctx = context.WithValue(ctx, httpReqCtxKey{}, r) + + expectedErr := errors.New("test") + querierMock.On("Query", ctx, "blocky.localhost", clientIP, "example.com.", A).Return(nil, expectedErr) + + _, err = sut.Query(ctx, QueryRequestObject{ + Body: &ApiQueryRequest{ + Query: "example.com", Type: "A", + }, + }) + Expect(err).Should(MatchError(expectedErr)) + }) + It("should return 400 on wrong parameter", func() { resp, err := sut.Query(ctx, QueryRequestObject{ Body: &ApiQueryRequest{ diff --git a/server/server_endpoints.go b/server/server_endpoints.go index 589f22b55..cca9ba60c 100644 --- a/server/server_endpoints.go +++ b/server/server_endpoints.go @@ -8,7 +8,6 @@ import ( "io" "net" "net/http" - "strings" "time" "github.com/0xERR0R/blocky/resolver" @@ -148,7 +147,7 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h clientID = extractClientIDFromHost(req.Host) } - r := newRequest(net.ParseIP(extractIP(req)), model.RequestProtocolTCP, clientID, msg) + r := newRequest(util.HTTPClientIP(req), model.RequestProtocolTCP, clientID, msg) resResponse, err := s.queryResolver.Resolve(req.Context(), r) if err != nil { @@ -157,11 +156,6 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h return } - response := new(dns.Msg) - response.SetReply(msg) - // enable compression - resResponse.Res.Compress = true - b, err := resResponse.Res.Pack() if err != nil { logAndResponseWithError(err, "can't serialize message: ", rw) @@ -175,27 +169,11 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h logAndResponseWithError(err, "can't write response: ", rw) } -func extractIP(r *http.Request) string { - hostPort := r.Header.Get("X-FORWARDED-FOR") - - if hostPort == "" { - hostPort = r.RemoteAddr - } - - hostPort = strings.ReplaceAll(hostPort, "[", "") - hostPort = strings.ReplaceAll(hostPort, "]", "") - index := strings.LastIndex(hostPort, ":") - - if index >= 0 { - return hostPort[:index] - } - - return hostPort -} - -func (s *Server) Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error) { +func (s *Server) Query( + ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type, +) (*model.Response, error) { dnsRequest := util.NewMsgWithQuestion(question, qType) - r := createResolverRequest(nil, dnsRequest) + r := newRequest(clientIP, model.RequestProtocolTCP, extractClientIDFromHost(serverHost), dnsRequest) return s.queryResolver.Resolve(ctx, r) } diff --git a/util/http.go b/util/http.go new file mode 100644 index 000000000..736177239 --- /dev/null +++ b/util/http.go @@ -0,0 +1,20 @@ +package util + +import ( + "net" + "net/http" +) + +func HTTPClientIP(r *http.Request) net.IP { + addr := r.Header.Get("X-FORWARDED-FOR") + if addr == "" { + addr = r.RemoteAddr + } + + ip, _, err := net.SplitHostPort(addr) + if err != nil { + return net.ParseIP(addr) + } + + return net.ParseIP(ip) +} diff --git a/util/http_test.go b/util/http_test.go new file mode 100644 index 000000000..b1bf60842 --- /dev/null +++ b/util/http_test.go @@ -0,0 +1,45 @@ +package util + +import ( + "net" + "net/http" + + . "github.com/onsi/ginkgo/v2" + . "github.com/onsi/gomega" +) + +var _ = Describe("HTTP Util", func() { + Describe("HTTPClientIP", func() { + It("extracts the IP from RemoteAddr", func() { + r, err := http.NewRequest(http.MethodGet, "http://example.com", nil) + Expect(err).Should(Succeed()) + + ip := net.IPv4allrouter + r.RemoteAddr = net.JoinHostPort(ip.String(), "78954") + + Expect(HTTPClientIP(r)).Should(Equal(ip)) + }) + + It("extracts the IP from RemoteAddr without a port", func() { + r, err := http.NewRequest(http.MethodGet, "http://example.com", nil) + Expect(err).Should(Succeed()) + + ip := net.IPv4allrouter + r.RemoteAddr = ip.String() + + Expect(HTTPClientIP(r)).Should(Equal(ip)) + }) + + It("extracts the IP from the X-Forwarded-For header", func() { + r, err := http.NewRequest(http.MethodGet, "http://example.com", nil) + Expect(err).Should(Succeed()) + + ip := net.IPv4bcast + r.RemoteAddr = ip.String() + + r.Header.Set("X-Forwarded-For", ip.String()) + + Expect(HTTPClientIP(r)).Should(Equal(ip)) + }) + }) +})