Skip to content

Commit 447cd0d

Browse files
cheuktbenjirewis
andauthored
RSDK-9990- Expose SOCKS proxy as fallback dialer (#414)
Co-authored-by: Benjamin Rewis <[email protected]>
1 parent b46642e commit 447cd0d

File tree

2 files changed

+141
-14
lines changed

2 files changed

+141
-14
lines changed

rpc/const.go

+4
Original file line numberDiff line numberDiff line change
@@ -13,4 +13,8 @@ var (
1313
// proxies to indicate the address through which to route all network traffic
1414
// via SOCKS5.
1515
SocksProxyEnvVar = "SOCKS_PROXY"
16+
17+
// OnlySocksProxyEnvVar is the name of an environment variable used if all network
18+
// traffic should be done through SOCKS5.
19+
OnlySocksProxyEnvVar = "ONLY_SOCKS_PROXY"
1620
)

rpc/dialer.go

+137-14
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"os"
1111
"strings"
1212
"sync"
13+
"time"
1314

1415
"github.com/edaniels/golog"
1516
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
@@ -240,6 +241,134 @@ func DialDirectGRPC(ctx context.Context, address string, logger utils.ZapCompati
240241
return dialInner(ctx, address, logger, dOpts)
241242
}
242243

244+
func socksProxyDialContext(ctx context.Context, network, proxyAddr, addr string) (net.Conn, error) {
245+
dialer, err := proxy.SOCKS5(network, proxyAddr, nil, proxy.Direct)
246+
if err != nil {
247+
return nil, fmt.Errorf("error creating SOCKS proxy dialer to address %q from environment: %w",
248+
proxyAddr, err)
249+
}
250+
return dialer.(proxy.ContextDialer).DialContext(ctx, network, addr)
251+
}
252+
253+
// SocksProxyFallbackDialContext will return nil if SocksProxyEnvVar is not set or if trying to connect to a local address,
254+
// which will allow dialers to use the default DialContext.
255+
// If SocksProxyEnvVar is set, it will prioritize a connection made without a proxy but will fall back to a SOCKS proxy connection.
256+
func SocksProxyFallbackDialContext(
257+
addr string, logger utils.ZapCompatibleLogger,
258+
) func(ctx context.Context, network, addr string) (net.Conn, error) {
259+
// Use SOCKS proxy from environment as gRPC proxy dialer. Do not use SOCKS proxy if trying to connect to a local address.
260+
localAddr := strings.HasPrefix(addr, "[::]") || strings.HasPrefix(addr, "localhost") || strings.HasPrefix(addr, "unix")
261+
proxyAddr := os.Getenv(SocksProxyEnvVar)
262+
if localAddr || proxyAddr == "" {
263+
// return nil in these cases so that the default dialer gets used instead.
264+
return nil
265+
}
266+
267+
return func(ctx context.Context, network, addr string) (net.Conn, error) {
268+
// if ONLY_SOCKS_PROXY specified, no need for a parallel dial - only dial through
269+
// the SOCKS proxy directly.
270+
if os.Getenv(OnlySocksProxyEnvVar) != "" {
271+
logger.Infow("Both SOCKS_PROXY and ONLY_SOCKS_PROXY specified, only SOCKS proxy will be used for outgoing connection")
272+
conn, err := socksProxyDialContext(ctx, network, proxyAddr, addr)
273+
if err == nil {
274+
logger.Infow("connected with SOCKS proxy")
275+
}
276+
return conn, err
277+
}
278+
279+
// the block below heavily references https://go.dev/src/net/dial.go#L585
280+
type dialResult struct {
281+
net.Conn
282+
error
283+
primary bool
284+
done bool
285+
}
286+
results := make(chan dialResult) // unbuffered
287+
288+
var primary, fallback dialResult
289+
var wg sync.WaitGroup
290+
defer wg.Wait()
291+
292+
// otherwise, do a parallel dial with a slight delay for the fallback option.
293+
returned := make(chan struct{})
294+
defer close(returned)
295+
296+
dialer := func(ctx context.Context, dialFunc func(context.Context) (net.Conn, error), primary bool) {
297+
defer wg.Done()
298+
conn, err := dialFunc(ctx)
299+
select {
300+
case results <- dialResult{Conn: conn, error: err, primary: primary, done: true}:
301+
case <-returned:
302+
if conn != nil {
303+
utils.UncheckedError(conn.Close())
304+
}
305+
}
306+
}
307+
308+
logger.Infow("SOCKS_PROXY specified, SOCKS proxy will be used as a fallback for outgoing connection")
309+
// start the main dial attempt.
310+
primaryCtx, primaryCancel := context.WithCancel(ctx)
311+
defer primaryCancel()
312+
wg.Add(1)
313+
primaryDial := func(ctx context.Context) (net.Conn, error) {
314+
// create a zero-valued net.Dialer to use net.Dialer's default DialContext method
315+
var zeroDialer net.Dialer
316+
return zeroDialer.DialContext(ctx, network, addr)
317+
}
318+
go dialer(primaryCtx, primaryDial, true)
319+
320+
// wait a small amount before starting the fallback dial (to prioritize the primary connection method).
321+
fallbackTimer := time.NewTimer(300 * time.Millisecond)
322+
defer fallbackTimer.Stop()
323+
324+
// fallbackCtx is defined here because this fails `go vet` otherwise. The intent is for fallbackCancel
325+
// to be called as this function exits, which will cancel the ongoing SOCKS proxy if it is still running.
326+
fallbackCtx, fallbackCancel := context.WithCancel(ctx)
327+
defer fallbackCancel()
328+
329+
// a for loop is used here so that we wait on both results and the fallback timer at the same time.
330+
// if the timer expires, we should start the fallback dial and then wait for results.
331+
// if the results channel receives a message, the message should be processed and either return
332+
// or continue waiting (and reset the timer if it hasn't already expired).
333+
for {
334+
select {
335+
case <-fallbackTimer.C:
336+
wg.Add(1)
337+
fallbackDial := func(ctx context.Context) (net.Conn, error) {
338+
return socksProxyDialContext(ctx, network, proxyAddr, addr)
339+
}
340+
go dialer(fallbackCtx, fallbackDial, false)
341+
case res := <-results:
342+
if res.error == nil {
343+
if res.primary {
344+
logger.Infow("connected with ethernet/wifi")
345+
} else {
346+
logger.Infow("connected with SOCKS proxy")
347+
}
348+
return res.Conn, nil
349+
}
350+
if res.primary {
351+
primary = res
352+
} else {
353+
fallback = res
354+
}
355+
// if both primary and fallback are done with errors, this means neither connection attempt succeeded.
356+
// return the error from the primary dial attempt in that case.
357+
if primary.done && fallback.done {
358+
return nil, primary.error
359+
}
360+
if res.primary && fallbackTimer.Stop() {
361+
// If we were able to stop the timer, that means it
362+
// was running (hadn't yet started the fallback), but
363+
// we just got an error on the primary path, so start
364+
// the fallback immediately (in 0 nanoseconds).
365+
fallbackTimer.Reset(0)
366+
}
367+
}
368+
}
369+
}
370+
}
371+
243372
// dialDirectGRPC dials a gRPC server directly.
244373
func dialDirectGRPC(ctx context.Context, address string, dOpts dialOptions, logger utils.ZapCompatibleLogger) (ClientConn, bool, error) {
245374
dialOpts := []grpc.DialOption{
@@ -251,20 +380,14 @@ func dialDirectGRPC(ctx context.Context, address string, dOpts dialOptions, logg
251380
}),
252381
}
253382

254-
// Use SOCKS proxy from environment as gRPC proxy dialer. Do not use
255-
// if trying to connect to a local address.
256-
if proxyAddr := os.Getenv(SocksProxyEnvVar); proxyAddr != "" &&
257-
!(strings.HasPrefix(address, "[::]") || strings.HasPrefix(address, "localhost") ||
258-
strings.HasPrefix(address, "unix")) {
259-
dialer, err := proxy.SOCKS5("tcp", proxyAddr, nil, proxy.Direct)
260-
if err != nil {
261-
return nil, false, fmt.Errorf("error creating SOCKS proxy dialer to address %q from environment: %w",
262-
proxyAddr, err)
263-
}
264-
265-
dialOpts = append(dialOpts, grpc.WithContextDialer(func(_ context.Context, addr string) (net.Conn, error) {
266-
logger.Info("behind SOCKS proxy; routing direct dial through proxy")
267-
return dialer.Dial("tcp", addr)
383+
// check if we should use a custom dialer that will use the SOCKS proxy as a fallback. Only attach a new context dialer
384+
// if the returned function is not nil.
385+
//
386+
// use "tcp" since gRPC uses HTTP/2, which is built on top of TCP.
387+
socksProxyDialContext := SocksProxyFallbackDialContext(address, logger)
388+
if socksProxyDialContext != nil {
389+
dialOpts = append(dialOpts, grpc.WithContextDialer(func(ctx context.Context, addr string) (net.Conn, error) {
390+
return socksProxyDialContext(ctx, "tcp", addr)
268391
}))
269392
}
270393

0 commit comments

Comments
 (0)