Skip to content
4 changes: 2 additions & 2 deletions http.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,8 @@ func (p *Proxy) AddHTTPHostRoute(ipPort, httpHost string, dest Target) {
// for any additional routes on ipPort.
//
// The ipPort is any valid net.Listen TCP address.
func (p *Proxy) AddHTTPHostMatchRoute(ipPort string, match Matcher, dest Target) {
p.addRoute(ipPort, httpHostMatch{match, dest})
func (p *Proxy) AddHTTPHostMatchRoute(ipPort string, match Matcher, dest Target) (routeID int) {
return p.addRoute(ipPort, httpHostMatch{match, dest})
}

type httpHostMatch struct {
Expand Down
17 changes: 8 additions & 9 deletions sni.go
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ import (
// with AddStopACMESearch.
//
// The ipPort is any valid net.Listen TCP address.
func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) {
p.AddSNIMatchRoute(ipPort, equals(sni), dest)
func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) (routeID int) {
return p.AddSNIMatchRoute(ipPort, equals(sni), dest)
}

// AddSNIMatchRoute appends a route to the ipPort listener that routes
Expand All @@ -48,16 +48,15 @@ func (p *Proxy) AddSNIRoute(ipPort, sni string, dest Target) {
// with AddStopACMESearch.
//
// The ipPort is any valid net.Listen TCP address.
func (p *Proxy) AddSNIMatchRoute(ipPort string, matcher Matcher, dest Target) {
func (p *Proxy) AddSNIMatchRoute(ipPort string, matcher Matcher, dest Target) (routeID int) {
cfg := p.configFor(ipPort)
if !cfg.stopACME {
if len(cfg.acmeTargets) == 0 {
p.addRoute(ipPort, &acmeMatch{cfg})
}
cfg.acmeTargets = append(cfg.acmeTargets, dest)
}

p.addRoute(ipPort, sniMatch{matcher, dest})
return p.addRoute(ipPort, sniMatch{matcher, dest})
}

// AddStopACMESearch prevents ACME probing of subsequent SNI routes.
Expand All @@ -74,7 +73,7 @@ type sniMatch struct {
}

func (m sniMatch) match(br *bufio.Reader) Target {
if m.matcher(context.TODO(), clientHelloServerName(br)) {
if m.matcher(context.TODO(), ClientHelloServerName(br)) {
return m.target
}
return nil
Expand All @@ -88,7 +87,7 @@ type acmeMatch struct {
}

func (m *acmeMatch) match(br *bufio.Reader) Target {
sni := clientHelloServerName(br)
sni := ClientHelloServerName(br)
if !strings.HasSuffix(sni, ".acme.invalid") {
return nil
}
Expand Down Expand Up @@ -153,10 +152,10 @@ func tryACME(ctx context.Context, ch chan<- Target, dest Target, sni string) {
ret = dest
}

// clientHelloServerName returns the SNI server name inside the TLS ClientHello,
// ClientHelloServerName returns the SNI server name inside the TLS ClientHello,
// without consuming any bytes from br.
// On any error, the empty string is returned.
func clientHelloServerName(br *bufio.Reader) (sni string) {
func ClientHelloServerName(br *bufio.Reader) (sni string) {
const recordHeaderLen = 5
hdr, err := br.Peek(recordHeaderLen)
if err != nil {
Expand Down
85 changes: 67 additions & 18 deletions tcpproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ import (
"io"
"log"
"net"
"sync"
"time"
)

Expand All @@ -79,6 +80,9 @@ type Proxy struct {
// function. If nil, net.Dial is used.
// The provided net is always "tcp".
ListenFunc func(net, laddr string) (net.Listener, error)

// defaultHandler handles unmatched traffic
defaultHandler Target
}

// Matcher reports whether hostname matches the Matcher's criteria.
Expand All @@ -93,7 +97,9 @@ func equals(want string) Matcher {

// config contains the proxying state for one listener.
type config struct {
routes []route
routes *sync.Map // map[int]route
nextRouteID int

acmeTargets []Target // accumulates targets that should be probed for acme.
stopACME bool // if true, AddSNIRoute doesn't add targets to acmeTargets.
}
Expand Down Expand Up @@ -122,25 +128,58 @@ func (p *Proxy) configFor(ipPort string) *config {
p.configs = make(map[string]*config)
}
if p.configs[ipPort] == nil {
p.configs[ipPort] = &config{}
cfg := &config{}
cfg.routes = &sync.Map{}
cfg.nextRouteID = 1
p.configs[ipPort] = cfg
}
return p.configs[ipPort]
}

func (p *Proxy) addRoute(ipPort string, r route) {
cfg := p.configFor(ipPort)
cfg.routes = append(cfg.routes, r)
func (p *Proxy) addRoute(ipPort string, r route) (routeID int) {
var cfg *config
if p.donec != nil {
// NOTE: Do not create config file if the server is listening.
// This saves the handling of bringing up and tearing down
// listeners when add or remove route.
cfg = p.configs[ipPort]
} else {
cfg = p.configFor(ipPort)
}
if cfg != nil {
routeID = cfg.nextRouteID
cfg.nextRouteID++
cfg.routes.Store(routeID, r)
}
return
}

// SetDefaultHandler sets the default handler for proxy.
func (p *Proxy) SetDefaultHandler(t Target) {
p.defaultHandler = t
}

// AddRoute appends an always-matching route to the ipPort listener,
// directing any connection to dest.
// directing any connection to dest. The added route's id is returned
// for future removal. If routeID is zero, the route is not registered.
//
// This is generally used as either the only rule (for simple TCP
// proxies), or as the final fallback rule for an ipPort.
//
// The ipPort is any valid net.Listen TCP address.
func (p *Proxy) AddRoute(ipPort string, dest Target) {
p.addRoute(ipPort, fixedTarget{dest})
func (p *Proxy) AddRoute(ipPort string, dest Target) (routeID int) {
return p.addRoute(ipPort, fixedTarget{dest})
}

// RemoveRoute removes an existing route for ipPort. If the route is
// not found, this is an no-op.
//
// Both AddRoute and RemoveRoute is go-routine safe.
func (p *Proxy) RemoveRoute(ipPort string, routeID int) {
cfg := p.configs[ipPort]
if cfg != nil {
cfg.routes.Delete(routeID)
}
}

type fixedTarget struct {
Expand Down Expand Up @@ -197,7 +236,7 @@ func (p *Proxy) Start() error {
return err
}
p.lns = append(p.lns, ln)
go p.serveListener(errc, ln, config.routes)
go p.serveListener(errc, ln, config)
}
go p.awaitFirstError(errc)
return nil
Expand All @@ -208,22 +247,24 @@ func (p *Proxy) awaitFirstError(errc <-chan error) {
close(p.donec)
}

func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, routes []route) {
func (p *Proxy) serveListener(ret chan<- error, ln net.Listener, cfg *config) {
for {
c, err := ln.Accept()
if err != nil {
ret <- err
return
}
go p.serveConn(c, routes)
go p.serveConn(c, cfg)
}
}

// serveConn runs in its own goroutine and matches c against routes.
// It returns whether it matched purely for testing.
func (p *Proxy) serveConn(c net.Conn, routes []route) bool {
func (p *Proxy) serveConn(c net.Conn, cfg *config) bool {
br := bufio.NewReader(c)
for _, route := range routes {
var handled bool
cfg.routes.Range(func(k, v interface{}) bool {
route := v.(route)
if target := route.match(br); target != nil {
if n := br.Buffered(); n > 0 {
peeked, _ := br.Peek(br.Buffered())
Expand All @@ -233,13 +274,21 @@ func (p *Proxy) serveConn(c net.Conn, routes []route) bool {
}
}
target.HandleConn(c)
return true
handled = true
return false // exit the iteration
}
return true
})
if !handled {
if p.defaultHandler != nil {
p.defaultHandler.HandleConn(c)
} else {
// TODO: hook for this?
log.Printf("tcpproxy: no routes matched conn %v/%v; closing", c.RemoteAddr().String(), c.LocalAddr().String())
c.Close()
}
}
// TODO: hook for this?
log.Printf("tcpproxy: no routes matched conn %v/%v; closing", c.RemoteAddr().String(), c.LocalAddr().String())
c.Close()
return false
return handled
}

// Conn is an incoming connection that has had some bytes read from it
Expand Down
Loading