11package main
22
33import (
4+ "errors"
5+ "fmt"
46 "strings"
57 "time"
68
9+ "strconv"
10+
711 "github.com/levenlabs/go-llog"
812 "github.com/mediocregopher/lever"
913 "github.com/miekg/dns"
10- "strconv"
1114)
1215
1316var dnsServerGroups [][]string
@@ -50,10 +53,14 @@ func tryProxy(m *dns.Msg, addr string) *dns.Msg {
5053
5154func queryGroup (r * dns.Msg , servers []string ) * dns.Msg {
5255 chs := make ([]chan * dns.Msg , len (servers ))
56+ doneCh := make (chan struct {})
5357 for i := range servers {
5458 chs [i ] = make (chan * dns.Msg , 1 )
5559 go func (ch chan * dns.Msg , addr string ) {
56- ch <- tryProxy (r , addr )
60+ select {
61+ case ch <- tryProxy (r , addr ):
62+ case <- doneCh :
63+ }
5764 }(chs [i ], servers [i ])
5865 }
5966
@@ -63,6 +70,7 @@ func queryGroup(r *dns.Msg, servers []string) *dns.Msg {
6370 break
6471 }
6572 }
73+ close (doneCh )
6674 return m
6775}
6876
@@ -84,22 +92,33 @@ func sendFormatError(w dns.ResponseWriter, r *dns.Msg) {
8492 return
8593}
8694
87- func handleRequest (w dns.ResponseWriter , r * dns.Msg ) {
88- kv := llog.KV {"question" : "" , "type" : "" }
89- var ok bool
95+ func validateRequest (r * dns.Msg ) error {
9096 if len (r .Question ) == 0 {
91- llog .Warn ("received request with no questions" , kv )
92- sendFormatError (w , r )
93- return
97+ return errors .New ("empty question set" )
98+ }
99+ typ , ok := dns .TypeToString [r .Question [0 ].Qtype ]
100+ if ! ok || typ == "None" {
101+ return fmt .Errorf ("invalid question type: %q" , typ )
102+ }
103+ return nil
104+ }
105+
106+ func handleRequest (w dns.ResponseWriter , r * dns.Msg ) {
107+ kv := llog.KV {}
108+ // Can be nil during testing
109+ if raddr := w .RemoteAddr (); raddr != nil {
110+ kv ["srcAddr" ] = raddr .String ()
94111 }
95- kv [ "type" ], ok = dns . TypeToString [ r . Question [ 0 ]. Qtype ]
96- if ! ok || kv [ "type" ] == "None" {
97- kv ["qtype " ] = r . Question [ 0 ]. Qtype
98- llog .Warn ("invalid type received " , kv )
112+
113+ if err := validateRequest ( r ); err != nil {
114+ kv ["err " ] = err
115+ llog .Warn ("invalid request " , kv )
99116 sendFormatError (w , r )
100117 return
101118 }
119+
102120 kv ["question" ] = r .Question [0 ].Name
121+ kv ["questionType" ] = r .Question [0 ].Qtype
103122
104123 llog .Info ("handling request" , kv )
105124 rr := NewReq {r , make (chan * dns.Msg )}
@@ -137,6 +156,8 @@ func handleRequest(w dns.ResponseWriter, r *dns.Msg) {
137156 }
138157}
139158
159+ var version string
160+
140161func main () {
141162 l := lever .New ("struggledns" , nil )
142163 l .Add (lever.Param {
@@ -168,8 +189,21 @@ func main() {
168189 Description : "If we should allow truncated responses to be proxied" ,
169190 Flag : true ,
170191 })
192+ if version != "" {
193+ l .Add (lever.Param {
194+ Name : "--version" ,
195+ Aliases : []string {"-v" },
196+ Description : "Print version info" ,
197+ Flag : true ,
198+ })
199+ }
171200 l .Parse ()
172201
202+ if l .ParamFlag ("--version" ) {
203+ fmt .Println (version )
204+ return
205+ }
206+
173207 addr , _ := l .ParamStr ("--listen-addr" )
174208 dnsServers , _ := l .ParamStrs ("--fwd-to" )
175209 combineGroups := l .ParamFlag ("--parallel" )
@@ -201,7 +235,7 @@ func main() {
201235 DialTimeout : time .Millisecond * 100 ,
202236 WriteTimeout : time .Millisecond * 100 ,
203237 ReadTimeout : time .Millisecond * time .Duration (timeout ),
204- UDPSize : 4096 ,
238+ UDPSize : 4096 ,
205239 }
206240
207241 handler := dns .HandlerFunc (handleRequest )
0 commit comments