Skip to content

Commit

Permalink
refactor: transport interface
Browse files Browse the repository at this point in the history
  • Loading branch information
natesales committed Aug 3, 2023
1 parent 83c169f commit d1bf0f8
Show file tree
Hide file tree
Showing 15 changed files with 332 additions and 245 deletions.
16 changes: 4 additions & 12 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,10 +47,8 @@ Application Options:
--pad Set EDNS0 padding
--http3 Use HTTP/3 for DoH
--recaxfr Perform recursive AXFR
-f, --format= Output format (pretty, json, yaml, raw)
(default: pretty)
--pretty-ttls Format TTLs in human readable format (default:
true)
-f, --format= Output format (pretty, json, yaml, raw) (default: pretty)
--pretty-ttls Format TTLs in human readable format (default: true)
--color Enable color output
--question Show question section
--answer Show answer section (default: true)
Expand All @@ -63,8 +61,7 @@ Application Options:
--aa Set AA (Authoritative Answer) flag in query
--ad Set AD (Authentic Data) flag in query
--cd Set CD (Checking Disabled) flag in query
--rd Set RD (Recursion Desired) flag in query
(default: true)
--rd Set RD (Recursion Desired) flag in query (default: true)
--ra Set RA (Recursion Available) flag in query
--z Set Z (Zero) flag in query
--t Set TC (Truncated) flag in query
Expand All @@ -78,13 +75,8 @@ Application Options:
--http-method= HTTP method (default: GET)
--quic-alpn-tokens= QUIC ALPN tokens (default: doq, doq-i11)
--quic-no-pmtud Disable QUIC PMTU discovery
--quic-dial-timeout= QUIC dial timeout (default: 10s)
--quic-idle-timeout= QUIC stream open timeout (default: 10s)
--quic-no-length-prefix Don't add RFC 9250 compliant length prefix
--handshake-timeout= Handshake timeout (default: 10s)
--tcp-dial-timeout= TCP dial timeout (default: 5s)
--default-rr-types= Default record types (default: A, AAAA, NS, MX,
TXT, CNAME)
--default-rr-types= Default record types (default: A, AAAA, NS, MX, TXT, CNAME)
--udp-buffer= Set EDNS0 UDP size in query (default: 1232)
-v, --verbose Show verbose log messages
--trace Show trace log messages
Expand Down
11 changes: 3 additions & 8 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,14 +72,9 @@ type optsTemplate struct {
HTTPMethod string `long:"http-method" description:"HTTP method" default:"GET"`

// QUIC
QUICALPNTokens []string `long:"quic-alpn-tokens" description:"QUIC ALPN tokens" default:"doq" default:"doq-i11"`
QUICNoPMTUD bool `long:"quic-no-pmtud" description:"Disable QUIC PMTU discovery"`
QUICDialTimeout time.Duration `long:"quic-dial-timeout" description:"QUIC dial timeout" default:"10s"`
QUICOpenStreamTimeout time.Duration `long:"quic-idle-timeout" description:"QUIC stream open timeout" default:"10s"`
QUICNoLengthPrefix bool `long:"quic-no-length-prefix" description:"Don't add RFC 9250 compliant length prefix"`

HandshakeTimeout time.Duration `long:"handshake-timeout" description:"Handshake timeout" default:"10s"`
TCPDialTimeout time.Duration `long:"tcp-dial-timeout" description:"TCP dial timeout" default:"5s"`
QUICALPNTokens []string `long:"quic-alpn-tokens" description:"QUIC ALPN tokens" default:"doq" default:"doq-i11"`
QUICNoPMTUD bool `long:"quic-no-pmtud" description:"Disable QUIC PMTU discovery"`
QUICNoLengthPrefix bool `long:"quic-no-length-prefix" description:"Don't add RFC 9250 compliant length prefix"`

DefaultRRTypes []string `long:"default-rr-types" description:"Default record types" default:"A" default:"AAAA" default:"NS" default:"MX" default:"TXT" default:"CNAME"`

Expand Down
52 changes: 41 additions & 11 deletions resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,8 @@ import (
"net"

"github.com/miekg/dns"
log "github.com/sirupsen/logrus"

"github.com/natesales/q/transport"
log "github.com/sirupsen/logrus"
)

// createQuery creates a slice of DNS queries
Expand Down Expand Up @@ -111,32 +110,63 @@ func createQuery(

// query performs a DNS query and returns the reply
func query(msg dns.Msg, server, protocol string, tlsConfig *tls.Config) (*dns.Msg, error) {
var reply *dns.Msg
var err error
var ts transport.Transport

switch protocol {
case "https", "http":
if opts.ODoHProxy != "" {
log.Debugf("Using ODoH transport with target %s proxy %s", server, opts.ODoHProxy)
reply, err = transport.ODoH(msg, server, opts.ODoHProxy)
ts = &transport.ODoH{
Target: server,
Proxy: opts.ODoHProxy,
TLSConfig: tlsConfig,
}
} else {
log.Debug("Using HTTP(s) transport")
reply, err = transport.HTTP(&msg, tlsConfig, server, opts.HTTPUserAgent, opts.HTTPMethod, opts.Timeout, opts.HandshakeTimeout, opts.HTTP3, opts.QUICNoPMTUD)
ts = &transport.HTTP{
Server: server,
TLSConfig: tlsConfig,
UserAgent: opts.HTTPUserAgent,
Method: opts.HTTPMethod,
Timeout: opts.Timeout,
HTTP3: opts.HTTP3,
NoPMTUd: opts.QUICNoPMTUD,
}
}
case "quic":
log.Debug("Using QUIC transport")
reply, err = transport.QUIC(&msg, server, tlsConfig, opts.QUICDialTimeout, opts.HandshakeTimeout, opts.QUICOpenStreamTimeout, opts.QUICNoPMTUD, !opts.QUICNoLengthPrefix)
ts = &transport.QUIC{
Server: server,
TLSConfig: tlsConfig,
NoPMTUD: opts.QUICNoPMTUD,
AddLengthPrefix: !opts.QUICNoLengthPrefix,
}
case "tls":
log.Debug("Using TLS transport")
reply, err = transport.TLS(&msg, server, tlsConfig, opts.TCPDialTimeout)
ts = &transport.TLS{
Server: server,
TLSConfig: tlsConfig,
Timeout: opts.Timeout,
}
case "tcp":
log.Debug("Using TCP transport")
reply, err = transport.Plain(&msg, server, true, opts.Timeout, opts.UDPBuffer)
ts = &transport.Plain{
Server: server,
PreferTCP: true,
Timeout: opts.Timeout,
UDPBuffer: opts.UDPBuffer,
}
case "plain":
log.Debug("Using UDP with TCP fallback")
reply, err = transport.Plain(&msg, server, false, opts.Timeout, opts.UDPBuffer)
ts = &transport.Plain{
Server: server,
PreferTCP: false,
Timeout: opts.Timeout,
UDPBuffer: opts.UDPBuffer,
}
default:
return nil, fmt.Errorf("unknown transport protocol %s", protocol)
}
return reply, err

return ts.(transport.Transport).Exchange(&msg)
}
55 changes: 27 additions & 28 deletions transport/http.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,32 @@ import (
)

// HTTP makes a DNS query over HTTP(s)
func HTTP(
m *dns.Msg, tlsConfig *tls.Config,
server, userAgent, method string,
timeout, handshakeTimeout time.Duration,
h3, noPMTUD bool) (*dns.Msg, error) {
type HTTP struct {
Server string
TLSConfig *tls.Config
UserAgent string
Method string
Timeout time.Duration
HTTP3 bool
NoPMTUd bool
}

func (h *HTTP) Exchange(m *dns.Msg) (*dns.Msg, error) {
httpClient := &http.Client{
Timeout: h.Timeout,
Transport: &http.Transport{
TLSClientConfig: tlsConfig,
MaxConnsPerHost: 1,
MaxIdleConns: 1,
TLSHandshakeTimeout: handshakeTimeout,
Proxy: http.ProxyFromEnvironment,
TLSClientConfig: h.TLSConfig,
MaxConnsPerHost: 1,
MaxIdleConns: 1,
Proxy: http.ProxyFromEnvironment,
},
Timeout: timeout,
}
if h3 {
if h.HTTP3 {
log.Debug("Using HTTP/3")
httpClient.Transport = &http3.RoundTripper{
TLSClientConfig: tlsConfig,
TLSClientConfig: h.TLSConfig,
QuicConfig: &quic.Config{
HandshakeIdleTimeout: handshakeTimeout,
DisablePathMTUDiscovery: noPMTUD,
DisablePathMTUDiscovery: h.NoPMTUd,
},
}
}
Expand All @@ -46,19 +50,19 @@ func HTTP(
return nil, fmt.Errorf("packing message: %w", err)
}

queryURL := server + "?dns=" + base64.RawURLEncoding.EncodeToString(buf)
req, err := http.NewRequest(method, queryURL, nil)
queryURL := h.Server + "?dns=" + base64.RawURLEncoding.EncodeToString(buf)
req, err := http.NewRequest(h.Method, queryURL, nil)
if err != nil {
return nil, fmt.Errorf("creating http request to %s: %w", queryURL, err)
}

req.Header.Set("Accept", "application/dns-message")
if userAgent != "" {
log.Debugf("Setting User-Agent to %s", userAgent)
req.Header.Set("User-Agent", userAgent)
if h.UserAgent != "" {
log.Debugf("Setting User-Agent to %s", h.UserAgent)
req.Header.Set("User-Agent", h.UserAgent)
}

log.Debugf("[http] sending %s request to %s", method, queryURL)
log.Debugf("[http] sending %s request to %s", h.Method, queryURL)
resp, err := httpClient.Do(req)
if resp != nil && resp.Body != nil {
defer resp.Body.Close()
Expand All @@ -77,14 +81,9 @@ func HTTP(
}

response := dns.Msg{}
err = response.Unpack(body)
if err != nil {
if err := response.Unpack(body); err != nil {
return nil, fmt.Errorf("unpacking DNS response from %s: %w", queryURL, err)
}

if response.Id != m.Id {
err = dns.ErrId
}

return &response, err
return &response, nil
}
50 changes: 27 additions & 23 deletions transport/http_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,39 +2,38 @@ package transport

import (
"crypto/tls"
"github.com/miekg/dns"
"net/http"
"testing"
"time"

"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
)

func testQuery() *dns.Msg {
msg := dns.Msg{}
msg.RecursionDesired = true
msg.Question = []dns.Question{{
Name: "example.com.",
Qtype: dns.StringToType["A"],
Qclass: dns.ClassINET,
}}
return &msg
}

func TestTransportHTTP(t *testing.T) {
reply, err := HTTP(testQuery(), &tls.Config{}, "https://cloudflare-dns.com/dns-query", "", "GET", 2*time.Second, 2*time.Second, false, false)
assert.Nil(t, err)
assert.Greater(t, len(reply.Answer), 0)
func httpTransport() *HTTP {
return &HTTP{
Server: "https://cloudflare-dns.com/dns-query",
TLSConfig: &tls.Config{},
UserAgent: "",
Method: http.MethodGet,
Timeout: 2 * time.Second,
HTTP3: false,
NoPMTUd: false,
}
}

func TestTransportHTTP3(t *testing.T) {
reply, err := HTTP(testQuery(), &tls.Config{}, "https://cloudflare-dns.com/dns-query", "", "GET", 2*time.Second, 2*time.Second, true, false)
tp := httpTransport()
tp.HTTP3 = true
reply, err := tp.Exchange(validQuery())
assert.Nil(t, err)
assert.Greater(t, len(reply.Answer), 0)
}

func TestTransportHTTPInvalidResolver(t *testing.T) {
_, err := HTTP(&dns.Msg{}, &tls.Config{}, "https://example.com", "", "GET", 2*time.Second, 2*time.Second, false, false)
tp := httpTransport()
tp.Server = "https://example.com"
_, err := tp.Exchange(validQuery())
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "unpacking DNS response")
}
Expand All @@ -46,7 +45,9 @@ func TestTransportHTTPServerError(t *testing.T) {
}))
}()

_, err := HTTP(&dns.Msg{}, &tls.Config{}, "http://localhost:8080", "", "GET", 2*time.Second, 2*time.Second, false, false)
tp := httpTransport()
tp.Server = "http://localhost:8080"
_, err := tp.Exchange(validQuery())
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "got status code 500")
}
Expand All @@ -64,8 +65,11 @@ func TestTransportHTTPIDMismatch(t *testing.T) {
w.Write(buf)
}))
}()
time.Sleep(50 * time.Millisecond)
_, err := HTTP(&dns.Msg{}, &tls.Config{}, "http://localhost:8085", "", "GET", 2*time.Second, 2*time.Second, false, false)
assert.NotNil(t, err)
assert.Contains(t, err.Error(), "id mismatch")
tp := httpTransport()
tp.Server = "http://localhost:8085"
query := validQuery()
reply, err := tp.Exchange(query)
assert.Nil(t, err)
assert.Equal(t, uint16(1), reply.Id)
assert.NotEqual(t, 1, query.Id)
}
27 changes: 21 additions & 6 deletions transport/odoh.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ package transport

import (
"bytes"
"crypto/tls"
"errors"
"fmt"
"io"
Expand Down Expand Up @@ -59,14 +60,28 @@ func buildURL(s, defaultPath string) *url.URL {
}

// ODoH makes a DNS query over ODoH
func ODoH(query dns.Msg, target, proxy string) (*dns.Msg, error) {
type ODoH struct {
Target string
Proxy string
TLSConfig *tls.Config
}

func (o *ODoH) Exchange(m *dns.Msg) (*dns.Msg, error) {
// Query ODoH configs on target
req, err := http.NewRequest(http.MethodGet, buildURL(strings.TrimSuffix(target, "/dns-query"), "/.well-known/odohconfigs").String(), nil)
req, err := http.NewRequest(
http.MethodGet,
buildURL(strings.TrimSuffix(o.Target, "/dns-query"), "/.well-known/odohconfigs").String(),
nil,
)
if err != nil {
return nil, fmt.Errorf("new target configs request: %s", err)
}

client := http.Client{}
client := http.Client{
Transport: &http.Transport{
TLSClientConfig: o.TLSConfig,
},
}
resp, err := client.Do(req)
if err != nil {
return nil, fmt.Errorf("do target configs request: %s", err)
Expand All @@ -86,7 +101,7 @@ func ODoH(query dns.Msg, target, proxy string) (*dns.Msg, error) {
}
log.Debugf("[odoh] retreived %d ODoH configs", len(odohConfigs.Configs))

packedDnsQuery, err := query.Pack()
packedDnsQuery, err := m.Pack()
if err != nil {
return nil, err
}
Expand All @@ -98,8 +113,8 @@ func ODoH(query dns.Msg, target, proxy string) (*dns.Msg, error) {
return nil, fmt.Errorf("encrypt query: %s", err)
}

t := buildURL(target, "/dns-query")
p := buildURL(proxy, "/proxy")
t := buildURL(o.Target, "/dns-query")
p := buildURL(o.Proxy, "/proxy")
qry := p.Query()
if qry.Get("targethost") == "" {
qry.Set("targethost", t.Host)
Expand Down
Loading

0 comments on commit d1bf0f8

Please sign in to comment.