Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RSDK-9990- Expose SOCKS proxy as fallback dialer #414

Merged
merged 8 commits into from
Mar 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions rpc/const.go
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,8 @@ var (
// proxies to indicate the address through which to route all network traffic
// via SOCKS5.
SocksProxyEnvVar = "SOCKS_PROXY"

// OnlySocksProxyEnvVar is the name of an environment variable used if all network
// traffic should be done through SOCKS5.
OnlySocksProxyEnvVar = "ONLY_SOCKS_PROXY"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[q] Will there be an equivalent ONLY_INTERNET? Or, is the idea that having SOCKS_PROXY unset means "internet only"?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think that should be ok, + internet will almost always be prioritized/chosen even without this knob

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Discussed offline; this is also helpful for development, as we can force all traffic through the SOCKS proxy even with a WiFi connection (ssh still possible.)

)
151 changes: 137 additions & 14 deletions rpc/dialer.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"os"
"strings"
"sync"
"time"

"github.com/edaniels/golog"
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
Expand Down Expand Up @@ -240,6 +241,134 @@ func DialDirectGRPC(ctx context.Context, address string, logger utils.ZapCompati
return dialInner(ctx, address, logger, dOpts)
}

func socksProxyDialContext(ctx context.Context, network, proxyAddr, addr string) (net.Conn, error) {
dialer, err := proxy.SOCKS5(network, proxyAddr, nil, proxy.Direct)
if err != nil {
return nil, fmt.Errorf("error creating SOCKS proxy dialer to address %q from environment: %w",
proxyAddr, err)
}
return dialer.(proxy.ContextDialer).DialContext(ctx, network, addr)
}

// SocksProxyFallbackDialContext will return nil if SocksProxyEnvVar is not set or if trying to connect to a local address,
// which will allow dialers to use the default DialContext.
// If SocksProxyEnvVar is set, it will prioritize a connection made without a proxy but will fall back to a SOCKS proxy connection.
func SocksProxyFallbackDialContext(
addr string, logger utils.ZapCompatibleLogger,
) func(ctx context.Context, network, addr string) (net.Conn, error) {
// Use SOCKS proxy from environment as gRPC proxy dialer. Do not use SOCKS proxy if trying to connect to a local address.
localAddr := strings.HasPrefix(addr, "[::]") || strings.HasPrefix(addr, "localhost") || strings.HasPrefix(addr, "unix")
proxyAddr := os.Getenv(SocksProxyEnvVar)
if localAddr || proxyAddr == "" {
// return nil in these cases so that the default dialer gets used instead.
return nil
}

return func(ctx context.Context, network, addr string) (net.Conn, error) {
// if ONLY_SOCKS_PROXY specified, no need for a parallel dial - only dial through
// the SOCKS proxy directly.
if os.Getenv(OnlySocksProxyEnvVar) != "" {
logger.Infow("Both SOCKS_PROXY and ONLY_SOCKS_PROXY specified, only SOCKS proxy will be used for outgoing connection")
conn, err := socksProxyDialContext(ctx, network, proxyAddr, addr)
if err == nil {
logger.Infow("connected with SOCKS proxy")
}
return conn, err
}

// the block below heavily references https://go.dev/src/net/dial.go#L585
type dialResult struct {
net.Conn
error
primary bool
done bool
}
results := make(chan dialResult) // unbuffered

var primary, fallback dialResult
var wg sync.WaitGroup
defer wg.Wait()

// otherwise, do a parallel dial with a slight delay for the fallback option.
returned := make(chan struct{})
defer close(returned)

dialer := func(ctx context.Context, dialFunc func(context.Context) (net.Conn, error), primary bool) {
defer wg.Done()
conn, err := dialFunc(ctx)
select {
case results <- dialResult{Conn: conn, error: err, primary: primary, done: true}:
case <-returned:
if conn != nil {
utils.UncheckedError(conn.Close())
}
}
}

logger.Infow("SOCKS_PROXY specified, SOCKS proxy will be used as a fallback for outgoing connection")
// start the main dial attempt.
primaryCtx, primaryCancel := context.WithCancel(ctx)
defer primaryCancel()
wg.Add(1)
primaryDial := func(ctx context.Context) (net.Conn, error) {
// create a zero-valued net.Dialer to use net.Dialer's default DialContext method
var zeroDialer net.Dialer
return zeroDialer.DialContext(ctx, network, addr)
}
go dialer(primaryCtx, primaryDial, true)

// wait a small amount before starting the fallback dial (to prioritize the primary connection method).
fallbackTimer := time.NewTimer(300 * time.Millisecond)
defer fallbackTimer.Stop()

// fallbackCtx is defined here because this fails `go vet` otherwise. The intent is for fallbackCancel
// to be called as this function exits, which will cancel the ongoing SOCKS proxy if it is still running.
fallbackCtx, fallbackCancel := context.WithCancel(ctx)
defer fallbackCancel()

// a for loop is used here so that we wait on both results and the fallback timer at the same time.
// if the timer expires, we should start the fallback dial and then wait for results.
// if the results channel receives a message, the message should be processed and either return
// or continue waiting (and reset the timer if it hasn't already expired).
for {
select {
case <-fallbackTimer.C:
wg.Add(1)
fallbackDial := func(ctx context.Context) (net.Conn, error) {
return socksProxyDialContext(ctx, network, proxyAddr, addr)
}
go dialer(fallbackCtx, fallbackDial, false)
case res := <-results:
if res.error == nil {
if res.primary {
logger.Infow("connected with ethernet/wifi")
} else {
logger.Infow("connected with SOCKS proxy")
}
return res.Conn, nil
}
if res.primary {
primary = res
} else {
fallback = res
}
// if both primary and fallback are done with errors, this means neither connection attempt succeeded.
// return the error from the primary dial attempt in that case.
if primary.done && fallback.done {
return nil, primary.error
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[q] Can you explain what's happening in this case? No connection can be returned even though they both completed?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if both primary and fallback are done with errors, this means neither connection attempt succeeded. return the error from the primary dial attempt in that case.

}
if res.primary && fallbackTimer.Stop() {
// If we were able to stop the timer, that means it
// was running (hadn't yet started the fallback), but
// we just got an error on the primary path, so start
// the fallback immediately (in 0 nanoseconds).
fallbackTimer.Reset(0)
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should there not be a

case <-ctx.Done():

here? Will this not spawn infinite SOCKS-proxy-dialing goroutines if neither WiFi not SOCKS proxy are available?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

once primary and fallback dial functions are done, we will exit because of L359. If ctx gets cancelled, the two dials will end and also end up at L359

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool; assume you mean L364 in the new line numbers.

}
}
}

// dialDirectGRPC dials a gRPC server directly.
func dialDirectGRPC(ctx context.Context, address string, dOpts dialOptions, logger utils.ZapCompatibleLogger) (ClientConn, bool, error) {
dialOpts := []grpc.DialOption{
Expand All @@ -251,20 +380,14 @@ func dialDirectGRPC(ctx context.Context, address string, dOpts dialOptions, logg
}),
}

// Use SOCKS proxy from environment as gRPC proxy dialer. Do not use
// if trying to connect to a local address.
if proxyAddr := os.Getenv(SocksProxyEnvVar); proxyAddr != "" &&
!(strings.HasPrefix(address, "[::]") || strings.HasPrefix(address, "localhost") ||
strings.HasPrefix(address, "unix")) {
dialer, err := proxy.SOCKS5("tcp", proxyAddr, nil, proxy.Direct)
if err != nil {
return nil, false, fmt.Errorf("error creating SOCKS proxy dialer to address %q from environment: %w",
proxyAddr, err)
}

dialOpts = append(dialOpts, grpc.WithContextDialer(func(_ context.Context, addr string) (net.Conn, error) {
logger.Info("behind SOCKS proxy; routing direct dial through proxy")
return dialer.Dial("tcp", addr)
// check if we should use a custom dialer that will use the SOCKS proxy as a fallback. Only attach a new context dialer
// if the returned function is not nil.
//
// use "tcp" since gRPC uses HTTP/2, which is built on top of TCP.
socksProxyDialContext := SocksProxyFallbackDialContext(address, logger)
if socksProxyDialContext != nil {
dialOpts = append(dialOpts, grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
return socksProxyDialContext(ctx, "tcp", addr)
}))
}

Expand Down
Loading