-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cache.go
140 lines (120 loc) · 2.7 KB
/
cache.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
package main
import (
"context"
"fmt"
"net"
"strings"
"sync"
"time"
"github.com/miekg/dns"
"go.devnw.com/ttl"
)
type Cache struct {
ctx context.Context
logger Logger
cache *ttl.Cache[string, *dns.Msg]
}
// Intercept is the cache intercept func which attempts to first pull
// the response from the cache if it exists. If it is no longer in the
// cache then the request is passed down the pipeline after wrapping
// the request with an interceptor. The interceptor is responsible for
// caching the response on the way back to the client.
func (c *Cache) Intercept(
ctx context.Context,
req *Request,
) (*Request, bool) {
if len(req.r.Question) == 0 {
err := req.Block()
if err != nil {
c.logger.Errorw(
"invalid question",
"category", CACHE,
"request", req.String(),
"error", err,
)
}
}
r, ok := c.cache.Get(c.ctx, req.Key())
if !ok || r == nil {
// Add hook for final response to cache
req.w = &interceptor{
ctx: c.ctx,
cache: c.cache,
logger: c.logger,
req: req,
next: req.w.WriteMsg, // TODO: Determine if this is the correct pattern
}
return req, true
}
err := req.Answer(r.SetReply(req.r))
if err != nil {
c.logger.Errorw(
"failed to set reply",
"category", CACHE,
"request", req.String(),
"error", err,
)
}
return req, false
}
// interceptor is a dns.ResponseWriter that caches the response
// for future queries so that they are not re-requesting an updated
// IP for an address that has already been queried.
type interceptor struct {
ctx context.Context
cache *ttl.Cache[string, *dns.Msg]
logger Logger
req *Request
next func(*dns.Msg) error
once sync.Once
}
func (i *interceptor) WriteMsg(res *dns.Msg) (err error) {
i.once.Do(func() {
ttl := time.Second * DEFAULTTTL
if len(res.Answer) > 0 && res.Answer[0].Header() != nil {
ttl = time.Second * time.Duration(res.Answer[0].Header().Ttl)
}
// Set the cache value with record specific TTL
err = i.cache.SetTTL(i.ctx, i.req.Key(), res, ttl)
if err != nil {
return
}
i.logger.Debugw(
"cache",
"method", WRITE,
"record", i.req.r.Question[0].Name,
"ttl", ttl,
)
})
if err != nil {
return err
}
return i.next(res)
}
type CacheAction string
const (
READ CacheAction = "read"
WRITE CacheAction = "write"
)
type CacheEvent struct {
Method CacheAction
Record string
Location net.IP
TTL time.Duration
}
func (e *CacheEvent) String() string {
ip := "<missing>"
if e.Location != nil {
ip = fmt.Sprintf(" %s %s ", e.Record, e.Location)
}
return fmt.Sprintf(
"CACHE %s %s %s %s",
strings.ToUpper(string(e.Method)),
e.Record,
e.TTL,
ip,
)
}
func (e *CacheEvent) Event() string {
return e.String()
}