Skip to content

Commit 3afa0e9

Browse files
committed
Fix DNS cache
1 parent c22747d commit 3afa0e9

File tree

3 files changed

+49
-26
lines changed

3 files changed

+49
-26
lines changed

internal/dns/cache.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,9 @@ func NewDNSCache(forwardTTL int, reverseTTL int, backend store.Interface) *DnsCa
3232
}
3333

3434
func (d *Dns) getCachedForward(host string) ([]string, bool) {
35+
if d.cache == nil {
36+
return nil, false
37+
}
3538
if cached, err := d.cache.forward.Get(d.ctx, host); err == nil {
3639
slog.Debug("DNS: forward cache hit", "name", host, "ips", cached)
3740
return cached, true
@@ -41,6 +44,9 @@ func (d *Dns) getCachedForward(host string) ([]string, bool) {
4144
}
4245

4346
func (d *Dns) getCachedReverse(addr string) ([]string, bool) {
47+
if d.cache == nil {
48+
return nil, false
49+
}
4450
if cached, err := d.cache.reverse.Get(d.ctx, addr); err == nil {
4551
slog.Debug("DNS: reverse cache hit", "addr", addr, "names", cached)
4652
return cached, true
@@ -50,9 +56,15 @@ func (d *Dns) getCachedReverse(addr string) ([]string, bool) {
5056
}
5157

5258
func (d *Dns) forwardCachePut(host string, entries []string) {
59+
if d.cache == nil {
60+
return
61+
}
5362
d.cache.forward.Set(d.ctx, host, entries, d.cache.forwardTTL)
5463
}
5564

5665
func (d *Dns) reverseCachePut(addr string, entries []string) {
66+
if d.cache == nil {
67+
return
68+
}
5769
d.cache.reverse.Set(d.ctx, addr, entries, d.cache.reverseTTL)
5870
}

lib/policy/expressions/environment_test.go

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,18 +1,28 @@
11
package expressions
22

33
import (
4+
"context"
45
"errors"
56
"net"
67
"strings"
78
"testing"
89

9-
"github.com/TecharoHQ/anubis/internal"
10+
"github.com/TecharoHQ/anubis/internal/dns"
11+
"github.com/TecharoHQ/anubis/lib/store/memory"
1012
"github.com/google/cel-go/common/types"
1113
"github.com/google/cel-go/common/types/ref"
1214
)
1315

16+
// newTestDNS is a helper function to create a new Dns object with an in-memory cache for testing.
17+
func newTestDNS(forwardTTL int, reverseTTL int) *dns.Dns {
18+
ctx := context.Background()
19+
memStore := memory.New(ctx)
20+
cache := dns.NewDNSCache(forwardTTL, reverseTTL, memStore)
21+
return dns.New(ctx, cache)
22+
}
23+
1424
func TestBotEnvironment(t *testing.T) {
15-
dnsObj := internal.NewDNS(300, 300)
25+
dnsObj := newTestDNS(300, 300)
1626
env, err := BotEnvironment(dnsObj)
1727
if err != nil {
1828
t.Fatalf("failed to create bot environment: %v", err)
@@ -291,11 +301,11 @@ func TestBotEnvironment(t *testing.T) {
291301
})
292302

293303
t.Run("dnsFunctions", func(t *testing.T) {
294-
originalDNSLookupAddr := internal.DNSLookupAddr
295-
originalDNSLookupHost := internal.DNSLookupHost
304+
originalDNSLookupAddr := dns.DNSLookupAddr
305+
originalDNSLookupHost := dns.DNSLookupHost
296306
defer func() {
297-
internal.DNSLookupAddr = originalDNSLookupAddr
298-
internal.DNSLookupHost = originalDNSLookupHost
307+
dns.DNSLookupAddr = originalDNSLookupAddr
308+
dns.DNSLookupHost = originalDNSLookupHost
299309
}()
300310

301311
t.Run("reverseDNS", func(t *testing.T) {
@@ -337,7 +347,7 @@ func TestBotEnvironment(t *testing.T) {
337347

338348
for _, tt := range tests {
339349
t.Run(tt.name, func(t *testing.T) {
340-
internal.DNSLookupAddr = func(addr string) ([]string, error) {
350+
dns.DNSLookupAddr = func(addr string) ([]string, error) {
341351
if addr == tt.addr {
342352
return tt.mockReturn, tt.mockError
343353
}
@@ -399,7 +409,7 @@ func TestBotEnvironment(t *testing.T) {
399409

400410
for _, tt := range tests {
401411
t.Run(tt.name, func(t *testing.T) {
402-
internal.DNSLookupHost = func(host string) ([]string, error) {
412+
dns.DNSLookupHost = func(host string) ([]string, error) {
403413
if host == tt.host {
404414
return tt.mockReturn, tt.mockError
405415
}
@@ -482,13 +492,13 @@ func TestBotEnvironment(t *testing.T) {
482492

483493
for _, tt := range tests {
484494
t.Run(tt.name, func(t *testing.T) {
485-
internal.DNSLookupAddr = func(addr string) ([]string, error) {
495+
dns.DNSLookupAddr = func(addr string) ([]string, error) {
486496
if addr == tt.addr {
487497
return tt.reverseMockReturn, tt.reverseMockError
488498
}
489499
return nil, errors.New("unexpected address for reverse lookup")
490500
}
491-
internal.DNSLookupHost = func(host string) ([]string, error) {
501+
dns.DNSLookupHost = func(host string) ([]string, error) {
492502
host = strings.TrimSuffix(host, ".")
493503
if ips, ok := tt.forwardMockReturn[host]; ok {
494504
return ips, nil

lib/policy/policy.go

Lines changed: 17 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ type ParsedConfig struct {
4343
StatusCodes config.StatusCodes
4444
DefaultDifficulty int
4545
DNSBL bool
46+
DnsCache *dns.DnsCache
4647
Dns *dns.Dns
4748
Logger *slog.Logger
4849
}
@@ -68,6 +69,22 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic
6869
result := newParsedConfig(c)
6970
result.DefaultDifficulty = defaultDifficulty
7071

72+
stFac, ok := store.Get(c.Store.Backend)
73+
switch ok {
74+
case true:
75+
store, err := stFac.Build(ctx, c.Store.Parameters)
76+
if err != nil {
77+
validationErrs = append(validationErrs, err)
78+
} else {
79+
result.Store = store
80+
}
81+
case false:
82+
validationErrs = append(validationErrs, config.ErrUnknownStoreBackend)
83+
}
84+
85+
result.DnsCache = dns.NewDNSCache(result.orig.DNSTTL.Forward, result.orig.DNSTTL.Reverse, result.Store)
86+
result.Dns = dns.New(ctx, result.DnsCache)
87+
7188
for _, b := range c.Bots {
7289
if berr := b.Valid(); berr != nil {
7390
validationErrs = append(validationErrs, berr)
@@ -196,22 +213,6 @@ func ParseConfig(ctx context.Context, fin io.Reader, fname string, defaultDiffic
196213
result.Thresholds = append(result.Thresholds, threshold)
197214
}
198215

199-
stFac, ok := store.Get(c.Store.Backend)
200-
switch ok {
201-
case true:
202-
store, err := stFac.Build(ctx, c.Store.Parameters)
203-
if err != nil {
204-
validationErrs = append(validationErrs, err)
205-
} else {
206-
result.Store = store
207-
}
208-
case false:
209-
validationErrs = append(validationErrs, config.ErrUnknownStoreBackend)
210-
}
211-
212-
dnsCache := dns.NewDNSCache(result.orig.DNSTTL.Forward, result.orig.DNSTTL.Reverse, result.Store)
213-
result.Dns = dns.New(ctx, dnsCache)
214-
215216
if c.Logging.Level != nil {
216217
logLevel = c.Logging.Level.String()
217218
}

0 commit comments

Comments
 (0)