11import base64
22import logging
3- import shlex
43import subprocess
54
65import lib .config as config
7- from lib .common import HostAddress
86
97from typing import List , Literal , Union , overload
108
@@ -17,14 +15,6 @@ def __init__(self, returncode, stdout, cmd, exception_msg):
1715 self .stdout = stdout
1816 self .cmd = cmd
1917
20- class SSHCommandFailed (BaseCommandFailed ):
21- def __init__ (self , returncode , stdout , cmd ):
22- msg_end = f": { stdout } " if stdout else "."
23- super (SSHCommandFailed , self ).__init__ (
24- returncode , stdout , cmd ,
25- f'SSH command ({ cmd } ) failed with return code { returncode } { msg_end } '
26- )
27-
2818class LocalCommandFailed (BaseCommandFailed ):
2919 def __init__ (self , returncode , stdout , cmd ):
3020 msg_end = f": { stdout } " if stdout else "."
@@ -40,10 +30,6 @@ def __init__(self, returncode, stdout):
4030 self .returncode = returncode
4131 self .stdout = stdout
4232
43- class SSHResult (BaseCmdResult ):
44- def __init__ (self , returncode , stdout ):
45- super (SSHResult , self ).__init__ (returncode , stdout )
46-
4733class LocalCommandResult (BaseCmdResult ):
4834 def __init__ (self , returncode , stdout ):
4935 super (LocalCommandResult , self ).__init__ (returncode , stdout )
@@ -61,191 +47,6 @@ def _ellide_log_lines(log):
6147 reduced_message .append ("(...)" )
6248 return "\n {}" .format ("\n " .join (reduced_message ))
6349
64- def _ssh (hostname_or_ip , cmd , check , simple_output , suppress_fingerprint_warnings ,
65- background , decode , options ) -> Union [SSHResult , SSHCommandFailed , str , bytes , None ]:
66- opts = list (options )
67- opts .append ('-o "BatchMode yes"' )
68- if suppress_fingerprint_warnings :
69- # Suppress warnings and questions related to host key fingerprints
70- # because on a test network IPs get reused, VMs are reinstalled, etc.
71- # Based on https://unix.stackexchange.com/a/365976/257493
72- opts .append ('-o "StrictHostKeyChecking no"' )
73- opts .append ('-o "LogLevel ERROR"' )
74- opts .append ('-o "UserKnownHostsFile /dev/null"' )
75-
76- if isinstance (cmd , str ):
77- command = cmd
78- else :
79- command = " " .join (cmd )
80-
81- ssh_cmd = f"ssh root@{ hostname_or_ip } { ' ' .join (opts )} { shlex .quote (command )} "
82-
83- # Fetch banner and remove it to avoid stdout/stderr pollution.
84- banner_res = None
85- if config .ignore_ssh_banner :
86- banner_res = subprocess .run (
87- "ssh root@%s %s '%s'" % (hostname_or_ip , ' ' .join (opts ), '\n ' ),
88- shell = True ,
89- stdout = subprocess .PIPE ,
90- stderr = subprocess .STDOUT ,
91- check = False
92- )
93-
94- logging .debug (f"[{ hostname_or_ip } ] { command } " )
95- process = subprocess .Popen (
96- ssh_cmd ,
97- shell = True ,
98- stdout = subprocess .PIPE ,
99- stderr = subprocess .STDOUT
100- )
101- if background :
102- return None
103-
104- stdout = []
105- assert process .stdout is not None
106- for line in iter (process .stdout .readline , b'' ):
107- readable_line = line .decode (errors = 'replace' ).strip ()
108- stdout .append (line )
109- logging .debug ("> %s" , readable_line )
110- _ , stderr = process .communicate ()
111- res = subprocess .CompletedProcess (ssh_cmd , process .returncode , b'' .join (stdout ), stderr )
112-
113- # Get a decoded version of the output in any case, replacing potential errors
114- output_for_errors = res .stdout .decode (errors = 'replace' ).strip ()
115-
116- # Even if check is False, we still raise in case of return code 255, which means a SSH error.
117- if res .returncode == 255 :
118- return SSHCommandFailed (255 , "SSH Error: %s" % output_for_errors , command )
119-
120- output : Union [bytes , str ] = res .stdout
121- if banner_res :
122- if banner_res .returncode == 255 :
123- return SSHCommandFailed (255 , "SSH Error: %s" % banner_res .stdout .decode (errors = 'replace' ), command )
124- output = output [len (banner_res .stdout ):]
125-
126- if decode :
127- assert isinstance (output , bytes )
128- output = output .decode ()
129-
130- if res .returncode and check :
131- return SSHCommandFailed (res .returncode , output_for_errors , command )
132-
133- if simple_output :
134- return output .strip ()
135- return SSHResult (res .returncode , output )
136-
137- # The actual code is in _ssh().
138- # This function is kept short for shorter pytest traces upon SSH failures, which are common,
139- # as pytest prints the whole function definition that raised the SSHCommandFailed exception
140- @overload
141- def ssh (hostname_or_ip : HostAddress , cmd : Union [str , List [str ]], * , check : bool = True ,
142- simple_output : Literal [True ] = True ,
143- suppress_fingerprint_warnings : bool = True , background : Literal [False ] = False ,
144- decode : Literal [True ] = True , options : List [str ] = []) -> str :
145- ...
146- @overload
147- def ssh (hostname_or_ip : HostAddress , cmd : Union [str , List [str ]], * , check : bool = True ,
148- simple_output : Literal [True ] = True ,
149- suppress_fingerprint_warnings : bool = True , background : Literal [False ] = False ,
150- decode : Literal [False ], options : List [str ] = []) -> bytes :
151- ...
152- @overload
153- def ssh (hostname_or_ip : HostAddress , cmd : Union [str , List [str ]], * , check : bool = True ,
154- simple_output : Literal [False ],
155- suppress_fingerprint_warnings : bool = True , background : Literal [False ] = False ,
156- decode : bool = True , options : List [str ] = []) -> SSHResult :
157- ...
158- @overload
159- def ssh (hostname_or_ip : HostAddress , cmd : Union [str , List [str ]], * , check : bool = True ,
160- simple_output : Literal [False ],
161- suppress_fingerprint_warnings : bool = True , background : Literal [True ],
162- decode : bool = True , options : List [str ] = []) -> None :
163- ...
164- @overload
165- def ssh (hostname_or_ip : HostAddress , cmd : Union [str , List [str ]], * , check = True ,
166- simple_output : bool = True ,
167- suppress_fingerprint_warnings = True , background : bool = False ,
168- decode : bool = True , options : List [str ] = []) \
169- -> Union [str , bytes , SSHResult , None ]:
170- ...
171- def ssh (hostname_or_ip , cmd , * , check = True , simple_output = True ,
172- suppress_fingerprint_warnings = True ,
173- background = False , decode = True , options = []):
174- result_or_exc = _ssh (hostname_or_ip , cmd , check , simple_output , suppress_fingerprint_warnings ,
175- background , decode , options )
176- if isinstance (result_or_exc , SSHCommandFailed ):
177- raise result_or_exc
178- else :
179- return result_or_exc
180-
181- def ssh_with_result (hostname_or_ip , cmd , suppress_fingerprint_warnings = True ,
182- background = False , decode = True , options = []) -> SSHResult :
183- result_or_exc = _ssh (hostname_or_ip , cmd , False , False , suppress_fingerprint_warnings ,
184- background , decode , options )
185- if isinstance (result_or_exc , SSHCommandFailed ):
186- raise result_or_exc
187- elif isinstance (result_or_exc , SSHResult ):
188- return result_or_exc
189- assert False , "unexpected type"
190-
191- def scp (hostname_or_ip , src , dest , check = True , suppress_fingerprint_warnings = True , local_dest = False ):
192- # local import to avoid cyclic import; lib.netutils also import lib.commands
193- from lib .netutil import wrap_ip
194-
195- opts = '-o "BatchMode yes"'
196- if suppress_fingerprint_warnings :
197- # Suppress warnings and questions related to host key fingerprints
198- # because on a test network IPs get reused, VMs are reinstalled, etc.
199- # Based on https://unix.stackexchange.com/a/365976/257493
200- opts = '-o "StrictHostKeyChecking no" -o "LogLevel ERROR" -o "UserKnownHostsFile /dev/null"'
201-
202- ip = wrap_ip (hostname_or_ip )
203- if local_dest :
204- src = 'root@{}:{}' .format (ip , src )
205- else :
206- dest = 'root@{}:{}' .format (ip , dest )
207-
208- command = "scp {} {} {}" .format (opts , src , dest )
209- res = subprocess .run (
210- command ,
211- shell = True ,
212- stdout = subprocess .PIPE ,
213- stderr = subprocess .STDOUT ,
214- check = False
215- )
216-
217- errorcode_msg = "" if res .returncode == 0 else " - Got error code: %s" % res .returncode
218- logging .debug (f"[{ hostname_or_ip } ] scp: { src } => { dest } { errorcode_msg } " )
219-
220- if check and res .returncode :
221- raise SSHCommandFailed (res .returncode , res .stdout .decode (), command )
222-
223- return res
224-
225- def sftp (hostname_or_ip , cmds , check = True , suppress_fingerprint_warnings = True ):
226- opts = ''
227- if suppress_fingerprint_warnings :
228- # Suppress warnings and questions related to host key fingerprints
229- # because on a test network IPs get reused, VMs are reinstalled, etc.
230- # Based on https://unix.stackexchange.com/a/365976/257493
231- opts = '-o "StrictHostKeyChecking no" -o "LogLevel ERROR" -o "UserKnownHostsFile /dev/null"'
232-
233- args = "sftp {} -b - root@{}" .format (opts , hostname_or_ip )
234- input = bytes ("\n " .join (cmds ), 'utf-8' )
235- res = subprocess .run (
236- args ,
237- input = input ,
238- shell = True ,
239- stdout = subprocess .PIPE ,
240- stderr = subprocess .STDOUT ,
241- check = False
242- )
243-
244- if check and res .returncode :
245- raise SSHCommandFailed (res .returncode , res .stdout .decode (), "{} -- {}" .format (args , cmds ))
246-
247- return res
248-
24950@overload
25051def local_cmd (cmd : Union [str , List [str ]], * , check : bool = True , simple_output : Literal [True ] = True ,
25152 decode : Literal [True ] = True ) -> str :
0 commit comments