Skip to content

Commit

Permalink
refactor(server): add resolve for common query code
Browse files Browse the repository at this point in the history
Ensure all queries go through that common code path so we always enable
compression, truncate if required, etc.
  • Loading branch information
ThinkChaos committed Jan 18, 2024
1 parent e9a1e89 commit f0ad412
Show file tree
Hide file tree
Showing 2 changed files with 54 additions and 37 deletions.
83 changes: 50 additions & 33 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -417,7 +417,9 @@ func createQueryResolver(

func (s *Server) registerDNSHandlers(ctx context.Context) {
wrappedOnRequest := func(w dns.ResponseWriter, request *dns.Msg) {
s.OnRequest(ctx, w, request)
ip, proto := resolveClientIPAndProtocol(w.RemoteAddr())

s.OnRequest(ctx, w, ip, proto, request)
}

for _, server := range s.dnsServers {
Expand Down Expand Up @@ -550,25 +552,6 @@ func (s *Server) Stop(ctx context.Context) error {
return nil
}

func createResolverRequest(rw dns.ResponseWriter, request *dns.Msg) *model.Request {
var hostName string

var remoteAddr net.Addr

if rw != nil {
remoteAddr = rw.RemoteAddr()
}

clientIP, protocol := resolveClientIPAndProtocol(remoteAddr)
con, ok := rw.(dns.ConnectionStater)

if ok && con.ConnectionState() != nil {
hostName = con.ConnectionState().ServerName
}

return newRequest(clientIP, protocol, extractClientIDFromHost(hostName), request)
}

func extractClientIDFromHost(hostName string) string {
const clientIDPrefix = "id-"
if strings.HasPrefix(hostName, clientIDPrefix) && strings.Contains(hostName, ".") {
Expand All @@ -578,12 +561,13 @@ func extractClientIDFromHost(hostName string) string {
return ""
}

func newRequest(clientIP net.IP, protocol model.RequestProtocol,
requestClientID string, request *dns.Msg,
func newRequest(
clientIP net.IP, clientID string,
protocol model.RequestProtocol, request *dns.Msg,
) *model.Request {
return &model.Request{
ClientIP: clientIP,
RequestClientID: requestClientID,
RequestClientID: clientID,
Protocol: protocol,
Req: request,
Log: log.Log().WithFields(logrus.Fields{
Expand All @@ -595,13 +579,23 @@ func newRequest(clientIP net.IP, protocol model.RequestProtocol,
}

// OnRequest will be executed if a new DNS request is received
func (s *Server) OnRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) {
func (s *Server) OnRequest(
ctx context.Context, w dns.ResponseWriter,
clientIP net.IP, protocol model.RequestProtocol,
request *dns.Msg,
) {
logger().Debug("new request")

r := createResolverRequest(w, request)
var hostName string

con, ok := w.(dns.ConnectionStater)
if ok && con.ConnectionState() != nil {
hostName = con.ConnectionState().ServerName
}

response, err := s.queryResolver.Resolve(ctx, r)
req := newRequest(clientIP, extractClientIDFromHost(hostName), protocol, request)

response, err := s.resolve(ctx, req)
if err != nil {
logger().Error("error on processing request:", err)

Expand All @@ -610,17 +604,40 @@ func (s *Server) OnRequest(ctx context.Context, w dns.ResponseWriter, request *d
err := w.WriteMsg(m)
util.LogOnError("can't write message: ", err)
} else {
response.Res.MsgHdr.RecursionAvailable = request.MsgHdr.RecursionDesired
err := w.WriteMsg(response.Res)
util.LogOnError("can't write message: ", err)
}
}

// truncate if necessary
response.Res.Truncate(getMaxResponseSize(r))
func (s *Server) resolve(ctx context.Context, request *model.Request) (*model.Response, error) {
var response *model.Response

// enable compression
response.Res.Compress = true
switch {
case len(request.Req.Question) == 0:
m := new(dns.Msg)
m.SetRcode(request.Req, dns.RcodeFormatError)

err := w.WriteMsg(response.Res)
util.LogOnError("can't write message: ", err)
request.Log.Error("query has no questions")

response = &model.Response{Res: m, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"}
default:
var err error

response, err = s.queryResolver.Resolve(ctx, request)
if err != nil {
return nil, err
}
}

response.Res.MsgHdr.RecursionAvailable = request.Req.MsgHdr.RecursionDesired

// truncate if necessary
response.Res.Truncate(getMaxResponseSize(request))

// enable compression
response.Res.Compress = true

return response, nil
}

// returns EDNS UDP size or if not present, 512 for UDP and 64K for TCP
Expand Down
8 changes: 4 additions & 4 deletions server/server_endpoints.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,9 +147,9 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h
clientID = extractClientIDFromHost(req.Host)
}

r := newRequest(util.HTTPClientIP(req), model.RequestProtocolTCP, clientID, msg)
r := newRequest(util.HTTPClientIP(req), clientID, model.RequestProtocolTCP, msg)

resResponse, err := s.queryResolver.Resolve(req.Context(), r)
resResponse, err := s.resolve(req.Context(), r)
if err != nil {
logAndResponseWithError(err, "unable to process query: ", rw)

Expand All @@ -173,9 +173,9 @@ func (s *Server) Query(
ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type,
) (*model.Response, error) {
dnsRequest := util.NewMsgWithQuestion(question, qType)
r := newRequest(clientIP, model.RequestProtocolTCP, extractClientIDFromHost(serverHost), dnsRequest)
r := newRequest(clientIP, extractClientIDFromHost(serverHost), model.RequestProtocolTCP, dnsRequest)

return s.queryResolver.Resolve(ctx, r)
return s.resolve(ctx, r)
}

func createHTTPSRouter(cfg *config.Config) *chi.Mux {
Expand Down

0 comments on commit f0ad412

Please sign in to comment.