Skip to content

Commit

Permalink
feat(api): support client name lookup when querying via the API
Browse files Browse the repository at this point in the history
  • Loading branch information
ThinkChaos committed Jan 18, 2024
1 parent aaee562 commit e9a1e89
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 34 deletions.
33 changes: 30 additions & 3 deletions api/api_interface_impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ package api
import (
"context"
"fmt"
"net"
"net/http"
"strings"
"time"

Expand All @@ -17,6 +19,8 @@ import (
"github.com/miekg/dns"
)

type httpReqCtxKey struct{}

// BlockingStatus represents the current blocking status
type BlockingStatus struct {
// True if blocking is enabled
Expand All @@ -40,15 +44,27 @@ type ListRefresher interface {
}

type Querier interface {
Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error)
Query(
ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type,
) (*model.Response, error)
}

type CacheControl interface {
FlushCaches(ctx context.Context)
}

func RegisterOpenAPIEndpoints(router chi.Router, impl StrictServerInterface) {
HandlerFromMuxWithBaseURL(NewStrictHandler(impl, nil), router, "/api")
middleware := []StrictMiddlewareFunc{ctxWithHTTPRequestMiddleware}

HandlerFromMuxWithBaseURL(NewStrictHandler(impl, middleware), router, "/api")
}

func ctxWithHTTPRequestMiddleware(handler StrictHandlerFunc, operationID string) StrictHandlerFunc {
return func(ctx context.Context, w http.ResponseWriter, r *http.Request, request any) (response any, err error) {
ctx = context.WithValue(ctx, httpReqCtxKey{}, r)

return handler(ctx, w, r, request)
}
}

type OpenAPIInterfaceImpl struct {
Expand Down Expand Up @@ -143,7 +159,18 @@ func (i *OpenAPIInterfaceImpl) Query(ctx context.Context, request QueryRequestOb
return Query400TextResponse(fmt.Sprintf("unknown query type '%s'", request.Body.Type)), nil
}

resp, err := i.querier.Query(ctx, dns.Fqdn(request.Body.Query), qType)
var (
serverHost string
clientIP net.IP
)

httpReq, ok := ctx.Value(httpReqCtxKey{}).(*http.Request)
if ok {
serverHost = httpReq.Host
clientIP = util.HTTPClientIP(httpReq)
}

resp, err := i.querier.Query(ctx, serverHost, clientIP, dns.Fqdn(request.Body.Query), qType)
if err != nil {
return nil, err
}
Expand Down
66 changes: 62 additions & 4 deletions api/api_interface_impl_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,13 @@ package api
import (
"context"
"errors"
"net"
"net/http"
"time"

"github.com/0xERR0R/blocky/model"
"github.com/0xERR0R/blocky/util"
"github.com/go-chi/chi/v5"
"github.com/miekg/dns"
"github.com/stretchr/testify/mock"

Expand Down Expand Up @@ -53,10 +56,17 @@ func (m *BlockingControlMock) BlockingStatus() BlockingStatus {
return args.Get(0).(BlockingStatus)
}

func (m *QuerierMock) Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error) {
args := m.Called(ctx, question, qType)
func (m *QuerierMock) Query(
ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type,
) (*model.Response, error) {
args := m.Called(ctx, serverHost, clientIP, question, qType)

return args.Get(0).(*model.Response), args.Error(1)
err := args.Error(1)
if err != nil {
return nil, err
}

return args.Get(0).(*model.Response), nil
}

func (m *CacheControlMock) FlushCaches(ctx context.Context) {
Expand Down Expand Up @@ -92,6 +102,34 @@ var _ = Describe("API implementation tests", func() {
listRefreshMock.AssertExpectations(GinkgoT())
})

Describe("RegisterOpenAPIEndpoints", func() {
It("adds routes", func() {
rtr := chi.NewRouter()
RegisterOpenAPIEndpoints(rtr, sut)

Expect(rtr.Routes()).ShouldNot(BeEmpty())
})
})

Describe("ctxWithHTTPRequestMiddleware", func() {
It("adds the request to the context", func() {
handler := func(ctx context.Context, _ http.ResponseWriter, r *http.Request, _ any) (any, error) {
Expect(ctx.Value(httpReqCtxKey{})).Should(BeIdenticalTo(r))

return nil, nil //nolint:nilnil
}

handler = ctxWithHTTPRequestMiddleware(handler, "operation-id")

req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://example.com", nil)
Expect(err).Should(Succeed())

resp, err := handler(ctx, nil, req, nil)
Expect(err).Should(Succeed())
Expect(resp).Should(BeNil())
})
})

Describe("Query API", func() {
When("Query is called", func() {
It("should return 200 on success", func() {
Expand All @@ -100,7 +138,7 @@ var _ = Describe("API implementation tests", func() {
)
Expect(err).Should(Succeed())

querierMock.On("Query", ctx, "google.com.", A).Return(&model.Response{
querierMock.On("Query", ctx, "", net.IP(nil), "google.com.", A).Return(&model.Response{
Res: queryResponse,
Reason: "reason",
}, nil)
Expand All @@ -120,6 +158,26 @@ var _ = Describe("API implementation tests", func() {
Expect(resp200.ReturnCode).Should(Equal("NOERROR"))
})

It("extracts metadata from the HTTP request", func() {
r, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://blocky.localhost", nil)
Expect(err).Should(Succeed())

clientIP := net.IPv4allrouter
r.RemoteAddr = net.JoinHostPort(clientIP.String(), "89685")

ctx = context.WithValue(ctx, httpReqCtxKey{}, r)

expectedErr := errors.New("test")
querierMock.On("Query", ctx, "blocky.localhost", clientIP, "example.com.", A).Return(nil, expectedErr)

_, err = sut.Query(ctx, QueryRequestObject{
Body: &ApiQueryRequest{
Query: "example.com", Type: "A",
},
})
Expect(err).Should(MatchError(expectedErr))
})

It("should return 400 on wrong parameter", func() {
resp, err := sut.Query(ctx, QueryRequestObject{
Body: &ApiQueryRequest{
Expand Down
32 changes: 5 additions & 27 deletions server/server_endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"io"
"net"
"net/http"
"strings"
"time"

"github.com/0xERR0R/blocky/resolver"
Expand Down Expand Up @@ -148,7 +147,7 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
clientID = extractClientIDFromHost(req.Host)
}

r := newRequest(net.ParseIP(extractIP(req)), model.RequestProtocolTCP, clientID, msg)
r := newRequest(util.HTTPClientIP(req), model.RequestProtocolTCP, clientID, msg)

resResponse, err := s.queryResolver.Resolve(req.Context(), r)
if err != nil {
Expand All @@ -157,11 +156,6 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
return
}

response := new(dns.Msg)
response.SetReply(msg)
// enable compression
resResponse.Res.Compress = true

b, err := resResponse.Res.Pack()
if err != nil {
logAndResponseWithError(err, "can't serialize message: ", rw)
Expand All @@ -175,27 +169,11 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
logAndResponseWithError(err, "can't write response: ", rw)
}

func extractIP(r *http.Request) string {
hostPort := r.Header.Get("X-FORWARDED-FOR")

if hostPort == "" {
hostPort = r.RemoteAddr
}

hostPort = strings.ReplaceAll(hostPort, "[", "")
hostPort = strings.ReplaceAll(hostPort, "]", "")
index := strings.LastIndex(hostPort, ":")

if index >= 0 {
return hostPort[:index]
}

return hostPort
}

func (s *Server) Query(ctx context.Context, question string, qType dns.Type) (*model.Response, error) {
func (s *Server) Query(
ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type,
) (*model.Response, error) {
dnsRequest := util.NewMsgWithQuestion(question, qType)
r := createResolverRequest(nil, dnsRequest)
r := newRequest(clientIP, model.RequestProtocolTCP, extractClientIDFromHost(serverHost), dnsRequest)

return s.queryResolver.Resolve(ctx, r)
}
Expand Down
20 changes: 20 additions & 0 deletions util/http.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
package util

import (
"net"
"net/http"
)

func HTTPClientIP(r *http.Request) net.IP {
addr := r.Header.Get("X-FORWARDED-FOR")
if addr == "" {
addr = r.RemoteAddr
}

ip, _, err := net.SplitHostPort(addr)
if err != nil {
return net.ParseIP(addr)
}

return net.ParseIP(ip)
}
45 changes: 45 additions & 0 deletions util/http_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
package util

import (
"net"
"net/http"

. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
)

var _ = Describe("HTTP Util", func() {
Describe("HTTPClientIP", func() {
It("extracts the IP from RemoteAddr", func() {
r, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
Expect(err).Should(Succeed())

ip := net.IPv4allrouter
r.RemoteAddr = net.JoinHostPort(ip.String(), "78954")

Expect(HTTPClientIP(r)).Should(Equal(ip))
})

It("extracts the IP from RemoteAddr without a port", func() {
r, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
Expect(err).Should(Succeed())

ip := net.IPv4allrouter
r.RemoteAddr = ip.String()

Expect(HTTPClientIP(r)).Should(Equal(ip))
})

It("extracts the IP from the X-Forwarded-For header", func() {
r, err := http.NewRequest(http.MethodGet, "http://example.com", nil)
Expect(err).Should(Succeed())

ip := net.IPv4bcast
r.RemoteAddr = ip.String()

r.Header.Set("X-Forwarded-For", ip.String())

Expect(HTTPClientIP(r)).Should(Equal(ip))
})
})
})

0 comments on commit e9a1e89

Please sign in to comment.