diff --git a/server/server.go b/server/server.go index 3ff0ece08..194d5bb4d 100644 --- a/server/server.go +++ b/server/server.go @@ -417,7 +417,9 @@ func createQueryResolver( func (s *Server) registerDNSHandlers(ctx context.Context) { wrappedOnRequest := func(w dns.ResponseWriter, request *dns.Msg) { - s.OnRequest(ctx, w, request) + ip, proto := resolveClientIPAndProtocol(w.RemoteAddr()) + + s.OnRequest(ctx, w, ip, proto, request) } for _, server := range s.dnsServers { @@ -550,25 +552,6 @@ func (s *Server) Stop(ctx context.Context) error { return nil } -func createResolverRequest(rw dns.ResponseWriter, request *dns.Msg) *model.Request { - var hostName string - - var remoteAddr net.Addr - - if rw != nil { - remoteAddr = rw.RemoteAddr() - } - - clientIP, protocol := resolveClientIPAndProtocol(remoteAddr) - con, ok := rw.(dns.ConnectionStater) - - if ok && con.ConnectionState() != nil { - hostName = con.ConnectionState().ServerName - } - - return newRequest(clientIP, protocol, extractClientIDFromHost(hostName), request) -} - func extractClientIDFromHost(hostName string) string { const clientIDPrefix = "id-" if strings.HasPrefix(hostName, clientIDPrefix) && strings.Contains(hostName, ".") { @@ -578,12 +561,13 @@ func extractClientIDFromHost(hostName string) string { return "" } -func newRequest(clientIP net.IP, protocol model.RequestProtocol, - requestClientID string, request *dns.Msg, +func newRequest( + clientIP net.IP, clientID string, + protocol model.RequestProtocol, request *dns.Msg, ) *model.Request { return &model.Request{ ClientIP: clientIP, - RequestClientID: requestClientID, + RequestClientID: clientID, Protocol: protocol, Req: request, Log: log.Log().WithFields(logrus.Fields{ @@ -595,13 +579,23 @@ func newRequest(clientIP net.IP, protocol model.RequestProtocol, } // OnRequest will be executed if a new DNS request is received -func (s *Server) OnRequest(ctx context.Context, w dns.ResponseWriter, request *dns.Msg) { +func (s *Server) OnRequest( + ctx context.Context, w dns.ResponseWriter, + clientIP net.IP, protocol model.RequestProtocol, + request *dns.Msg, +) { logger().Debug("new request") - r := createResolverRequest(w, request) + var hostName string + + con, ok := w.(dns.ConnectionStater) + if ok && con.ConnectionState() != nil { + hostName = con.ConnectionState().ServerName + } - response, err := s.queryResolver.Resolve(ctx, r) + req := newRequest(clientIP, extractClientIDFromHost(hostName), protocol, request) + response, err := s.resolve(ctx, req) if err != nil { logger().Error("error on processing request:", err) @@ -610,17 +604,40 @@ func (s *Server) OnRequest(ctx context.Context, w dns.ResponseWriter, request *d err := w.WriteMsg(m) util.LogOnError("can't write message: ", err) } else { - response.Res.MsgHdr.RecursionAvailable = request.MsgHdr.RecursionDesired + err := w.WriteMsg(response.Res) + util.LogOnError("can't write message: ", err) + } +} - // truncate if necessary - response.Res.Truncate(getMaxResponseSize(r)) +func (s *Server) resolve(ctx context.Context, request *model.Request) (*model.Response, error) { + var response *model.Response - // enable compression - response.Res.Compress = true + switch { + case len(request.Req.Question) == 0: + m := new(dns.Msg) + m.SetRcode(request.Req, dns.RcodeFormatError) - err := w.WriteMsg(response.Res) - util.LogOnError("can't write message: ", err) + request.Log.Error("query has no questions") + + response = &model.Response{Res: m, RType: model.ResponseTypeCUSTOMDNS, Reason: "CUSTOM DNS"} + default: + var err error + + response, err = s.queryResolver.Resolve(ctx, request) + if err != nil { + return nil, err + } } + + response.Res.MsgHdr.RecursionAvailable = request.Req.MsgHdr.RecursionDesired + + // truncate if necessary + response.Res.Truncate(getMaxResponseSize(request)) + + // enable compression + response.Res.Compress = true + + return response, nil } // returns EDNS UDP size or if not present, 512 for UDP and 64K for TCP diff --git a/server/server_endpoints.go b/server/server_endpoints.go index cca9ba60c..55cc06f95 100644 --- a/server/server_endpoints.go +++ b/server/server_endpoints.go @@ -147,9 +147,9 @@ func (s *Server) processDohMessage(rawMsg []byte, rw http.ResponseWriter, req *h clientID = extractClientIDFromHost(req.Host) } - r := newRequest(util.HTTPClientIP(req), model.RequestProtocolTCP, clientID, msg) + r := newRequest(util.HTTPClientIP(req), clientID, model.RequestProtocolTCP, msg) - resResponse, err := s.queryResolver.Resolve(req.Context(), r) + resResponse, err := s.resolve(req.Context(), r) if err != nil { logAndResponseWithError(err, "unable to process query: ", rw) @@ -173,9 +173,9 @@ func (s *Server) Query( ctx context.Context, serverHost string, clientIP net.IP, question string, qType dns.Type, ) (*model.Response, error) { dnsRequest := util.NewMsgWithQuestion(question, qType) - r := newRequest(clientIP, model.RequestProtocolTCP, extractClientIDFromHost(serverHost), dnsRequest) + r := newRequest(clientIP, extractClientIDFromHost(serverHost), model.RequestProtocolTCP, dnsRequest) - return s.queryResolver.Resolve(ctx, r) + return s.resolve(ctx, r) } func createHTTPSRouter(cfg *config.Config) *chi.Mux {