Skip to content
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
140 changes: 117 additions & 23 deletions dns/resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,40 @@ func (f FuncResolver) Query(ctx context.Context, q dnsmessage.Question) (*dnsmes
return f(ctx, q)
}

// RawResolver can query DNS and return the raw wire-format response bytes as defined in RFC 1035.
// Using plain name and qtype avoids a dependency on any specific DNS parsing library,
// allowing callers to parse the response with any library — including those that support
// record types not yet recognized by golang.org/x/net/dns/dnsmessage.
type RawResolver interface {
QueryRaw(ctx context.Context, name string, qtype uint16) ([]byte, error)
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to decide on the qname format. Could be the native []byte format with length-prefixed labels, or [][]byte (slice of labels), or a escaped string ("foo.bar.com").

While this doesn't affect regular usage, it can help use cases like DNS tunneling.

Comment thread
fortuna marked this conversation as resolved.
Outdated
}

// FuncRawResolver is a [RawResolver] that uses the given function to query DNS.
type FuncRawResolver func(ctx context.Context, name string, qtype uint16) ([]byte, error)

// QueryRaw implements the [RawResolver] interface.
func (f FuncRawResolver) QueryRaw(ctx context.Context, name string, qtype uint16) ([]byte, error) {
return f(ctx, name, qtype)
}

// RawToResolver wraps a [RawResolver] in a [Resolver] that parses the wire-format
// response bytes using golang.org/x/net/dns/dnsmessage.
// The underlying [RawResolver] is responsible for ID matching and returning valid bytes;
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can remove this line.

// this adapter only unpacks the result.
func RawToResolver(r RawResolver) Resolver {
return FuncResolver(func(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) {
raw, err := r.QueryRaw(ctx, q.Name.String(), uint16(q.Type))
if err != nil {
return nil, err
}
var msg dnsmessage.Message
if err := msg.Unpack(raw); err != nil {
return nil, &nestedError{ErrBadResponse, fmt.Errorf("failed to unpack DNS response: %w", err)}
}
return &msg, nil
})
}

// NewQuestion is a convenience function to create a [dnsmessage.Question].
// The input domain is interpreted as fully-qualified. If the end "." is missing, it's added.
func NewQuestion(domain string, qtype dnsmessage.Type) (*dnsmessage.Question, error) {
Expand All @@ -93,6 +127,15 @@ func NewQuestion(domain string, qtype dnsmessage.Type) (*dnsmessage.Question, er
// for the IPv6 and UDP headers".
const maxUDPMessageSize = 1232

// makeQuestion constructs a dnsmessage.Question from a plain name and record type.
func makeQuestion(name string, qtype uint16) (dnsmessage.Question, error) {
q, err := NewQuestion(name, dnsmessage.Type(qtype))
Comment thread
fortuna marked this conversation as resolved.
Outdated
if err != nil {
return dnsmessage.Question{}, err
}
return *q, nil
}

// appendRequest appends the bytes a DNS request using the id and question to buf.
func appendRequest(id uint16, q dnsmessage.Question, buf []byte) ([]byte, error) {
b := dnsmessage.NewBuilder(buf, dnsmessage.Header{ID: id, RecursionDesired: true})
Expand Down Expand Up @@ -167,8 +210,13 @@ func checkResponse(reqID uint16, reqQues dnsmessage.Question, respHdr dnsmessage
}

// queryDatagram implements a DNS query over a datagram protocol.
func queryDatagram(conn io.ReadWriter, q dnsmessage.Question) (*dnsmessage.Message, error) {
// It validates the response ID and question echo before returning raw wire-format bytes.
func queryDatagram(conn io.ReadWriter, name string, qtype uint16) ([]byte, error) {
// Reference: https://cs.opensource.google/go/go/+/master:src/net/dnsclient_unix.go?q=func:dnsPacketRoundTrip&ss=go%2Fgo
q, err := makeQuestion(name, qtype)
if err != nil {
return nil, &nestedError{ErrBadRequest, fmt.Errorf("invalid question: %w", err)}
}
id := uint16(rand.Uint32())
buf, err := appendRequest(id, q, make([]byte, 0, maxUDPMessageSize))
if err != nil {
Expand Down Expand Up @@ -198,17 +246,24 @@ func queryDatagram(conn io.ReadWriter, q dnsmessage.Question) (*dnsmessage.Messa
returnErr = errors.Join(returnErr, err)
continue
}
return &msg, nil
result := make([]byte, n)
copy(result, buf[:n])
return result, nil
}
}

// queryStream implements a DNS query over a stream protocol. It frames the messages by prepending them with a 2-byte length prefix.
func queryStream(conn io.ReadWriter, q dnsmessage.Question) (*dnsmessage.Message, error) {
// It validates the response ID and question echo before returning raw wire-format bytes.
func queryStream(conn io.ReadWriter, name string, qtype uint16) ([]byte, error) {
// Reference: https://cs.opensource.google/go/go/+/master:src/net/dnsclient_unix.go?q=func:dnsStreamRoundTrip&ss=go%2Fgo
q, err := makeQuestion(name, qtype)
if err != nil {
return nil, &nestedError{ErrBadRequest, fmt.Errorf("invalid question: %w", err)}
}
id := uint16(rand.Uint32())
buf, err := appendRequest(id, q, make([]byte, 2, 514))
if err != nil {
return nil, &nestedError{ErrBadRequest, fmt.Errorf("append request failed: %w", err)}
return nil, &nestedError{ErrBadRequest, err}
}
// Buffer length must fit in a uint16.
if len(buf) > 1<<16-1 {
Expand Down Expand Up @@ -241,7 +296,7 @@ func queryStream(conn io.ReadWriter, q dnsmessage.Question) (*dnsmessage.Message
if err := checkResponse(id, q, msg.Header, msg.Questions); err != nil {
return nil, &nestedError{ErrBadResponse, err}
}
return &msg, nil
return buf, nil
}

func ensurePort(address string, defaultPort string) string {
Expand All @@ -256,13 +311,13 @@ func ensurePort(address string, defaultPort string) string {
return address
}

// NewUDPResolver creates a [Resolver] that implements the DNS-over-UDP protocol, using a [transport.PacketDialer] for transport.
// NewUDPRawResolver creates a [RawResolver] that implements the DNS-over-UDP protocol, using a [transport.PacketDialer] for transport.
// It uses a different port for every request.
//
// [DNS-over-UDP]: https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.1
func NewUDPResolver(pd transport.PacketDialer, resolverAddr string) Resolver {
func NewUDPRawResolver(pd transport.PacketDialer, resolverAddr string) RawResolver {
Comment on lines 402 to +403
Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think these factory functions work well. We should consider a way to specify configs, like the max packet length for UDP, and layer other functionality that are common across implementations.

resolverAddr = ensurePort(resolverAddr, "53")
return FuncResolver(func(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) {
return FuncRawResolver(func(ctx context.Context, name string, qtype uint16) ([]byte, error) {
conn, err := pd.DialPacket(ctx, resolverAddr)
if err != nil {
return nil, &nestedError{ErrDial, err}
Expand All @@ -271,15 +326,23 @@ func NewUDPResolver(pd transport.PacketDialer, resolverAddr string) Resolver {
if deadline, ok := ctx.Deadline(); ok {
conn.SetDeadline(deadline)
}
return queryDatagram(conn, q)
return queryDatagram(conn, name, qtype)
})
}

type streamResolver struct {
// NewUDPResolver creates a [Resolver] that implements the DNS-over-UDP protocol, using a [transport.PacketDialer] for transport.
// It uses a different port for every request.
//
// [DNS-over-UDP]: https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.1
func NewUDPResolver(pd transport.PacketDialer, resolverAddr string) Resolver {
return RawToResolver(NewUDPRawResolver(pd, resolverAddr))
}

type streamRawResolver struct {
NewConn func(context.Context) (transport.StreamConn, error)
}

func (r *streamResolver) Query(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) {
func (r *streamRawResolver) QueryRaw(ctx context.Context, name string, qtype uint16) ([]byte, error) {
conn, err := r.NewConn(ctx)
if err != nil {
return nil, &nestedError{ErrDial, err}
Expand All @@ -289,31 +352,39 @@ func (r *streamResolver) Query(ctx context.Context, q dnsmessage.Question) (*dns
if deadline, ok := ctx.Deadline(); ok {
conn.SetDeadline(deadline)
}
return queryStream(conn, q)
return queryStream(conn, name, qtype)
}

// NewTCPResolver creates a [Resolver] that implements the [DNS-over-TCP] protocol, using a [transport.StreamDialer] for transport.
// NewTCPRawResolver creates a [RawResolver] that implements the [DNS-over-TCP] protocol, using a [transport.StreamDialer] for transport.
// It creates a new connection to the resolver for every request.
//
// [DNS-over-TCP]: https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.2
func NewTCPResolver(sd transport.StreamDialer, resolverAddr string) Resolver {
func NewTCPRawResolver(sd transport.StreamDialer, resolverAddr string) RawResolver {
// TODO: Consider handling Authenticated Data.
resolverAddr = ensurePort(resolverAddr, "53")
return &streamResolver{
return &streamRawResolver{
NewConn: func(ctx context.Context) (transport.StreamConn, error) {
return sd.DialStream(ctx, resolverAddr)
},
}
}

// NewTLSResolver creates a [Resolver] that implements the [DNS-over-TLS] protocol, using a [transport.StreamDialer]
// NewTCPResolver creates a [Resolver] that implements the [DNS-over-TCP] protocol, using a [transport.StreamDialer] for transport.
// It creates a new connection to the resolver for every request.
//
// [DNS-over-TCP]: https://datatracker.ietf.org/doc/html/rfc1035#section-4.2.2
func NewTCPResolver(sd transport.StreamDialer, resolverAddr string) Resolver {
return RawToResolver(NewTCPRawResolver(sd, resolverAddr))
}

// NewTLSRawResolver creates a [RawResolver] that implements the [DNS-over-TLS] protocol, using a [transport.StreamDialer]
// to connect to the resolverAddr, and the resolverName as the TLS server name.
// It creates a new connection to the resolver for every request.
//
// [DNS-over-TLS]: https://datatracker.ietf.org/doc/html/rfc7858
func NewTLSResolver(sd transport.StreamDialer, resolverAddr string, resolverName string) Resolver {
func NewTLSRawResolver(sd transport.StreamDialer, resolverAddr string, resolverName string) RawResolver {
resolverAddr = ensurePort(resolverAddr, "853")
return &streamResolver{
return &streamRawResolver{
NewConn: func(ctx context.Context) (transport.StreamConn, error) {
baseConn, err := sd.DialStream(ctx, resolverAddr)
if err != nil {
Expand All @@ -324,12 +395,21 @@ func NewTLSResolver(sd transport.StreamDialer, resolverAddr string, resolverName
}
}

// NewHTTPSResolver creates a [Resolver] that implements the [DNS-over-HTTPS] protocol, using a [transport.StreamDialer]
// NewTLSResolver creates a [Resolver] that implements the [DNS-over-TLS] protocol, using a [transport.StreamDialer]
// to connect to the resolverAddr, and the resolverName as the TLS server name.
// It creates a new connection to the resolver for every request.
//
// [DNS-over-TLS]: https://datatracker.ietf.org/doc/html/rfc7858
func NewTLSResolver(sd transport.StreamDialer, resolverAddr string, resolverName string) Resolver {
return RawToResolver(NewTLSRawResolver(sd, resolverAddr, resolverName))
}

// NewHTTPSRawResolver creates a [RawResolver] that implements the [DNS-over-HTTPS] protocol, using a [transport.StreamDialer]
// to connect to the resolverAddr, and the url as the DoH template URI.
// It uses an internal HTTP client that reuses connections when possible.
//
// [DNS-over-HTTPS]: https://datatracker.ietf.org/doc/html/rfc8484
func NewHTTPSResolver(sd transport.StreamDialer, resolverAddr string, url string) Resolver {
func NewHTTPSRawResolver(sd transport.StreamDialer, resolverAddr string, url string) RawResolver {
resolverAddr = ensurePort(resolverAddr, "443")
dialContext := func(ctx context.Context, network, addr string) (net.Conn, error) {
if !strings.HasPrefix(network, "tcp") {
Expand All @@ -352,11 +432,16 @@ func NewHTTPSResolver(sd transport.StreamDialer, resolverAddr string, url string
ResponseHeaderTimeout: 20 * time.Second, // Same value as Android DNS-over-TLS
},
}
return FuncResolver(func(ctx context.Context, q dnsmessage.Question) (*dnsmessage.Message, error) {
return FuncRawResolver(func(ctx context.Context, name string, qtype uint16) ([]byte, error) {
// Prepare request.
// DoH uses ID=0 per RFC 8484.
q, err := makeQuestion(name, qtype)
if err != nil {
return nil, &nestedError{ErrBadRequest, fmt.Errorf("invalid question: %w", err)}
}
buf, err := appendRequest(0, q, make([]byte, 0, 512))
if err != nil {
return nil, &nestedError{ErrBadRequest, fmt.Errorf("append request failed: %w", err)}
return nil, &nestedError{ErrBadRequest, err}
}
httpReq, err := http.NewRequestWithContext(ctx, "POST", url, bytes.NewBuffer(buf))
if err != nil {
Expand Down Expand Up @@ -388,6 +473,15 @@ func NewHTTPSResolver(sd transport.StreamDialer, resolverAddr string, url string
if err := checkResponse(0, q, msg.Header, msg.Questions); err != nil {
return nil, &nestedError{ErrBadResponse, err}
}
return &msg, nil
return response, nil
})
}

// NewHTTPSResolver creates a [Resolver] that implements the [DNS-over-HTTPS] protocol, using a [transport.StreamDialer]
// to connect to the resolverAddr, and the url as the DoH template URI.
// It uses an internal HTTP client that reuses connections when possible.
//
// [DNS-over-HTTPS]: https://datatracker.ietf.org/doc/html/rfc8484
func NewHTTPSResolver(sd transport.StreamDialer, resolverAddr string, url string) Resolver {
return RawToResolver(NewHTTPSRawResolver(sd, resolverAddr, url))
}
Loading
Loading