Skip to content

Commit

Permalink
feat: Support defining records by dns zone format (#1360)
Browse files Browse the repository at this point in the history
* feat: Support zonefile configuration for custom dns mapping

* docs: Update configuration.md

* Rename var to ok

* Linter fixes

* Remove hashes in test describe description

* Implement PR comments; zoneFileMapping -> zone, initialize with proper sizes

* Remove custom CNAME parsing

* Utilize TTL defined in zone file

* Link to wikipedia's example file

* Test to confirm that a relative zone entry without an $ORIGIN returns an error

* Write a test covering the $INCLUDE directive

* Write a test confirming that a dns zone can result in more than 1 RR

* Linting

* fix: Use proper matchers in CustomDNS Zone tests; Update configuration.md description

* Pull in config directory to support relative $INCLUDE

* Added tests to ensure the ability to use both bare filenames as well as relative filenames when using the $INCLUDE directive

* Shorten test description (Linting error)

* Move Assignment of z.RRs to the end of the UnmarshallYAML function

* Moved tests for relative $INCLUDE zones to config_test. Added test case when config param passed to blocky is a directory

* Corrected test case to _actually_ test againt bare file names
  • Loading branch information
BenMcH authored Feb 9, 2024
1 parent 178dbb7 commit 9f633f1
Show file tree
Hide file tree
Showing 8 changed files with 359 additions and 81 deletions.
11 changes: 10 additions & 1 deletion config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -455,21 +455,30 @@ func loadConfig(logger *logrus.Entry, path string, mandatory bool) (rCfg *Config
return nil, fmt.Errorf("can't read config file(s): %w", err)
}

var data []byte
var (
data []byte
prettyPath string
)

if fs.IsDir() {
prettyPath = filepath.Join(path, "*")

data, err = readFromDir(path, data)

if err != nil {
return nil, fmt.Errorf("can't read config files: %w", err)
}
} else {
prettyPath = path

data, err = os.ReadFile(path)
if err != nil {
return nil, fmt.Errorf("can't read config file: %w", err)
}
}

cfg.CustomDNS.Zone.configPath = prettyPath

err = unmarshalConfig(logger, data, &cfg)
if err != nil {
return nil, err
Expand Down
115 changes: 115 additions & 0 deletions config/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,94 @@ var _ = Describe("Config", func() {
defaultTestFileConfig(c)
})
})
When("Test config file contains a zone file with $INCLUDE", func() {
When("The config path is set to the config file", func() {
It("Should support the $INCLUDE directive with a bare filename", func() {
folder := helpertest.NewTmpFolder("zones")
folder.CreateStringFile("other.zone", "www 3600 A 1.2.3.4")
cfgFile := writeConfigYmlWithLocalZoneFile(folder, "other.zone")

c, err = LoadConfig(cfgFile.Path, true)

Expect(err).Should(Succeed())
Expect(c.CustomDNS.Zone.RRs).Should(HaveLen(1))

Expect(c.CustomDNS.Zone.RRs["www.example.com."]).
Should(SatisfyAll(
HaveLen(1),
ContainElements(
SatisfyAll(
helpertest.BeDNSRecord("www.example.com.", helpertest.A, "1.2.3.4"),
helpertest.HaveTTL(BeNumerically("==", 3600)),
)),
))
})
It("Should support the $INCLUDE directive with a relative filename", func() {
folder := helpertest.NewTmpFolder("zones")
folder.CreateStringFile("other.zone", "www 3600 A 1.2.3.4")
cfgFile := writeConfigYmlWithLocalZoneFile(folder, "./other.zone")

c, err = LoadConfig(cfgFile.Path, true)

Expect(err).Should(Succeed())
Expect(c.CustomDNS.Zone.RRs).Should(HaveLen(1))

Expect(c.CustomDNS.Zone.RRs["www.example.com."]).
Should(SatisfyAll(

HaveLen(1),
ContainElements(
SatisfyAll(
helpertest.BeDNSRecord("www.example.com.", helpertest.A, "1.2.3.4"),
helpertest.HaveTTL(BeNumerically("==", 3600)),
)),
))
})
})
When("The config path is set to a directory", func() {
It("Should support the $INCLUDE directive with a bare filename", func() {
folder := helpertest.NewTmpFolder("zones")
folder.CreateStringFile("other.zone", "www 3600 A 1.2.3.4")
writeConfigYmlWithLocalZoneFile(folder, "other.zone")

c, err = LoadConfig(folder.Path, true)

Expect(err).Should(Succeed())
Expect(c.CustomDNS.Zone.RRs).Should(HaveLen(1))

Expect(c.CustomDNS.Zone.RRs["www.example.com."]).
Should(SatisfyAll(
HaveLen(1),
ContainElements(
SatisfyAll(
helpertest.BeDNSRecord("www.example.com.", helpertest.A, "1.2.3.4"),
helpertest.HaveTTL(BeNumerically("==", 3600)),
)),
))
})
It("Should support the $INCLUDE directive with a relative filename", func() {
folder := helpertest.NewTmpFolder("zones")
folder.CreateStringFile("other.zone", "www 3600 A 1.2.3.4")
writeConfigYmlWithLocalZoneFile(folder, "./other.zone")

c, err = LoadConfig(folder.Path, true)

Expect(err).Should(Succeed())
Expect(c.CustomDNS.Zone.RRs).Should(HaveLen(1))

Expect(c.CustomDNS.Zone.RRs["www.example.com."]).
Should(SatisfyAll(

HaveLen(1),
ContainElements(
SatisfyAll(
helpertest.BeDNSRecord("www.example.com.", helpertest.A, "1.2.3.4"),
helpertest.HaveTTL(BeNumerically("==", 3600)),
)),
))
})
})
})
When("Test file does not exist", func() {
It("should fail", func() {
_, err := LoadConfig(tmpDir.JoinPath("config-does-not-exist.yaml"), true)
Expand Down Expand Up @@ -977,6 +1065,33 @@ func writeConfigYml(tmpDir *helpertest.TmpFolder) *helpertest.TmpFile {
)
}

func writeConfigYmlWithLocalZoneFile(tmpDir *helpertest.TmpFolder, includeStr string) *helpertest.TmpFile {
return tmpDir.CreateStringFile("config.yml",
"upstreams:",
" userAgent: testBlocky",
" init:",
" strategy: failOnError",
" groups:",
" default:",
" - tcp+udp:8.8.8.8",
" - tcp+udp:8.8.4.4",
" - 1.1.1.1",
"customDNS:",
" zone: |",
" $ORIGIN example.com.",
" $INCLUDE "+includeStr,
"filtering:",
" queryTypes:",
" - AAAA",
" - A",
"fqdnOnly:",
" enable: true",
"port: 55553,:55554,[::1]:55555",
"logLevel: debug",
"minTlsServeVersion: 1.3",
)
}

func writeConfigDir(tmpDir *helpertest.TmpFolder) {
tmpDir.CreateStringFile("config1.yaml",
"upstreams:",
Expand Down
95 changes: 52 additions & 43 deletions config/custom_dns.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,14 +14,57 @@ type CustomDNS struct {
RewriterConfig `yaml:",inline"`
CustomTTL Duration `yaml:"customTTL" default:"1h"`
Mapping CustomDNSMapping `yaml:"mapping"`
Zone ZoneFileDNS `yaml:"zone" default:""`
FilterUnmappedTypes bool `yaml:"filterUnmappedTypes" default:"true"`
}

type (
CustomDNSMapping map[string]CustomDNSEntries
CustomDNSEntries []dns.RR

ZoneFileDNS struct {
RRs CustomDNSMapping
configPath string
}
)

func (z *ZoneFileDNS) UnmarshalYAML(unmarshal func(interface{}) error) error {
var input string
if err := unmarshal(&input); err != nil {
return err
}

result := make(CustomDNSMapping)

zoneParser := dns.NewZoneParser(strings.NewReader(input), "", z.configPath)
zoneParser.SetIncludeAllowed(true)

for {
zoneRR, ok := zoneParser.Next()

if !ok {
if zoneParser.Err() != nil {
return zoneParser.Err()
}

// Done
break
}

domain := zoneRR.Header().Name

if _, ok := result[domain]; !ok {
result[domain] = make(CustomDNSEntries, 0, 1)
}

result[domain] = append(result[domain], zoneRR)
}

z.RRs = result

return nil
}

func (c *CustomDNSEntries) UnmarshalYAML(unmarshal func(interface{}) error) error {
var input string
if err := unmarshal(&input); err != nil {
Expand All @@ -30,24 +73,16 @@ func (c *CustomDNSEntries) UnmarshalYAML(unmarshal func(interface{}) error) erro

parts := strings.Split(input, ",")
result := make(CustomDNSEntries, len(parts))
containsCNAME := false

for i, part := range parts {
rr, err := configToRR(part)
if err != nil {
return err
}

_, isCNAME := rr.(*dns.CNAME)
containsCNAME = containsCNAME || isCNAME

result[i] = rr
}

if containsCNAME && len(result) > 1 {
return fmt.Errorf("when a CNAME record is present, it must be the only record in the mapping")
}

*c = result

return nil
Expand All @@ -70,47 +105,21 @@ func (c *CustomDNS) LogConfig(logger *logrus.Entry) {
}
}

func removePrefixSuffix(in, prefix string) string {
in = strings.TrimPrefix(in, fmt.Sprintf("%s(", prefix))
in = strings.TrimSuffix(in, ")")

return strings.TrimSpace(in)
}

func configToRR(part string) (dns.RR, error) {
if strings.HasPrefix(part, "CNAME(") {
domain := removePrefixSuffix(part, "CNAME")
domain = dns.Fqdn(domain)
cname := &dns.CNAME{Target: domain}

return cname, nil
func configToRR(ipStr string) (dns.RR, error) {
ip := net.ParseIP(ipStr)
if ip == nil {
return nil, fmt.Errorf("invalid IP address '%s'", ipStr)
}

// Fall back to A/AAAA records to maintain backwards compatibility in config.yml
// We will still remove the A() or AAAA() if it exists
if strings.Contains(part, ".") { // IPV4 address
ipStr := removePrefixSuffix(part, "A")
ip := net.ParseIP(ipStr)

if ip == nil {
return nil, fmt.Errorf("invalid IP address '%s'", part)
}

if ip.To4() != nil {
a := new(dns.A)
a.A = ip

return a, nil
} else { // IPV6 address
ipStr := removePrefixSuffix(part, "AAAA")
ip := net.ParseIP(ipStr)

if ip == nil {
return nil, fmt.Errorf("invalid IP address '%s'", part)
}
}

aaaa := new(dns.AAAA)
aaaa.AAAA = ip
aaaa := new(dns.AAAA)
aaaa.AAAA = ip

return aaaa, nil
}
return aaaa, nil
}
Loading

0 comments on commit 9f633f1

Please sign in to comment.