Skip to content

Commit

Permalink
chore(all): add missing settings validation in New* constructors
Browse files Browse the repository at this point in the history
- Noop constructors are excluded
- pkg/blockbuilder
- pkg/doh/metrics/prometheus
- pkg/dot/metrics/prometheus
- pkg/middleware/metrics
- pkg/middlewares/cache/lru
- pkg/middlewares/cache/metrics/prometheus
- pkg/middlewares/filter/mapfilter
- pkg/middlewares/filter/metrics/prometheus
- pkg/middlewares/log
- pkg/middlewares/log/logger/console
- pkg/middlewares/metrics/prometheus
  • Loading branch information
qdm12 committed Nov 15, 2023
1 parent 366dd9b commit e07c215
Show file tree
Hide file tree
Showing 25 changed files with 122 additions and 28 deletions.
12 changes: 10 additions & 2 deletions cmd/dns/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,21 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, //nolint:cycl
dnsLogger := logger.New(log.SetComponent("DNS server loop"))
const clientTimeout = 15 * time.Second
client := &http.Client{Timeout: clientTimeout}
blockBuilder := setup.BuildBlockBuilder(settings.Block, client)
blockBuilder, err := setup.BuildBlockBuilder(settings.Block, client)
if err != nil {
return fmt.Errorf("block builder: %w", err)
}

prometheusRegistry := prometheus.NewRegistry()
cacheMetrics, err := setup.BuildCacheMetrics(settings.Metrics, prometheusRegistry)
if err != nil {
return fmt.Errorf("cache metrics: %w", err)
}
cache := setup.BuildCache(settings.Cache, cacheMetrics) // share the same cache across DNS server restarts
cache, err := setup.BuildCache(settings.Cache, cacheMetrics) // share the same cache across DNS server restarts
if err != nil {
return fmt.Errorf("cache: %w", err)
}

dnsLoop, err := dns.New(settings, dnsLogger, blockBuilder, cache, prometheusRegistry)
if err != nil {
return fmt.Errorf("creating DNS loop: %w", err)
Expand Down
6 changes: 5 additions & 1 deletion examples/doh-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ func main() {

logger := new(Logger)

cacheMiddleware := cachemiddleware.New(lru.New(lru.Settings{}))
cache, err := lru.New(lru.Settings{})
if err != nil {
log.Fatal(err)
}
cacheMiddleware := cachemiddleware.New(cache)

server, err := doh.NewServer(doh.ServerSettings{
Middlewares: []doh.Middleware{cacheMiddleware},
Expand Down
6 changes: 5 additions & 1 deletion examples/dot-server/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,11 @@ func main() {

logger := new(Logger)

cacheMiddleware := cachemiddleware.New(lru.New(lru.Settings{}))
cache, err := lru.New(lru.Settings{})
if err != nil {
log.Fatal(err)
}
cacheMiddleware := cachemiddleware.New(cache)

server, err := dot.NewServer(dot.ServerSettings{
Middlewares: []dot.Middleware{cacheMiddleware},
Expand Down
2 changes: 1 addition & 1 deletion internal/setup/block.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ import (
)

func BuildBlockBuilder(userSettings settings.Block,
client *http.Client) (blockBuilder *blockbuilder.Builder) {
client *http.Client) (blockBuilder *blockbuilder.Builder, err error) {
settings := blockbuilder.Settings{
Client: client,
BlockMalicious: userSettings.BlockMalicious,
Expand Down
4 changes: 2 additions & 2 deletions internal/setup/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,10 @@ type Cache interface {
}

func BuildCache(userSettings settings.Cache, //nolint:ireturn
metrics CacheMetrics) (cache Cache) {
metrics CacheMetrics) (cache Cache, err error) {
switch userSettings.Type {
case noop.CacheType:
return noop.New(noop.Settings{Metrics: metrics})
return noop.New(noop.Settings{Metrics: metrics}), nil
case lru.CacheType:
return lru.New(lru.Settings{
MaxEntries: userSettings.LRU.MaxEntries,
Expand Down
10 changes: 7 additions & 3 deletions internal/setup/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ func logMiddleware(userSettings settings.MiddlewareLog) (middleware *log.Middlew
settings := log.Settings{
Logger: noop.New(),
}
return log.New(settings), nil
return log.New(settings)
}

const dirPerm = os.FileMode(0744)
Expand All @@ -38,12 +38,16 @@ func logMiddleware(userSettings settings.MiddlewareLog) (middleware *log.Middlew
LogRequests: boolPtr(*userSettings.LogRequests),
LogResponses: boolPtr(*userSettings.LogResponses),
}
middlewareLogger := console.New(middlewareLoggerSettings)
middlewareLogger, err := console.New(middlewareLoggerSettings)
if err != nil {
return nil, fmt.Errorf("creating logger: %w", err)
}

settings := log.Settings{
Logger: middlewareLogger,
}

return log.New(settings), nil
return log.New(settings)
}

func boolPtr(b bool) *bool { return &b }
2 changes: 1 addition & 1 deletion internal/setup/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,5 +39,5 @@ func middlewareMetrics(metricsType string,
settings := metricsmiddleware.Settings{
Metrics: metrics,
}
return metricsmiddleware.New(settings), nil
return metricsmiddleware.New(settings)
}
4 changes: 3 additions & 1 deletion pkg/blockbuilder/all_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_Builder_BuildAll(t *testing.T) { //nolint:cyclop,maintidx
Expand Down Expand Up @@ -253,7 +254,8 @@ func Test_Builder_BuildAll(t *testing.T) { //nolint:cyclop,maintidx
}),
}

builder := New(tc.settings)
builder, err := New(tc.settings)
require.NoError(t, err)

result := builder.BuildAll(ctx)

Expand Down
10 changes: 8 additions & 2 deletions pkg/blockbuilder/builder.go
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
package blockbuilder

import (
"fmt"
"net/http"
"net/netip"
)

func New(settings Settings) *Builder {
func New(settings Settings) (builder *Builder, err error) {
settings.SetDefaults()

err = settings.Validate()
if err != nil {
return nil, fmt.Errorf("settings validation: %w", err)
}

return &Builder{
client: settings.Client,
blockMalicious: *settings.BlockMalicious,
Expand All @@ -20,7 +26,7 @@ func New(settings Settings) *Builder {
addBlockedIPs: settings.AddBlockedIPs,
addBlockedIPPrefixes: settings.AddBlockedIPPrefixes,
// TODO cache blocked IPs and hostnames after first request?
}
}, nil
}

type Builder struct {
Expand Down
4 changes: 3 additions & 1 deletion pkg/blockbuilder/hostnames_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_Builder_Hostnames(t *testing.T) { //nolint:cyclop
Expand Down Expand Up @@ -156,7 +157,8 @@ func Test_Builder_Hostnames(t *testing.T) { //nolint:cyclop
}

settings := Settings{Client: client}
builder := New(settings)
builder, err := New(settings)
require.NoError(t, err)

blockedHostnames, errs := builder.buildHostnames(ctx,
tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked,
Expand Down
4 changes: 3 additions & 1 deletion pkg/blockbuilder/ips_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_Builder_IPs(t *testing.T) { //nolint:cyclop
Expand Down Expand Up @@ -154,7 +155,8 @@ func Test_Builder_IPs(t *testing.T) { //nolint:cyclop
}

settings := Settings{Client: client}
builder := New(settings)
builder, err := New(settings)
require.NoError(t, err)

blockedIPs, blockedIPPrefixes, errs := builder.buildIPs(ctx,
tc.malicious.blocked, tc.ads.blocked, tc.surveillance.blocked,
Expand Down
3 changes: 2 additions & 1 deletion pkg/doh/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -258,11 +258,12 @@ func Test_Server_Mocks(t *testing.T) {
middlewareMetrics.EXPECT().AnswersInc("IN", "A")
middlewareMetrics.EXPECT().AnswersInc("IN", "AAAA")

metricsMiddleware := metricsmiddleware.New(
metricsMiddleware, err := metricsmiddleware.New(
metricsmiddleware.Settings{
Metrics: middlewareMetrics,
},
)
require.NoError(t, err)

server, err := NewServer(ServerSettings{
Logger: logger,
Expand Down
5 changes: 5 additions & 0 deletions pkg/doh/metrics/prometheus/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,11 @@ type Metrics struct {
func New(settings Settings) (metrics *Metrics, err error) {
settings.SetDefaults()

err = settings.Validate()
if err != nil {
return nil, fmt.Errorf("settings validation: %w", err)
}

metrics = new(Metrics)

metrics.counters, err = newCounters(settings.Prometheus)
Expand Down
3 changes: 2 additions & 1 deletion pkg/dot/integration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -253,11 +253,12 @@ func Test_Server_Mocks(t *testing.T) {
metrics.EXPECT().AnswersInc("IN", "A")
metrics.EXPECT().AnswersInc("IN", "AAAA")

metricsMiddleware := metricsmiddleware.New(
metricsMiddleware, err := metricsmiddleware.New(
metricsmiddleware.Settings{
Metrics: metrics,
},
)
require.NoError(t, err)

server, err := NewServer(ServerSettings{
Logger: logger,
Expand Down
5 changes: 5 additions & 0 deletions pkg/dot/metrics/prometheus/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ type Metrics struct {
func New(settings Settings) (metrics *Metrics, err error) {
settings.SetDefaults()

err = settings.Validate()
if err != nil {
return nil, fmt.Errorf("settings validation: %w", err)
}

metrics = new(Metrics)

metrics.counters, err = newCounters(settings.Prometheus)
Expand Down
10 changes: 8 additions & 2 deletions pkg/middlewares/cache/lru/lru.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package lru

import (
"container/list"
"fmt"
"sync"
"time"

Expand All @@ -24,9 +25,14 @@ type LRU struct {
timeNow func() time.Time
}

func New(settings Settings) *LRU {
func New(settings Settings) (cache *LRU, err error) {
settings.SetDefaults()

err = settings.Validate()
if err != nil {
return nil, fmt.Errorf("settings validation: %w", err)
}

settings.Metrics.SetCacheType(CacheType)
settings.Metrics.CacheMaxEntriesSet(settings.MaxEntries)

Expand All @@ -36,7 +42,7 @@ func New(settings Settings) *LRU {
linkedList: list.New(),
metrics: settings.Metrics,
timeNow: time.Now,
}
}, nil
}

func (l *LRU) Add(request, response *dns.Msg) {
Expand Down
4 changes: 3 additions & 1 deletion pkg/middlewares/cache/lru/lru_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"github.com/golang/mock/gomock"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func newTestMsgs(name string, expUnix uint32) (request, response *dns.Msg) {
Expand Down Expand Up @@ -43,7 +44,8 @@ func Test_lru_e2e(t *testing.T) {

metrics.EXPECT().SetCacheType(CacheType)
metrics.EXPECT().CacheMaxEntriesSet(settings.MaxEntries)
lru := New(settings)
lru, err := New(settings)
require.NoError(t, err)

metrics.EXPECT().CacheInsertInc()
lru.Add(requestA, responseA)
Expand Down
5 changes: 5 additions & 0 deletions pkg/middlewares/cache/metrics/prometheus/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,11 @@ type Metrics struct {
func New(settings Settings) (metrics *Metrics, err error) {
settings.SetDefaults()

err = settings.Validate()
if err != nil {
return nil, fmt.Errorf("settings validation: %w", err)
}

metrics = new(Metrics)

metrics.counters, err = newCounters(settings.Prometheus)
Expand Down
6 changes: 6 additions & 0 deletions pkg/middlewares/filter/mapfilter/filter.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package mapfilter

import (
"fmt"
"net/netip"
"sync"
)
Expand All @@ -16,6 +17,11 @@ type Filter struct {
func New(settings Settings) (filter *Filter, err error) {
settings.SetDefaults()

err = settings.Validate()
if err != nil {
return nil, fmt.Errorf("settings validation: %w", err)
}

filter = &Filter{
metrics: settings.Metrics,
}
Expand Down
5 changes: 5 additions & 0 deletions pkg/middlewares/filter/metrics/prometheus/metrics.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,11 @@ type Metrics struct {
func New(settings Settings) (metrics *Metrics, err error) {
settings.SetDefaults()

err = settings.Validate()
if err != nil {
return nil, fmt.Errorf("settings validation: %w", err)
}

metrics = new(Metrics)

metrics.counters, err = newCounters(settings.Prometheus)
Expand Down
11 changes: 9 additions & 2 deletions pkg/middlewares/log/log.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
package log

import (
"fmt"
"net"

"github.com/miekg/dns"
Expand All @@ -13,11 +14,17 @@ type Middleware struct {
logger Logger
}

func New(settings Settings) *Middleware {
func New(settings Settings) (middleware *Middleware, err error) {
settings.SetDefaults()

err = settings.Validate()
if err != nil {
return nil, fmt.Errorf("settings validation: %w", err)
}

return &Middleware{
logger: settings.Logger,
}
}, nil
}

// Wrap wraps the DNS handler with the middleware.
Expand Down
4 changes: 3 additions & 1 deletion pkg/middlewares/log/log_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/golang/mock/gomock"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)

func Test_New(t *testing.T) {
Expand All @@ -29,7 +30,8 @@ func Test_New(t *testing.T) {
Logger: logger,
}

middleware := New(settings)
middleware, err := New(settings)
require.NoError(t, err)

next := dns.HandlerFunc(func(rw dns.ResponseWriter, m *dns.Msg) {})
handler := middleware.Wrap(next)
Expand Down
Loading

0 comments on commit e07c215

Please sign in to comment.