@@ -10,6 +10,7 @@ import (
10
10
"os"
11
11
"strings"
12
12
"sync"
13
+ "time"
13
14
14
15
"github.com/edaniels/golog"
15
16
grpc_middleware "github.com/grpc-ecosystem/go-grpc-middleware"
@@ -240,6 +241,134 @@ func DialDirectGRPC(ctx context.Context, address string, logger utils.ZapCompati
240
241
return dialInner (ctx , address , logger , dOpts )
241
242
}
242
243
244
+ func socksProxyDialContext (ctx context.Context , network , proxyAddr , addr string ) (net.Conn , error ) {
245
+ dialer , err := proxy .SOCKS5 (network , proxyAddr , nil , proxy .Direct )
246
+ if err != nil {
247
+ return nil , fmt .Errorf ("error creating SOCKS proxy dialer to address %q from environment: %w" ,
248
+ proxyAddr , err )
249
+ }
250
+ return dialer .(proxy.ContextDialer ).DialContext (ctx , network , addr )
251
+ }
252
+
253
+ // SocksProxyFallbackDialContext will return nil if SocksProxyEnvVar is not set or if trying to connect to a local address,
254
+ // which will allow dialers to use the default DialContext.
255
+ // If SocksProxyEnvVar is set, it will prioritize a connection made without a proxy but will fall back to a SOCKS proxy connection.
256
+ func SocksProxyFallbackDialContext (
257
+ addr string , logger utils.ZapCompatibleLogger ,
258
+ ) func (ctx context.Context , network , addr string ) (net.Conn , error ) {
259
+ // Use SOCKS proxy from environment as gRPC proxy dialer. Do not use SOCKS proxy if trying to connect to a local address.
260
+ localAddr := strings .HasPrefix (addr , "[::]" ) || strings .HasPrefix (addr , "localhost" ) || strings .HasPrefix (addr , "unix" )
261
+ proxyAddr := os .Getenv (SocksProxyEnvVar )
262
+ if localAddr || proxyAddr == "" {
263
+ // return nil in these cases so that the default dialer gets used instead.
264
+ return nil
265
+ }
266
+
267
+ return func (ctx context.Context , network , addr string ) (net.Conn , error ) {
268
+ // if ONLY_SOCKS_PROXY specified, no need for a parallel dial - only dial through
269
+ // the SOCKS proxy directly.
270
+ if os .Getenv (OnlySocksProxyEnvVar ) != "" {
271
+ logger .Infow ("Both SOCKS_PROXY and ONLY_SOCKS_PROXY specified, only SOCKS proxy will be used for outgoing connection" )
272
+ conn , err := socksProxyDialContext (ctx , network , proxyAddr , addr )
273
+ if err == nil {
274
+ logger .Infow ("connected with SOCKS proxy" )
275
+ }
276
+ return conn , err
277
+ }
278
+
279
+ // the block below heavily references https://go.dev/src/net/dial.go#L585
280
+ type dialResult struct {
281
+ net.Conn
282
+ error
283
+ primary bool
284
+ done bool
285
+ }
286
+ results := make (chan dialResult ) // unbuffered
287
+
288
+ var primary , fallback dialResult
289
+ var wg sync.WaitGroup
290
+ defer wg .Wait ()
291
+
292
+ // otherwise, do a parallel dial with a slight delay for the fallback option.
293
+ returned := make (chan struct {})
294
+ defer close (returned )
295
+
296
+ dialer := func (ctx context.Context , dialFunc func (context.Context ) (net.Conn , error ), primary bool ) {
297
+ defer wg .Done ()
298
+ conn , err := dialFunc (ctx )
299
+ select {
300
+ case results <- dialResult {Conn : conn , error : err , primary : primary , done : true }:
301
+ case <- returned :
302
+ if conn != nil {
303
+ utils .UncheckedError (conn .Close ())
304
+ }
305
+ }
306
+ }
307
+
308
+ logger .Infow ("SOCKS_PROXY specified, SOCKS proxy will be used as a fallback for outgoing connection" )
309
+ // start the main dial attempt.
310
+ primaryCtx , primaryCancel := context .WithCancel (ctx )
311
+ defer primaryCancel ()
312
+ wg .Add (1 )
313
+ primaryDial := func (ctx context.Context ) (net.Conn , error ) {
314
+ // create a zero-valued net.Dialer to use net.Dialer's default DialContext method
315
+ var zeroDialer net.Dialer
316
+ return zeroDialer .DialContext (ctx , network , addr )
317
+ }
318
+ go dialer (primaryCtx , primaryDial , true )
319
+
320
+ // wait a small amount before starting the fallback dial (to prioritize the primary connection method).
321
+ fallbackTimer := time .NewTimer (300 * time .Millisecond )
322
+ defer fallbackTimer .Stop ()
323
+
324
+ // fallbackCtx is defined here because this fails `go vet` otherwise. The intent is for fallbackCancel
325
+ // to be called as this function exits, which will cancel the ongoing SOCKS proxy if it is still running.
326
+ fallbackCtx , fallbackCancel := context .WithCancel (ctx )
327
+ defer fallbackCancel ()
328
+
329
+ // a for loop is used here so that we wait on both results and the fallback timer at the same time.
330
+ // if the timer expires, we should start the fallback dial and then wait for results.
331
+ // if the results channel receives a message, the message should be processed and either return
332
+ // or continue waiting (and reset the timer if it hasn't already expired).
333
+ for {
334
+ select {
335
+ case <- fallbackTimer .C :
336
+ wg .Add (1 )
337
+ fallbackDial := func (ctx context.Context ) (net.Conn , error ) {
338
+ return socksProxyDialContext (ctx , network , proxyAddr , addr )
339
+ }
340
+ go dialer (fallbackCtx , fallbackDial , false )
341
+ case res := <- results :
342
+ if res .error == nil {
343
+ if res .primary {
344
+ logger .Infow ("connected with ethernet/wifi" )
345
+ } else {
346
+ logger .Infow ("connected with SOCKS proxy" )
347
+ }
348
+ return res .Conn , nil
349
+ }
350
+ if res .primary {
351
+ primary = res
352
+ } else {
353
+ fallback = res
354
+ }
355
+ // if both primary and fallback are done with errors, this means neither connection attempt succeeded.
356
+ // return the error from the primary dial attempt in that case.
357
+ if primary .done && fallback .done {
358
+ return nil , primary .error
359
+ }
360
+ if res .primary && fallbackTimer .Stop () {
361
+ // If we were able to stop the timer, that means it
362
+ // was running (hadn't yet started the fallback), but
363
+ // we just got an error on the primary path, so start
364
+ // the fallback immediately (in 0 nanoseconds).
365
+ fallbackTimer .Reset (0 )
366
+ }
367
+ }
368
+ }
369
+ }
370
+ }
371
+
243
372
// dialDirectGRPC dials a gRPC server directly.
244
373
func dialDirectGRPC (ctx context.Context , address string , dOpts dialOptions , logger utils.ZapCompatibleLogger ) (ClientConn , bool , error ) {
245
374
dialOpts := []grpc.DialOption {
@@ -251,20 +380,14 @@ func dialDirectGRPC(ctx context.Context, address string, dOpts dialOptions, logg
251
380
}),
252
381
}
253
382
254
- // Use SOCKS proxy from environment as gRPC proxy dialer. Do not use
255
- // if trying to connect to a local address.
256
- if proxyAddr := os .Getenv (SocksProxyEnvVar ); proxyAddr != "" &&
257
- ! (strings .HasPrefix (address , "[::]" ) || strings .HasPrefix (address , "localhost" ) ||
258
- strings .HasPrefix (address , "unix" )) {
259
- dialer , err := proxy .SOCKS5 ("tcp" , proxyAddr , nil , proxy .Direct )
260
- if err != nil {
261
- return nil , false , fmt .Errorf ("error creating SOCKS proxy dialer to address %q from environment: %w" ,
262
- proxyAddr , err )
263
- }
264
-
265
- dialOpts = append (dialOpts , grpc .WithContextDialer (func (_ context.Context , addr string ) (net.Conn , error ) {
266
- logger .Info ("behind SOCKS proxy; routing direct dial through proxy" )
267
- return dialer .Dial ("tcp" , addr )
383
+ // check if we should use a custom dialer that will use the SOCKS proxy as a fallback. Only attach a new context dialer
384
+ // if the returned function is not nil.
385
+ //
386
+ // use "tcp" since gRPC uses HTTP/2, which is built on top of TCP.
387
+ socksProxyDialContext := SocksProxyFallbackDialContext (address , logger )
388
+ if socksProxyDialContext != nil {
389
+ dialOpts = append (dialOpts , grpc .WithContextDialer (func (ctx context.Context , addr string ) (net.Conn , error ) {
390
+ return socksProxyDialContext (ctx , "tcp" , addr )
268
391
}))
269
392
}
270
393
0 commit comments