|
1 | 1 | package expressions |
2 | 2 |
|
3 | 3 | import ( |
| 4 | + "context" |
4 | 5 | "errors" |
5 | 6 | "net" |
6 | 7 | "strings" |
7 | 8 | "testing" |
8 | 9 |
|
9 | | - "github.com/TecharoHQ/anubis/internal" |
| 10 | + "github.com/TecharoHQ/anubis/internal/dns" |
| 11 | + "github.com/TecharoHQ/anubis/lib/store/memory" |
10 | 12 | "github.com/google/cel-go/common/types" |
11 | 13 | "github.com/google/cel-go/common/types/ref" |
12 | 14 | ) |
13 | 15 |
|
| 16 | +// newTestDNS is a helper function to create a new Dns object with an in-memory cache for testing. |
| 17 | +func newTestDNS(forwardTTL int, reverseTTL int) *dns.Dns { |
| 18 | + ctx := context.Background() |
| 19 | + memStore := memory.New(ctx) |
| 20 | + cache := dns.NewDNSCache(forwardTTL, reverseTTL, memStore) |
| 21 | + return dns.New(ctx, cache) |
| 22 | +} |
| 23 | + |
14 | 24 | func TestBotEnvironment(t *testing.T) { |
15 | | - dnsObj := internal.NewDNS(300, 300) |
| 25 | + dnsObj := newTestDNS(300, 300) |
16 | 26 | env, err := BotEnvironment(dnsObj) |
17 | 27 | if err != nil { |
18 | 28 | t.Fatalf("failed to create bot environment: %v", err) |
@@ -291,11 +301,11 @@ func TestBotEnvironment(t *testing.T) { |
291 | 301 | }) |
292 | 302 |
|
293 | 303 | t.Run("dnsFunctions", func(t *testing.T) { |
294 | | - originalDNSLookupAddr := internal.DNSLookupAddr |
295 | | - originalDNSLookupHost := internal.DNSLookupHost |
| 304 | + originalDNSLookupAddr := dns.DNSLookupAddr |
| 305 | + originalDNSLookupHost := dns.DNSLookupHost |
296 | 306 | defer func() { |
297 | | - internal.DNSLookupAddr = originalDNSLookupAddr |
298 | | - internal.DNSLookupHost = originalDNSLookupHost |
| 307 | + dns.DNSLookupAddr = originalDNSLookupAddr |
| 308 | + dns.DNSLookupHost = originalDNSLookupHost |
299 | 309 | }() |
300 | 310 |
|
301 | 311 | t.Run("reverseDNS", func(t *testing.T) { |
@@ -337,7 +347,7 @@ func TestBotEnvironment(t *testing.T) { |
337 | 347 |
|
338 | 348 | for _, tt := range tests { |
339 | 349 | t.Run(tt.name, func(t *testing.T) { |
340 | | - internal.DNSLookupAddr = func(addr string) ([]string, error) { |
| 350 | + dns.DNSLookupAddr = func(addr string) ([]string, error) { |
341 | 351 | if addr == tt.addr { |
342 | 352 | return tt.mockReturn, tt.mockError |
343 | 353 | } |
@@ -399,7 +409,7 @@ func TestBotEnvironment(t *testing.T) { |
399 | 409 |
|
400 | 410 | for _, tt := range tests { |
401 | 411 | t.Run(tt.name, func(t *testing.T) { |
402 | | - internal.DNSLookupHost = func(host string) ([]string, error) { |
| 412 | + dns.DNSLookupHost = func(host string) ([]string, error) { |
403 | 413 | if host == tt.host { |
404 | 414 | return tt.mockReturn, tt.mockError |
405 | 415 | } |
@@ -482,13 +492,13 @@ func TestBotEnvironment(t *testing.T) { |
482 | 492 |
|
483 | 493 | for _, tt := range tests { |
484 | 494 | t.Run(tt.name, func(t *testing.T) { |
485 | | - internal.DNSLookupAddr = func(addr string) ([]string, error) { |
| 495 | + dns.DNSLookupAddr = func(addr string) ([]string, error) { |
486 | 496 | if addr == tt.addr { |
487 | 497 | return tt.reverseMockReturn, tt.reverseMockError |
488 | 498 | } |
489 | 499 | return nil, errors.New("unexpected address for reverse lookup") |
490 | 500 | } |
491 | | - internal.DNSLookupHost = func(host string) ([]string, error) { |
| 501 | + dns.DNSLookupHost = func(host string) ([]string, error) { |
492 | 502 | host = strings.TrimSuffix(host, ".") |
493 | 503 | if ips, ok := tt.forwardMockReturn[host]; ok { |
494 | 504 | return ips, nil |
|
0 commit comments