diff --git a/overlord/conn_server.go b/overlord/conn_server.go index 1e7bb49..323418b 100644 --- a/overlord/conn_server.go +++ b/overlord/conn_server.go @@ -194,7 +194,7 @@ func (c *ConnServer) handleOverlordRequest(obj interface{}) { c.SpawnFileServer(v.Sid, v.TerminalSid, v.Action, v.Filename, v.Dest, v.Perm, v.CheckOnly) case SpawnModeForwarderCmd: - c.SpawnModeForwarder(v.Sid, v.Port) + c.SpawnModeForwarder(v.Sid, v.Host, v.Port) } } @@ -459,9 +459,10 @@ func (c *ConnServer) SendClearToDownload() { // SpawnModeForwarder spawns a forwarder connection (a ghost with mode ModeForward). // sid is the session ID, which will be used as the session ID of the new ghost. -func (c *ConnServer) SpawnModeForwarder(sid string, port int) { +func (c *ConnServer) SpawnModeForwarder(sid string, host string, port int) { req := NewRequest("forward", map[string]interface{}{ "sid": sid, + "host": host, "port": port, }) c.SendRequest(req, c.getHandler("SpawnModeForwarder")) diff --git a/overlord/ghost.go b/overlord/ghost.go index 40c805f..8906a06 100644 --- a/overlord/ghost.go +++ b/overlord/ghost.go @@ -14,7 +14,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "log" "net" "net/http" @@ -163,7 +162,7 @@ func (t *tlsSettings) updateContext() { } if t.tlsCertFile != "" { log.Println("TLSSettings: using user-supplied ca-certificate") - cert, err := ioutil.ReadFile(t.tlsCertFile) + cert, err := os.ReadFile(t.tlsCertFile) if err != nil { log.Fatalln(err) } @@ -219,6 +218,7 @@ type Ghost struct { fileOp fileOperation // File operation name downloadQueue chan downloadInfo // Download queue uploadContext fileUploadContext // File upload context + host string // Host to forward port int // Port number to forward tlsMode int // TLS mode } @@ -295,8 +295,9 @@ func (ghost *Ghost) SetFileOp(operation, filename string, perm int) *Ghost { return ghost } -// SetModeForwardPort sets the port to forward. -func (ghost *Ghost) SetModeForwardPort(port int) *Ghost { +// SetForwardTarget sets the host and port to forward. +func (ghost *Ghost) SetForwardTarget(host string, port int) *Ghost { + ghost.host = host ghost.port = port return ghost } @@ -321,7 +322,7 @@ func (ghost *Ghost) loadProperties() { return } - bytes, err := ioutil.ReadFile(ghost.propFile) + bytes, err := os.ReadFile(ghost.propFile) if err != nil { log.Printf("loadProperties: %s\n", err) return @@ -623,6 +624,7 @@ func (ghost *Ghost) handleFileUploadRequest(req *Request) error { func (ghost *Ghost) handleModeForwardRequest(req *Request) error { type RequestParams struct { Sid string `json:"sid"` + Host string `json:"host"` Port int `json:"port"` } @@ -635,7 +637,7 @@ func (ghost *Ghost) handleModeForwardRequest(req *Request) error { log.Printf("Received forward command, ModeForward agent %s spawned\n", params.Sid) addrs := []string{ghost.connectedAddr} g := NewGhost(addrs, ghost.tls, ModeForward, RandomMID).SetSid( - params.Sid).SetModeForwardPort(params.Port) + params.Sid).SetForwardTarget(params.Host, params.Port) g.Start(false, false) }() @@ -958,7 +960,8 @@ func (ghost *Ghost) SpawnTTYServer(res *Response) error { } } -// SpawnShellServer spawns a Shell server and forward input/output from/to the TCP socket. +// SpawnShellServer spawns a Shell server and forward input/output from/to the +// TCP socket. func (ghost *Ghost) SpawnShellServer(res *Response) error { log.Println("SpawnShellServer: started") @@ -1023,9 +1026,9 @@ func (ghost *Ghost) SpawnShellServer(res *Response) error { process := (*PollableProcess)(cmd.Process) _, err = process.Poll() - // Check if the process is terminated. If not, send SIGlogcatTypeVT100 to the process, - // then wait for 1 second. Send another SIGKILL to make sure the process is - // terminated. + // Check if the process is terminated. If not, send SIGlogcatTypeVT100 to + // the process, then wait for 1 second. Send another SIGKILL to make sure + // the process is terminated. if err != nil { cmd.Process.Signal(syscall.SIGTERM) time.Sleep(time.Second) @@ -1090,10 +1093,10 @@ func (ghost *Ghost) InitiatefileOperation(res *Response) error { return errors.New("InitiatefileOperation: unknown file operation, ignored") } -// SpawnPortModeForwardServer spawns a port forwarding server and forward I/O to +// SpawnPortForwardServer spawns a port forwarding server and forward I/O to // the TCP socket. -func (ghost *Ghost) SpawnPortModeForwardServer(res *Response) error { - log.Println("SpawnPortModeForwardServer: started") +func (ghost *Ghost) SpawnPortForwardServer(res *Response) error { + log.Println("SpawnPortForwardServer: started") var err error @@ -1103,11 +1106,11 @@ func (ghost *Ghost) SpawnPortModeForwardServer(res *Response) error { ghost.Conn.Write([]byte(err.Error() + "\n")) } ghost.Conn.Close() - log.Println("SpawnPortModeForwardServer: terminated") + log.Println("SpawnPortForwardServer: terminated") }() - conn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", ghost.port), - connectTimeout) + conn, err := net.DialTimeout("tcp", + fmt.Sprintf("%s:%d", ghost.host, ghost.port), connectTimeout) if err != nil { return err } @@ -1131,7 +1134,7 @@ func (ghost *Ghost) SpawnPortModeForwardServer(res *Response) error { conn.Write(buf) case err := <-ghost.readErrChan: if err == io.EOF { - log.Println("SpawnPortModeForwardServer: connection terminated") + log.Println("SpawnPortForwardServer: connection terminated") return nil } return err @@ -1236,7 +1239,7 @@ func (ghost *Ghost) Register() error { case ModeFile: handler = ghost.InitiatefileOperation case ModeForward: - handler = ghost.SpawnPortModeForwardServer + handler = ghost.SpawnPortForwardServer } err = ghost.SendRequest(req, handler) return nil @@ -1485,7 +1488,8 @@ func (ghost *Ghost) Start(lanDisc bool, RPCServer bool) { } } -// Returns a ghostRPCStub client object which can be used to call ghostRPCStub methods. +// Returns a ghostRPCStub client object which can be used to call ghostRPCStub +// methods. func ghostRPCStubServer() (*rpc.Client, error) { conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", ghostRPCStubPort)) if err != nil { diff --git a/overlord/overlord.go b/overlord/overlord.go index 43cba77..82a19d9 100644 --- a/overlord/overlord.go +++ b/overlord/overlord.go @@ -9,7 +9,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "log" "net" "net/http" @@ -29,8 +28,9 @@ import ( ) const ( - systemAppDir = "../share/overlord" - ldInterval = 5 + defaultForwardHost = "127.0.0.1" + ldInterval = 5 + systemAppDir = "../share/overlord" ) // SpawnTerminalCmd is an overlord intend to launch a terminal. @@ -59,6 +59,7 @@ type SpawnFileCmd struct { // SpawnModeForwarderCmd is an overlord intend to perform port forwarding. type SpawnModeForwarderCmd struct { Sid string // Session ID + Host string // Host to forward Port int // Port to forward } @@ -377,7 +378,7 @@ func (ovl *Overlord) GetAppNames(ignoreSpecial bool) ([]string, error) { return false } - apps, err := ioutil.ReadDir(ovl.GetAppDir()) + apps, err := os.ReadDir(ovl.GetAppDir()) if err != nil { return nil, nil } @@ -850,10 +851,16 @@ func (ovl *Overlord) RegisterHTTPHandlers() { return } + var host string = defaultForwardHost var port int vars := mux.Vars(r) mid := vars["mid"] + // default thost to 127.0.0.1 if not specified + if _host, ok := r.URL.Query()["host"]; ok { + host = _host[0] + } + if _port, ok := r.URL.Query()["port"]; ok { if port, err = strconv.Atoi(_port[0]); err != nil { WebSocketSendError(conn, "invalid port") @@ -869,7 +876,7 @@ func (ovl *Overlord) RegisterHTTPHandlers() { ovl.agentsMu.Unlock() wc := newWebsocketContext(conn) ovl.AddWebsocketContext(wc) - agent.Command <- SpawnModeForwarderCmd{wc.Sid, port} + agent.Command <- SpawnModeForwarderCmd{wc.Sid, host, port} if res := <-agent.Response; res != "" { WebSocketSendError(conn, res) } diff --git a/overlord/sysutils_linux.go b/overlord/sysutils_linux.go index b737cf8..a0ec4bf 100644 --- a/overlord/sysutils_linux.go +++ b/overlord/sysutils_linux.go @@ -8,7 +8,6 @@ import ( "bufio" "encoding/hex" "fmt" - "io/ioutil" "os" "strings" @@ -62,7 +61,7 @@ func GetMachineID() (string, error) { } } - interfaces, err := ioutil.ReadDir("/sys/class/net") + interfaces, err := os.ReadDir("/sys/class/net") if err == nil { mid := "" for _, iface := range interfaces { diff --git a/scripts/ghost.py b/scripts/ghost.py index 58318d9..811d9bd 100755 --- a/scripts/ghost.py +++ b/scripts/ghost.py @@ -45,7 +45,7 @@ _BUFSIZE = 8192 _RETRY_INTERVAL = 2 _SEPARATOR = b'\r\n' -_PING_TIMEOUT = 3 +_PING_TIMEOUT = 10 _PING_INTERVAL = 5 _REQUEST_TIMEOUT_SECS = 60 _SHELL = os.getenv('SHELL', '/bin/sh') @@ -66,6 +66,8 @@ FAILED = 'failed' DISCONNECTED = 'disconnected' +_DEFAULT_FORWARD_HOST = '127.0.0.1' + class PingTimeoutError(Exception): pass @@ -198,7 +200,7 @@ class Ghost: def __init__(self, overlord_addrs, tls_settings=None, mode=AGENT, mid=None, sid=None, prop_file=None, terminal_sid=None, tty_device=None, - command=None, file_op=None, port=None, tls_mode=None): + command=None, file_op=None, host=None, port=None, tls_mode=None): """Constructor. Args: @@ -259,6 +261,7 @@ def __init__(self, overlord_addrs, tls_settings=None, mode=AGENT, mid=None, self._shell_command = command self._file_op = file_op self._download_queue = queue.Queue() + self._host = host self._port = port def SetIgnoreChild(self, status): @@ -390,7 +393,7 @@ def CloseSockets(self): pass def SpawnGhost(self, mode, sid=None, terminal_sid=None, tty_device=None, - command=None, file_op=None, port=None): + command=None, file_op=None, host=None, port=None): """Spawn a child ghost with specific mode. Returns: @@ -405,7 +408,7 @@ def SpawnGhost(self, mode, sid=None, terminal_sid=None, tty_device=None, g = Ghost([self._connected_addr], tls_settings=self._tls_settings, mode=mode, mid=Ghost.RANDOM_MID, sid=sid, terminal_sid=terminal_sid, tty_device=tty_device, - command=command, file_op=file_op, port=port) + command=command, file_op=file_op, host=host, port=port) g.Start() sys.exit(0) else: @@ -819,7 +822,7 @@ def SpawnPortForwardServer(self, unused_var): try: src_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) src_sock.settimeout(_CONNECT_TIMEOUT) - src_sock.connect(('127.0.0.1', self._port)) + src_sock.connect((self._host, self._port)) src_sock.send(self._sock.RecvBuf()) @@ -853,7 +856,7 @@ def timeout_handler(x): raise PingTimeoutError self._last_ping = self.Timestamp() - self.SendRequest('ping', {}, timeout_handler, 5) + self.SendRequest('ping', {}, timeout_handler, _PING_TIMEOUT) def HandleFileDownloadRequest(self, msg): params = msg['params'] @@ -939,7 +942,9 @@ def HandleRequest(self, msg): elif command == 'file_upload': self.HandleFileUploadRequest(msg) elif command == 'forward': - self.SpawnGhost(self.FORWARD, params['sid'], port=params['port']) + self.SpawnGhost(self.FORWARD, params['sid'], + host=params.get('host', _DEFAULT_FORWARD_HOST), + port=params['port']) self.SendResponse(msg, SUCCESS) def HandleResponse(self, response): diff --git a/scripts/ovl.py b/scripts/ovl.py index 58037c2..1a0d9ac 100755 --- a/scripts/ovl.py +++ b/scripts/ovl.py @@ -1236,7 +1236,7 @@ def _push_single_target(src, dst): # If dest_dir does not exist, the resulting directory structure should # be: # dest_dir/A - dst_root = root if dst_exists else root[len(src):].lstrip('/') + dst_root = os.path.basename(root) if dst_exists else '' for name in files: _push(os.path.join(root, name), os.path.join(dst, dst_root, name)) @@ -1354,18 +1354,18 @@ def _pull(src, dst, ftype, perm=0o644, link=None): help='remove port forwarding for local port LOCAL_PORT'), Arg('--remove-all', dest='remove_all', action='store_true', default=False, help='remove all port forwarding'), - Arg('remote', metavar='REMOTE_PORT', type=int, nargs='?'), - Arg('local', metavar='LOCAL_PORT', type=int, nargs='?')]) + Arg('remote', metavar='[HOST:]REMOTE_PORT', type=str, nargs='?'), + Arg('local_port', metavar='LOCAL_PORT', type=int, nargs='?')]) def Forward(self, args): if args.list_all: max_len = 10 if self._state.forwards: max_len = max([len(v[0]) for v in self._state.forwards.values()]) - print('%-*s %-8s %-8s' % (max_len, 'Client', 'Remote', 'Local')) + print('%-*s %-23s %-8s' % (max_len, 'Client', 'Remote', 'Local')) for local in sorted(self._state.forwards.keys()): value = self._state.forwards[local] - print('%-*s %-8s %-8s' % (max_len, value[0], value[1], local)) + print('%-*s %-23s %-8s' % (max_len, value[0], value[1], local)) return if args.remove_all: @@ -1379,12 +1379,24 @@ def Forward(self, args): self.CheckClient() if args.remote is None: - raise RuntimeError('remote port not specified') + raise RuntimeError('remote target not specified') - if args.local is None: - args.local = args.remote - remote = int(args.remote) - local = int(args.local) + remote_parts = args.remote.split(':') + if len(remote_parts) == 1: + try: + remote_port = int(remote_parts[0]) + except ValueError: + raise RuntimeError('invalid remote port') + elif len(remote_parts) == 2: + remote_host = remote_parts[0] + remote_port = int(remote_parts[1]) + else: + raise RuntimeError('invalid remote target') + + if args.local_port is None: + args.local_port = remote_port + + remote = remote_port def HandleConnection(conn): headers = [] @@ -1395,9 +1407,10 @@ def HandleConnection(conn): scheme = 'ws%s://' % ('s' if self._state.ssl else '') ws = ForwarderWebSocketClient( self._state, conn, - scheme + '%s:%d/api/agent/forward/%s?port=%d' % ( + scheme + '%s:%d/api/agent/forward/%s?host=%s&port=%d' % ( self._state.host, self._state.port, - urllib.parse.quote(self._selected_mid), remote), + urllib.parse.quote(self._selected_mid), + remote_host, remote_port), headers=headers) try: ws.connect() @@ -1409,7 +1422,7 @@ def HandleConnection(conn): server = socket.socket(socket.AF_INET, socket.SOCK_STREAM) server.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - server.bind(('0.0.0.0', local)) + server.bind(('0.0.0.0', args.local_port)) server.listen(5) pid = os.fork() @@ -1420,7 +1433,8 @@ def HandleConnection(conn): t.daemon = True t.start() else: - self._server.AddForward(self._selected_mid, remote, local, pid) + self._server.AddForward(self._selected_mid, args.remote, args.local_port, + pid) def main():