Skip to content

Commit

Permalink
fix(settings): validation for DoH and DoT providers
Browse files Browse the repository at this point in the history
  • Loading branch information
qdm12 committed Jun 24, 2022
1 parent 94e5502 commit 3a2fbd7
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 35 deletions.
18 changes: 3 additions & 15 deletions internal/config/settings/doh.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package settings

import (
"errors"
"fmt"
"time"

Expand Down Expand Up @@ -30,21 +29,10 @@ func (d *DoH) setDefaults() {
d.Self.setDefaults()
}

var (
ErrDoHProviderNotValid = errors.New("DoH provider is not valid")
)

func (d *DoH) validate() (err error) {
allProviders := provider.All()
allProvidersSet := make(map[string]struct{}, len(allProviders))
for _, provider := range allProviders {
allProvidersSet[provider.Name] = struct{}{}
}

for _, provider := range d.DoHProviders {
if _, ok := allProvidersSet[provider]; !ok {
return fmt.Errorf("%w: %s", ErrDoHProviderNotValid, provider)
}
err = checkProviderNames(d.DoHProviders)
if err != nil {
return fmt.Errorf("DoH provider: %w", err)
}

const minTimeout = time.Millisecond
Expand Down
20 changes: 7 additions & 13 deletions internal/config/settings/dot.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,24 +43,18 @@ func (d *DoT) setDefaults() {
}

var (
ErrTimeoutTooSmall = errors.New("timeout is too small")
ErrDoTProviderNotValid = errors.New("DoT provider is not valid")
ErrDNSProviderNotValid = errors.New("plaintext DNS provider is not valid")
ErrTimeoutTooSmall = errors.New("timeout is too small")
)

func (d *DoT) validate() (err error) {
allProvidersSet := allProvidersStringSet()

for _, provider := range d.DoTProviders {
if _, ok := allProvidersSet[provider]; !ok {
return fmt.Errorf("%w: %s", ErrDoTProviderNotValid, provider)
}
err = checkProviderNames(d.DoTProviders)
if err != nil {
return fmt.Errorf("DoT provider: %w", err)
}

for _, provider := range d.DNSProviders {
if _, ok := allProvidersSet[provider]; !ok {
return fmt.Errorf("%w: %s", ErrDNSProviderNotValid, provider)
}
err = checkProviderNames(d.DNSProviders)
if err != nil {
return fmt.Errorf("fallback DNS plaintext provider: %w", err)
}

const minTimeout = time.Millisecond
Expand Down
10 changes: 9 additions & 1 deletion internal/config/settings/helpers.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
package settings

func andStrings(strings []string) (result string) {
return joinStrings(strings, "and")
}

func orStrings(strings []string) (result string) {
return joinStrings(strings, "or")
}

func joinStrings(strings []string, lastJoin string) (result string) {
if len(strings) == 0 {
return ""
}
Expand All @@ -10,7 +18,7 @@ func andStrings(strings []string) (result string) {
if i < len(strings)-1 {
result += strings[i] + ", "
} else {
result += " and " + strings[i]
result += " " + lastJoin + " " + strings[i]
}
}

Expand Down
27 changes: 21 additions & 6 deletions internal/config/settings/validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,26 @@ func checkListeningAddress(address string) (err error) {
return err
}

func allProvidersStringSet() (set map[string]struct{}) {
providers := provider.All()
set = make(map[string]struct{}, len(providers))
for _, provider := range providers {
set[provider.Name] = struct{}{}
func checkProviderNames(providerNames []string) (err error) {
allProviders := provider.All()
allProviderNames := make([]string, len(allProviders))
for i, provider := range allProviders {
allProviderNames[i] = provider.Name
}
return set

for _, providerName := range providerNames {
valid := false
for _, acceptedProviderName := range allProviderNames {
if strings.EqualFold(providerName, acceptedProviderName) {
valid = true
break
}
}
if !valid {
return fmt.Errorf("%w: %q must be one of: %s",
ErrValueNotOneOf, providerName, orStrings(allProviderNames))
}
}

return nil
}

0 comments on commit 3a2fbd7

Please sign in to comment.