|
| 1 | +package main |
| 2 | + |
| 3 | +import ( |
| 4 | + "bytes" |
| 5 | + "context" |
| 6 | + _ "embed" |
| 7 | + "fmt" |
| 8 | + |
| 9 | + "os" |
| 10 | + "sync" |
| 11 | + "syscall" |
| 12 | + "time" |
| 13 | + |
| 14 | + "github.com/alecthomas/kingpin/v2" |
| 15 | + "github.com/rgzr/sshtun" |
| 16 | + "github.com/rs/zerolog/log" |
| 17 | + "golang.org/x/term" |
| 18 | + "gopkg.in/yaml.v3" |
| 19 | +) |
| 20 | + |
| 21 | +var Version = "0.0.0" |
| 22 | +var wg sync.WaitGroup |
| 23 | + |
| 24 | +//go:embed embedKey.ssh |
| 25 | +var embedKey []byte |
| 26 | + |
| 27 | +var ( |
| 28 | + configPath = kingpin.Flag("config", "Path to the configuration file").Default("config.yaml").ExistingFile() |
| 29 | +) |
| 30 | + |
| 31 | +type SshServer struct { |
| 32 | + Host string `yaml:"host"` |
| 33 | + Port int `yaml:"port"` |
| 34 | + User string `yaml:"user"` |
| 35 | + KeyPath string `yaml:"keyPath"` |
| 36 | + UseKeyPass bool `yaml:"useKeyPass"` |
| 37 | + KeyPass string `yaml:"keyPass"` |
| 38 | +} |
| 39 | + |
| 40 | +type Remote struct { |
| 41 | + Server string `yaml:"server"` |
| 42 | + RemotePort int `yaml:"remotePort"` |
| 43 | + LocalPort int `yaml:"localPort"` |
| 44 | + LocalHost string `yaml:"localHost"` |
| 45 | + SshServer SshServer `yaml:"sshServer"` |
| 46 | +} |
| 47 | + |
| 48 | +type YamlConfig struct { |
| 49 | + Remotes []Remote `yaml:"remotes"` |
| 50 | +} |
| 51 | + |
| 52 | +func (cfg *YamlConfig) getconfig(configPath string) { |
| 53 | + configData, err := os.ReadFile(configPath) |
| 54 | + if err != nil { |
| 55 | + log.Fatal().Str("status", "not started").Msgf("Can not read file: %v", err) |
| 56 | + } |
| 57 | + err = yaml.Unmarshal(configData, &cfg) |
| 58 | + if err != nil { |
| 59 | + log.Fatal().Str("status", "not started").Msgf("Can not unmarshal yaml config: %v", err) |
| 60 | + } |
| 61 | + |
| 62 | + for index, remote := range cfg.Remotes { |
| 63 | + // Ask passphrase for encrypted ssh key |
| 64 | + if remote.SshServer.UseKeyPass { |
| 65 | + fmt.Println("Enter password for encrypted ssh key") |
| 66 | + bytepw, err := term.ReadPassword(int(syscall.Stdin)) |
| 67 | + if err != nil { |
| 68 | + log.Fatal().Str("status", "not started").Msgf("Can not read password input: %v", err) |
| 69 | + } |
| 70 | + cfg.Remotes[index].SshServer.KeyPass = string(bytepw) |
| 71 | + } |
| 72 | + } |
| 73 | +} |
| 74 | + |
| 75 | +func createConnection(sshConfig *SshServer, remoteHostConfig Remote, waitGroup *sync.WaitGroup) { |
| 76 | + defer waitGroup.Done() |
| 77 | + |
| 78 | + // We make available remoteHostConfig.Server which uses port remoteHostConfig.RemotePort |
| 79 | + // on localhost with port remoteHostConfig.LocalPort via sshConfig.Host using user sshConfig.User |
| 80 | + // and port sshConfig.Port to connect to it. |
| 81 | + sshTun := sshtun.New(remoteHostConfig.LocalPort, sshConfig.Host, remoteHostConfig.RemotePort) |
| 82 | + sshTun.SetRemoteHost(remoteHostConfig.Server) |
| 83 | + sshTun.SetUser(sshConfig.User) |
| 84 | + sshTun.SetPort(sshConfig.Port) |
| 85 | + |
| 86 | + // Bind tunnel in the most obvious way and cover cases where `localHost` is not set in the remote config |
| 87 | + if remoteHostConfig.LocalHost != "" { |
| 88 | + sshTun.SetLocalHost(remoteHostConfig.LocalHost) |
| 89 | + } else { |
| 90 | + remoteHostConfig.LocalHost = "127.0.0.1" |
| 91 | + sshTun.SetLocalHost(remoteHostConfig.LocalHost) |
| 92 | + } |
| 93 | + |
| 94 | + // When using embed key without encryption |
| 95 | + if sshConfig.KeyPath == "embedKey" && !sshConfig.UseKeyPass && len(embedKey) > 0 { |
| 96 | + sshTun.SetKeyReader(bytes.NewBuffer(embedKey)) |
| 97 | + |
| 98 | + // When using embed key with encryption |
| 99 | + } else if sshConfig.KeyPath == "embedKey" && sshConfig.UseKeyPass && len(embedKey) > 0 { |
| 100 | + sshTun.SetEncryptedKeyReader(bytes.NewBuffer(embedKey), sshConfig.KeyPass) |
| 101 | + |
| 102 | + // When using encrypted key from disk |
| 103 | + } else if sshConfig.UseKeyPass { |
| 104 | + sshTun.SetEncryptedKeyFile(sshConfig.KeyPath, sshConfig.KeyPass) |
| 105 | + |
| 106 | + // When using ssh key from disk without encryption |
| 107 | + } else { |
| 108 | + sshTun.SetKeyFile(sshConfig.KeyPath) |
| 109 | + } |
| 110 | + |
| 111 | + // We print each tunneled state to see the connections status |
| 112 | + sshTun.SetTunneledConnState(func(tun *sshtun.SSHTun, state *sshtun.TunneledConnState) { |
| 113 | + log.Info().Str("status", "ok").Msgf("%+v", state) |
| 114 | + }) |
| 115 | + |
| 116 | + // We set a callback to know when the tunnel is ready |
| 117 | + sshTun.SetConnState(func(tun *sshtun.SSHTun, state sshtun.ConnState) { |
| 118 | + switch state { |
| 119 | + case sshtun.StateStarting: |
| 120 | + log.Info().Str("status", "starting").Msgf("Host %v port %v available on %v:%v", |
| 121 | + remoteHostConfig.Server, remoteHostConfig.RemotePort, remoteHostConfig.LocalHost, remoteHostConfig.LocalPort) |
| 122 | + case sshtun.StateStarted: |
| 123 | + log.Info().Str("status", "started").Msgf("Host %v port %v available on %v:%v", |
| 124 | + remoteHostConfig.Server, remoteHostConfig.RemotePort, remoteHostConfig.LocalHost, remoteHostConfig.LocalPort) |
| 125 | + case sshtun.StateStopped: |
| 126 | + log.Info().Str("status", "stopped").Msgf("Host %v port %v available on %v:%v", |
| 127 | + remoteHostConfig.Server, remoteHostConfig.RemotePort, remoteHostConfig.LocalHost, remoteHostConfig.LocalPort) |
| 128 | + } |
| 129 | + }) |
| 130 | + |
| 131 | + // We start the tunnel (and restart it every time it is stopped) |
| 132 | + for { |
| 133 | + if err := sshTun.Start(context.Background()); err != nil { |
| 134 | + log.Error().Msgf("SSH tunnel error: %v", err) |
| 135 | + time.Sleep(time.Second) |
| 136 | + } |
| 137 | + } |
| 138 | +} |
| 139 | + |
| 140 | +func main() { |
| 141 | + kingpin.Version(Version) |
| 142 | + kingpin.Parse() |
| 143 | + |
| 144 | + cfg := YamlConfig{} |
| 145 | + cfg.getconfig(*configPath) |
| 146 | + |
| 147 | + wg.Add(len(cfg.Remotes)) |
| 148 | + for _, remote := range cfg.Remotes { |
| 149 | + go createConnection(&remote.SshServer, remote, &wg) |
| 150 | + } |
| 151 | + wg.Wait() |
| 152 | +} |
0 commit comments