diff --git a/.gitignore b/.gitignore index 1ce0c4a..fbcaaea 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,11 @@ # binaries /client +/echo /proxy # testing certs /cert.pem /key.pem + +# Go sum +go.sum diff --git a/README.md b/README.md index 5813649..818b40a 100644 --- a/README.md +++ b/README.md @@ -85,17 +85,35 @@ ns1.com. 25 IN RRSIG A 13 2 26 20190325121645 20190323121645 44688 ns1.com. xJK5 ; EDNS: version 0; flags: do; udp: 512 ``` +## Echo server + +This codebase also includes an echo server, i.e. a server that, for each stream, +reads its whole contents, and reflects them back to the client. + +To build the echo server, use the following command: + +``` +go build ./cmd/echo +``` + +The echo server may be run the same way as the proxy, except that it does not +accept a `-backend` option, since it does not forward queries anywhere. + +``` +sudo ./echo +``` + ## Troubleshooting Note that this is an experimental code built on top of an experimental protocol. -The server and client in this repository use the same QUIC library +The servers and client in this repository use the same QUIC library and therefore they should be compatible. However, if a different client is used, the handshake may fail on the version negotiation. We suggest to check packet capture first when the client is unable to connect. -The proxy also logs information about accepted connections and streams which -can be used to inspect the sequence of events: +The proxy and the echo server also log information about accepted connections +and streams, which can be used to inspect the sequence of events: ``` $ sudo ./proxy -listen 127.0.0.1:853 -cert cert.pem -key key.pem -backend 8.8.4.4:53 diff --git a/cmd/client/main.go b/cmd/client/main.go index de7ccee..453f5ee 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -11,7 +11,7 @@ import ( "sync" "time" - "github.com/lucas-clemente/quic-go" + "github.com/quic-go/quic-go" "github.com/miekg/dns" ) diff --git a/cmd/echo/main.go b/cmd/echo/main.go new file mode 100644 index 0000000..95e2f12 --- /dev/null +++ b/cmd/echo/main.go @@ -0,0 +1,101 @@ +package main + +import ( + "encoding/binary" + "flag" + "fmt" + "io" + + "github.com/go-kit/log" + quic "github.com/quic-go/quic-go" + "github.com/miekg/dns" + + "github.com/ns1/doq-proxy/server" +) + +func main() { + server.Main(genFlags, handleStream) +} + +func genFlags(dns *bool) { + flag.BoolVar(dns, "dns", true, "If true, validates the traffic as DNS.") +} + +func handleDnsStream(l log.Logger, stream quic.Stream) error { + defer stream.Close() + + wireLength := make([]byte, 2) + _, err := io.ReadFull(stream, wireLength) + if err != nil { + return fmt.Errorf("read query length: %w", err) + } + + length := binary.BigEndian.Uint16(wireLength) + + wireQuery := make([]byte, length) + _, err = io.ReadFull(stream, wireQuery) + if err != nil { + return fmt.Errorf("read query payload: %w", err) + } + + msg := dns.Msg{} + err = msg.Unpack(wireQuery) + if err != nil { + return fmt.Errorf("could not decode query: %w", err) + } + + if msg.MsgHdr.Response { + l.Log("msg", "QR bit already set") + } + + msg.MsgHdr.Response = true + + bundle := make([]byte, 0) + responseWire, err := msg.Pack() + if err != nil { + return fmt.Errorf("could not encode response: %w", err) + } + + bundle = binary.BigEndian.AppendUint16(bundle, uint16(len(responseWire))) + bundle = append(bundle, responseWire...) + + _, err = stream.Write(bundle) + if err != nil { + return fmt.Errorf("send response: %w", err) + } + + return nil +} + +func handleDumbStream(l log.Logger, stream quic.Stream) error { + for { + end := false + data := make([]byte, 2048) + n, err := stream.Read(data) + if err == io.EOF { + end = true + } else if err != nil { + return fmt.Errorf("read query: %w", err) + } + + _, err = stream.Write(data[:n]) + if err != nil { + return fmt.Errorf("send response: %w", err) + } + + if end { + stream.Close() + break + } + } + + return nil +} + +func handleStream(l log.Logger, stream quic.Stream, dns bool) error { + if dns { + return handleDnsStream(l, stream) + } else { + return handleDumbStream(l, stream) + } +} diff --git a/cmd/proxy/main.go b/cmd/proxy/main.go index 73346f8..3af83ba 100644 --- a/cmd/proxy/main.go +++ b/cmd/proxy/main.go @@ -1,9 +1,7 @@ package main import ( - "context" "crypto/rand" - "crypto/tls" "encoding/binary" "errors" "flag" @@ -11,153 +9,24 @@ import ( "io" "net" "os" - "os/signal" - "sync" - "syscall" "time" - "github.com/go-kit/kit/log" - quic "github.com/lucas-clemente/quic-go" + "github.com/go-kit/log" + quic "github.com/quic-go/quic-go" "github.com/miekg/dns" - "github.com/oklog/run" + + "github.com/ns1/doq-proxy/server" ) func main() { - l := log.NewLogfmtLogger(log.NewSyncWriter(os.Stdout)) - l = log.WithPrefix(l, "ts", log.DefaultTimestampUTC) - - var g run.Group - - // proxy code loop - { - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - g.Add(func() error { - return loop(l, ctx) - }, func(error) { - cancel() - }) - } - - // signal termination - { - sigterm := make(chan os.Signal, 1) - g.Add(func() error { - signal.Notify(sigterm, syscall.SIGINT, syscall.SIGTERM) - if sig, ok := <-sigterm; ok { - l.Log("msg", "stopping the proxy", "signal", sig.String()) - } - return nil - }, func(error) { - signal.Stop(sigterm) - close(sigterm) - }) - } - - err := g.Run() - if err != nil { - l.Log("msg", "terminating after error", "err", err) - os.Exit(1) - } + server.Main(genFlags, handleStream) } -func loop(l log.Logger, ctx context.Context) error { - var ( - addr string - tlsCert string - tlsKey string - backend string - ) - - flag.StringVar(&addr, "listen", "127.0.0.1:853", "UDP address to listen on.") - flag.StringVar(&tlsCert, "cert", "cert.pem", "TLS certificate path.") - flag.StringVar(&tlsKey, "key", "key.pem", "TLS key path.") - flag.StringVar(&backend, "backend", "8.8.4.4:53", "IP of backend server.") - - flag.Parse() - - cert, err := tls.LoadX509KeyPair(tlsCert, tlsKey) - if err != nil { - return fmt.Errorf("load certificate: %w", err) - } - - tls := tls.Config{ - Certificates: []tls.Certificate{cert}, - NextProtos: []string{"doq"}, - } - - listener, err := quic.ListenAddr(addr, &tls, nil) - if err != nil { - return fmt.Errorf("listen: %w", err) - } - defer listener.Close() - - l.Log("msg", "listening for clients", "addr", addr) - - wg := sync.WaitGroup{} - - for { - session, err := listener.Accept(ctx) - if err != nil { - wg.Wait() - return fmt.Errorf("accept connection: %w", err) - } - - l := log.With(l, "client", session.RemoteAddr()) - wg.Add(1) - go func() { - handleClient(l, ctx, session, backend) - wg.Done() - }() - } - -} - -func handleClient(l log.Logger, ctx context.Context, session quic.Connection, backend string) { - l.Log("msg", "session accepted") - - var ( - err error - wg sync.WaitGroup = sync.WaitGroup{} - ) - - defer func() { - msg := "" - if err != nil { - msg = err.Error() - } - session.CloseWithError(0, msg) - - l.Log("msg", "session closed") - }() - - for { - stream, err := session.AcceptStream(ctx) - if err != nil { - break - } - - l := log.With(l, "stream_id", stream.StreamID()) - l.Log("msg", "stream accepted") - - wg.Add(1) - go func() { - defer func() { - wg.Done() - l.Log("msg", "stream closed") - }() - - if err := handleStream(stream, backend); err != nil { - l.Log("msg", "stream failure", "err", err) - } - }() - } - - wg.Wait() +func genFlags(backend *string) { + flag.StringVar(backend, "backend", "8.8.4.4:53", "IP of backend server.") } -func handleStream(stream quic.Stream, backend string) error { +func handleStream(l log.Logger, stream quic.Stream, backend string) error { defer stream.Close() wireLength := make([]byte, 2) diff --git a/go.mod b/go.mod index cfb1e2e..b2e5c20 100644 --- a/go.mod +++ b/go.mod @@ -1,32 +1,25 @@ module github.com/ns1/doq-proxy -go 1.17 +go 1.18 require ( - github.com/go-kit/kit v0.12.0 - github.com/lucas-clemente/quic-go v0.29.2 + github.com/go-kit/log v0.2.1 github.com/miekg/dns v1.1.51 github.com/oklog/run v1.1.0 - google.golang.org/protobuf v1.28.0 // indirect + github.com/quic-go/quic-go v0.40.0 ) require ( - github.com/fsnotify/fsnotify v1.6.0 // indirect - github.com/go-kit/log v0.2.1 // indirect - github.com/go-logfmt/logfmt v0.6.0 // indirect - github.com/go-task/slim-sprig v0.0.0-20210107165309-348f09dbbbc0 // indirect - github.com/golang/mock v1.6.0 // indirect - github.com/marten-seemann/qtls-go1-18 v0.1.4 // indirect - github.com/marten-seemann/qtls-go1-19 v0.1.2 // indirect - github.com/nxadm/tail v1.4.8 // indirect - github.com/onsi/ginkgo v1.16.5 // indirect - github.com/stretchr/testify v1.7.0 // indirect - golang.org/x/crypto v0.6.0 // indirect - golang.org/x/exp v0.0.0-20230224173230-c95f2b4c22f2 // indirect - golang.org/x/mod v0.8.0 // indirect - golang.org/x/net v0.7.0 // indirect - golang.org/x/sys v0.5.0 // indirect - golang.org/x/tools v0.6.0 // indirect - gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect - gopkg.in/yaml.v3 v3.0.0 // indirect + github.com/go-logfmt/logfmt v0.5.1 // indirect + github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 // indirect + github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 // indirect + github.com/onsi/ginkgo/v2 v2.9.5 // indirect + github.com/quic-go/qtls-go1-20 v0.4.1 // indirect + go.uber.org/mock v0.3.0 // indirect + golang.org/x/crypto v0.4.0 // indirect + golang.org/x/exp v0.0.0-20221205204356-47842c84f3db // indirect + golang.org/x/mod v0.11.0 // indirect + golang.org/x/net v0.10.0 // indirect + golang.org/x/sys v0.8.0 // indirect + golang.org/x/tools v0.9.1 // indirect ) diff --git a/server/server.go b/server/server.go new file mode 100644 index 0000000..a2d3db3 --- /dev/null +++ b/server/server.go @@ -0,0 +1,217 @@ +package server + +import ( + "context" + "crypto/tls" + "flag" + "fmt" + "math" + "math/rand" + "os" + "os/signal" + "sync" + "syscall" + "time" + + "github.com/go-kit/log" + quic "github.com/quic-go/quic-go" + "github.com/oklog/run" +) + +const DoqUnspecifiedError = 0x5 + +// Adds specific flags for the server type - e.g. proxy takes a string parameter +// containing the backend address. baton is the memory into which the parameters +// are to be stored - the result is then passed to the corresponding +// StreamHandler. +type FlagsGenerator[T any] func(baton *T) + +// Handles data for the QUIC stream. The baton parameter is of a server-specific +// type. +type StreamHandler[T any] func(l log.Logger, stream quic.Stream, baton T) error + +// Starts the DNS-over-QUIC server. T is the type of parameters for the specific +// server - e.g. proxy has a string parameter containing the backend address. +func Main[T any](flagsGenerator FlagsGenerator[T], sh StreamHandler[T]) { + l := log.NewLogfmtLogger(log.NewSyncWriter(os.Stdout)) + l = log.WithPrefix(l, "ts", log.DefaultTimestampUTC) + + var group run.Group + + // proxy code loop + { + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + group.Add(func() error { + var ( + addr string + tlsCert string + tlsKey string + keyLog string + randomReset float64 + baton T + ) + + flag.StringVar(&addr, "listen", "127.0.0.1:853", + "UDP address to listen on.") + flag.StringVar(&tlsCert, "cert", "cert.pem", + "TLS certificate path.") + flag.StringVar(&tlsKey, "key", "key.pem", + "TLS key path.") + flag.StringVar(&keyLog, "keylog", "", + "TLS key log file (e.g. for Wireshark analysis) - none if empty") + flag.Float64Var(&randomReset, "reset", 0.0, + "Float between 0 and 1 determining the chance that a stream will be randomly reset") + if flagsGenerator != nil { + flagsGenerator(&baton) + } + flag.Parse() + + if randomReset < 0.0 || randomReset > 1.0 { + return fmt.Errorf("random-reset value %v is not between 0 and 1", + randomReset) + } + + resetThreshold := uint32(randomReset * math.MaxUint32) + + return loop(l, ctx, sh, addr, tlsCert, tlsKey, keyLog, + resetThreshold, baton) + }, func(error) { + cancel() + }) + } + + // signal termination + { + sigterm := make(chan os.Signal, 1) + group.Add(func() error { + signal.Notify(sigterm, syscall.SIGINT, syscall.SIGTERM) + if sig, ok := <-sigterm; ok { + l.Log("msg", "stopping the proxy", "signal", sig.String()) + } + return nil + }, func(error) { + signal.Stop(sigterm) + close(sigterm) + }) + } + + err := group.Run() + if err != nil { + l.Log("msg", "terminating after error", "err", err) + os.Exit(1) + } +} + +func loop[T any](l log.Logger, ctx context.Context, sh StreamHandler[T], + addr string, tlsCert string, tlsKey string, keyLog string, + resetThreshold uint32, baton T) error { + + cert, err := tls.LoadX509KeyPair(tlsCert, tlsKey) + if err != nil { + return fmt.Errorf("load certificate: %w", err) + } + + tls := tls.Config{ + Certificates: []tls.Certificate{cert}, + NextProtos: []string{"doq"}, + MinVersion: tls.VersionTLS13, + } + + if keyLog != "" { + keyLogFile, err := os.OpenFile(keyLog, os.O_APPEND | os.O_CREATE | os.O_WRONLY, 0755) + if err != nil { + return fmt.Errorf("open keylog file: %w", err) + } + defer keyLogFile.Close() + tls.KeyLogWriter = keyLogFile + } + + + quic_conf := quic.Config{ + MaxIdleTimeout: 10 * time.Second, + Allow0RTT: true, + } + + listener, err := quic.ListenAddrEarly(addr, &tls, &quic_conf) + if err != nil { + return fmt.Errorf("listen: %w", err) + } + defer listener.Close() + + l.Log("msg", "listening for clients", "addr", addr) + + wg := sync.WaitGroup{} + + for { + session, err := listener.Accept(ctx) + if err != nil { + wg.Wait() + return fmt.Errorf("accept connection: %w", err) + } + + l := log.With(l, "client", session.RemoteAddr()) + wg.Add(1) + go func() { + handleClient(l, ctx, session, sh, resetThreshold, baton) + wg.Done() + }() + } +} + +func handleClient[T any](l log.Logger, ctx context.Context, + session quic.Connection, sh StreamHandler[T], + resetThreshold uint32, baton T) { + l.Log("msg", "session accepted") + + var ( + err error + wg sync.WaitGroup = sync.WaitGroup{} + ) + + defer func() { + msg := "" + if err != nil { + msg = err.Error() + l.Log("msg", "session failure", "err", err) + } + session.CloseWithError(0, msg) + + l.Log("msg", "session closed") + }() + + for { + stream, err := session.AcceptStream(ctx) + if err != nil { + break + } + + l := log.With(l, "stream_id", stream.StreamID()) + l.Log("msg", "stream accepted") + + wg.Add(1) + go func() { + defer func() { + wg.Done() + l.Log("msg", "stream closed") + }() + + if resetThreshold > 0 { + r := rand.Uint32() + if r < resetThreshold { + stream.CancelRead(DoqUnspecifiedError) + stream.CancelWrite(DoqUnspecifiedError) + stream.Close() + return + } + } + + if err := sh(l, stream, baton); err != nil { + l.Log("msg", "stream failure", "err", err) + } + }() + } + + wg.Wait() +}