Skip to content

Commit

Permalink
overlord: allow specifying host target to forward
Browse files Browse the repository at this point in the history
  • Loading branch information
aitjcize committed May 11, 2024
1 parent 3611689 commit aedbb54
Show file tree
Hide file tree
Showing 6 changed files with 79 additions and 49 deletions.
5 changes: 3 additions & 2 deletions overlord/conn_server.go
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ func (c *ConnServer) handleOverlordRequest(obj interface{}) {
c.SpawnFileServer(v.Sid, v.TerminalSid, v.Action, v.Filename, v.Dest,
v.Perm, v.CheckOnly)
case SpawnModeForwarderCmd:
c.SpawnModeForwarder(v.Sid, v.Port)
c.SpawnModeForwarder(v.Sid, v.Host, v.Port)
}
}

Expand Down Expand Up @@ -459,9 +459,10 @@ func (c *ConnServer) SendClearToDownload() {

// SpawnModeForwarder spawns a forwarder connection (a ghost with mode ModeForward).
// sid is the session ID, which will be used as the session ID of the new ghost.
func (c *ConnServer) SpawnModeForwarder(sid string, port int) {
func (c *ConnServer) SpawnModeForwarder(sid string, host string, port int) {
req := NewRequest("forward", map[string]interface{}{
"sid": sid,
"host": host,
"port": port,
})
c.SendRequest(req, c.getHandler("SpawnModeForwarder"))
Expand Down
42 changes: 23 additions & 19 deletions overlord/ghost.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
Expand Down Expand Up @@ -163,7 +162,7 @@ func (t *tlsSettings) updateContext() {
}
if t.tlsCertFile != "" {
log.Println("TLSSettings: using user-supplied ca-certificate")
cert, err := ioutil.ReadFile(t.tlsCertFile)
cert, err := os.ReadFile(t.tlsCertFile)
if err != nil {
log.Fatalln(err)
}
Expand Down Expand Up @@ -219,6 +218,7 @@ type Ghost struct {
fileOp fileOperation // File operation name
downloadQueue chan downloadInfo // Download queue
uploadContext fileUploadContext // File upload context
host string // Host to forward
port int // Port number to forward
tlsMode int // TLS mode
}
Expand Down Expand Up @@ -295,8 +295,9 @@ func (ghost *Ghost) SetFileOp(operation, filename string, perm int) *Ghost {
return ghost
}

// SetModeForwardPort sets the port to forward.
func (ghost *Ghost) SetModeForwardPort(port int) *Ghost {
// SetForwardTarget sets the host and port to forward.
func (ghost *Ghost) SetForwardTarget(host string, port int) *Ghost {
ghost.host = host
ghost.port = port
return ghost
}
Expand All @@ -321,7 +322,7 @@ func (ghost *Ghost) loadProperties() {
return
}

bytes, err := ioutil.ReadFile(ghost.propFile)
bytes, err := os.ReadFile(ghost.propFile)
if err != nil {
log.Printf("loadProperties: %s\n", err)
return
Expand Down Expand Up @@ -623,6 +624,7 @@ func (ghost *Ghost) handleFileUploadRequest(req *Request) error {
func (ghost *Ghost) handleModeForwardRequest(req *Request) error {
type RequestParams struct {
Sid string `json:"sid"`
Host string `json:"host"`
Port int `json:"port"`
}

Expand All @@ -635,7 +637,7 @@ func (ghost *Ghost) handleModeForwardRequest(req *Request) error {
log.Printf("Received forward command, ModeForward agent %s spawned\n", params.Sid)
addrs := []string{ghost.connectedAddr}
g := NewGhost(addrs, ghost.tls, ModeForward, RandomMID).SetSid(
params.Sid).SetModeForwardPort(params.Port)
params.Sid).SetForwardTarget(params.Host, params.Port)
g.Start(false, false)
}()

Expand Down Expand Up @@ -958,7 +960,8 @@ func (ghost *Ghost) SpawnTTYServer(res *Response) error {
}
}

// SpawnShellServer spawns a Shell server and forward input/output from/to the TCP socket.
// SpawnShellServer spawns a Shell server and forward input/output from/to the
// TCP socket.
func (ghost *Ghost) SpawnShellServer(res *Response) error {
log.Println("SpawnShellServer: started")

Expand Down Expand Up @@ -1023,9 +1026,9 @@ func (ghost *Ghost) SpawnShellServer(res *Response) error {

process := (*PollableProcess)(cmd.Process)
_, err = process.Poll()
// Check if the process is terminated. If not, send SIGlogcatTypeVT100 to the process,
// then wait for 1 second. Send another SIGKILL to make sure the process is
// terminated.
// Check if the process is terminated. If not, send SIGlogcatTypeVT100 to
// the process, then wait for 1 second. Send another SIGKILL to make sure
// the process is terminated.
if err != nil {
cmd.Process.Signal(syscall.SIGTERM)
time.Sleep(time.Second)
Expand Down Expand Up @@ -1090,10 +1093,10 @@ func (ghost *Ghost) InitiatefileOperation(res *Response) error {
return errors.New("InitiatefileOperation: unknown file operation, ignored")
}

// SpawnPortModeForwardServer spawns a port forwarding server and forward I/O to
// SpawnPortForwardServer spawns a port forwarding server and forward I/O to
// the TCP socket.
func (ghost *Ghost) SpawnPortModeForwardServer(res *Response) error {
log.Println("SpawnPortModeForwardServer: started")
func (ghost *Ghost) SpawnPortForwardServer(res *Response) error {
log.Println("SpawnPortForwardServer: started")

var err error

Expand All @@ -1103,11 +1106,11 @@ func (ghost *Ghost) SpawnPortModeForwardServer(res *Response) error {
ghost.Conn.Write([]byte(err.Error() + "\n"))
}
ghost.Conn.Close()
log.Println("SpawnPortModeForwardServer: terminated")
log.Println("SpawnPortForwardServer: terminated")
}()

conn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", ghost.port),
connectTimeout)
conn, err := net.DialTimeout("tcp",
fmt.Sprintf("%s:%d", ghost.host, ghost.port), connectTimeout)
if err != nil {
return err
}
Expand All @@ -1131,7 +1134,7 @@ func (ghost *Ghost) SpawnPortModeForwardServer(res *Response) error {
conn.Write(buf)
case err := <-ghost.readErrChan:
if err == io.EOF {
log.Println("SpawnPortModeForwardServer: connection terminated")
log.Println("SpawnPortForwardServer: connection terminated")
return nil
}
return err
Expand Down Expand Up @@ -1236,7 +1239,7 @@ func (ghost *Ghost) Register() error {
case ModeFile:
handler = ghost.InitiatefileOperation
case ModeForward:
handler = ghost.SpawnPortModeForwardServer
handler = ghost.SpawnPortForwardServer
}
err = ghost.SendRequest(req, handler)
return nil
Expand Down Expand Up @@ -1485,7 +1488,8 @@ func (ghost *Ghost) Start(lanDisc bool, RPCServer bool) {
}
}

// Returns a ghostRPCStub client object which can be used to call ghostRPCStub methods.
// Returns a ghostRPCStub client object which can be used to call ghostRPCStub
// methods.
func ghostRPCStubServer() (*rpc.Client, error) {
conn, err := net.Dial("tcp", fmt.Sprintf("127.0.0.1:%d", ghostRPCStubPort))
if err != nil {
Expand Down
17 changes: 12 additions & 5 deletions overlord/overlord.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"net/http"
Expand All @@ -29,8 +28,9 @@ import (
)

const (
systemAppDir = "../share/overlord"
ldInterval = 5
defaultForwardHost = "127.0.0.1"
ldInterval = 5
systemAppDir = "../share/overlord"
)

// SpawnTerminalCmd is an overlord intend to launch a terminal.
Expand Down Expand Up @@ -59,6 +59,7 @@ type SpawnFileCmd struct {
// SpawnModeForwarderCmd is an overlord intend to perform port forwarding.
type SpawnModeForwarderCmd struct {
Sid string // Session ID
Host string // Host to forward
Port int // Port to forward
}

Expand Down Expand Up @@ -377,7 +378,7 @@ func (ovl *Overlord) GetAppNames(ignoreSpecial bool) ([]string, error) {
return false
}

apps, err := ioutil.ReadDir(ovl.GetAppDir())
apps, err := os.ReadDir(ovl.GetAppDir())
if err != nil {
return nil, nil
}
Expand Down Expand Up @@ -850,10 +851,16 @@ func (ovl *Overlord) RegisterHTTPHandlers() {
return
}

var host string = defaultForwardHost
var port int

vars := mux.Vars(r)
mid := vars["mid"]
// default thost to 127.0.0.1 if not specified
if _host, ok := r.URL.Query()["host"]; ok {
host = _host[0]
}

if _port, ok := r.URL.Query()["port"]; ok {
if port, err = strconv.Atoi(_port[0]); err != nil {
WebSocketSendError(conn, "invalid port")
Expand All @@ -869,7 +876,7 @@ func (ovl *Overlord) RegisterHTTPHandlers() {
ovl.agentsMu.Unlock()
wc := newWebsocketContext(conn)
ovl.AddWebsocketContext(wc)
agent.Command <- SpawnModeForwarderCmd{wc.Sid, port}
agent.Command <- SpawnModeForwarderCmd{wc.Sid, host, port}
if res := <-agent.Response; res != "" {
WebSocketSendError(conn, res)
}
Expand Down
3 changes: 1 addition & 2 deletions overlord/sysutils_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"bufio"
"encoding/hex"
"fmt"
"io/ioutil"
"os"
"strings"

Expand Down Expand Up @@ -62,7 +61,7 @@ func GetMachineID() (string, error) {
}
}

interfaces, err := ioutil.ReadDir("/sys/class/net")
interfaces, err := os.ReadDir("/sys/class/net")
if err == nil {
mid := ""
for _, iface := range interfaces {
Expand Down
19 changes: 12 additions & 7 deletions scripts/ghost.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@
_BUFSIZE = 8192
_RETRY_INTERVAL = 2
_SEPARATOR = b'\r\n'
_PING_TIMEOUT = 3
_PING_TIMEOUT = 10
_PING_INTERVAL = 5
_REQUEST_TIMEOUT_SECS = 60
_SHELL = os.getenv('SHELL', '/bin/sh')
Expand All @@ -66,6 +66,8 @@
FAILED = 'failed'
DISCONNECTED = 'disconnected'

_DEFAULT_FORWARD_HOST = '127.0.0.1'


class PingTimeoutError(Exception):
pass
Expand Down Expand Up @@ -198,7 +200,7 @@ class Ghost:

def __init__(self, overlord_addrs, tls_settings=None, mode=AGENT, mid=None,
sid=None, prop_file=None, terminal_sid=None, tty_device=None,
command=None, file_op=None, port=None, tls_mode=None):
command=None, file_op=None, host=None, port=None, tls_mode=None):
"""Constructor.
Args:
Expand Down Expand Up @@ -259,6 +261,7 @@ def __init__(self, overlord_addrs, tls_settings=None, mode=AGENT, mid=None,
self._shell_command = command
self._file_op = file_op
self._download_queue = queue.Queue()
self._host = host
self._port = port

def SetIgnoreChild(self, status):
Expand Down Expand Up @@ -390,7 +393,7 @@ def CloseSockets(self):
pass

def SpawnGhost(self, mode, sid=None, terminal_sid=None, tty_device=None,
command=None, file_op=None, port=None):
command=None, file_op=None, host=None, port=None):
"""Spawn a child ghost with specific mode.
Returns:
Expand All @@ -405,7 +408,7 @@ def SpawnGhost(self, mode, sid=None, terminal_sid=None, tty_device=None,
g = Ghost([self._connected_addr], tls_settings=self._tls_settings,
mode=mode, mid=Ghost.RANDOM_MID, sid=sid,
terminal_sid=terminal_sid, tty_device=tty_device,
command=command, file_op=file_op, port=port)
command=command, file_op=file_op, host=host, port=port)
g.Start()
sys.exit(0)
else:
Expand Down Expand Up @@ -819,7 +822,7 @@ def SpawnPortForwardServer(self, unused_var):
try:
src_sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
src_sock.settimeout(_CONNECT_TIMEOUT)
src_sock.connect(('127.0.0.1', self._port))
src_sock.connect((self._host, self._port))

src_sock.send(self._sock.RecvBuf())

Expand Down Expand Up @@ -853,7 +856,7 @@ def timeout_handler(x):
raise PingTimeoutError

self._last_ping = self.Timestamp()
self.SendRequest('ping', {}, timeout_handler, 5)
self.SendRequest('ping', {}, timeout_handler, _PING_TIMEOUT)

def HandleFileDownloadRequest(self, msg):
params = msg['params']
Expand Down Expand Up @@ -939,7 +942,9 @@ def HandleRequest(self, msg):
elif command == 'file_upload':
self.HandleFileUploadRequest(msg)
elif command == 'forward':
self.SpawnGhost(self.FORWARD, params['sid'], port=params['port'])
self.SpawnGhost(self.FORWARD, params['sid'],
host=params.get('host', _DEFAULT_FORWARD_HOST),
port=params['port'])
self.SendResponse(msg, SUCCESS)

def HandleResponse(self, response):
Expand Down
Loading

0 comments on commit aedbb54

Please sign in to comment.