Skip to content

Commit eec1c5d

Browse files
committed
Pass through remote temp filename with file up/download commands
This means the temporary filename is generated at operation generation stage rather than execution, meaning any in-deploy config overrides are correctly used.
1 parent f5cb24e commit eec1c5d

File tree

10 files changed

+89
-24
lines changed

10 files changed

+89
-24
lines changed

pyinfra/api/command.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -127,10 +127,11 @@ def execute(self, state, host, executor_kwargs):
127127

128128

129129
class FileUploadCommand(PyinfraCommand):
130-
def __init__(self, src, dest, **kwargs):
130+
def __init__(self, src, dest, remote_temp_filename=None, **kwargs):
131131
super(FileUploadCommand, self).__init__(**kwargs)
132132
self.src = src
133133
self.dest = dest
134+
self.remote_temp_filename = remote_temp_filename
134135

135136
def __repr__(self):
136137
return 'FileUploadCommand({0}, {1})'.format(self.src, self.dest)
@@ -140,17 +141,19 @@ def execute(self, state, host, executor_kwargs):
140141

141142
return host.put_file(
142143
self.src, self.dest,
144+
remote_temp_filename=self.remote_temp_filename,
143145
print_output=state.print_output,
144146
print_input=state.print_input,
145147
**executor_kwargs
146148
)
147149

148150

149151
class FileDownloadCommand(PyinfraCommand):
150-
def __init__(self, src, dest, **kwargs):
152+
def __init__(self, src, dest, remote_temp_filename=None, **kwargs):
151153
super(FileDownloadCommand, self).__init__(**kwargs)
152154
self.src = src
153155
self.dest = dest
156+
self.remote_temp_filename = remote_temp_filename
154157

155158
def __repr__(self):
156159
return 'FileDownloadCommand({0}, {1})'.format(self.src, self.dest)
@@ -160,6 +163,7 @@ def execute(self, state, host, executor_kwargs):
160163

161164
return host.get_file(
162165
self.src, self.dest,
166+
remote_temp_filename=self.remote_temp_filename,
163167
print_output=state.print_output,
164168
print_input=state.print_input,
165169
**executor_kwargs

pyinfra/api/connectors/chroot.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@ def put_file(
9393
host,
9494
filename_or_io,
9595
remote_filename,
96+
remote_temp_filename=None, # ignored
9697
print_output=False,
9798
print_input=False,
9899
**kwargs # ignored (sudo/etc)
@@ -146,6 +147,7 @@ def get_file(
146147
host,
147148
remote_filename,
148149
filename_or_io,
150+
remote_temp_filename=None, # ignored
149151
print_output=False,
150152
print_input=False,
151153
**kwargs # ignored (sudo/etc)

pyinfra/api/connectors/docker.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ def run_shell_command(
121121

122122
def put_file(
123123
state, host, filename_or_io, remote_filename,
124+
remote_temp_filename=None, # ignored
124125
print_output=False, print_input=False,
125126
**kwargs # ignored (sudo/etc)
126127
):
@@ -171,6 +172,7 @@ def put_file(
171172

172173
def get_file(
173174
state, host, remote_filename, filename_or_io,
175+
remote_temp_filename=None, # ignored
174176
print_output=False, print_input=False,
175177
**kwargs # ignored (sudo/etc)
176178
):

pyinfra/api/connectors/dockerssh.py

Lines changed: 12 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -138,6 +138,7 @@ def run_shell_command(
138138

139139
def put_file(
140140
state, host, filename_or_io, remote_filename,
141+
remote_temp_filename=None,
141142
print_output=False, print_input=False,
142143
**kwargs # ignored (sudo/etc)
143144
):
@@ -146,12 +147,12 @@ def put_file(
146147
temporary location and then uploading it into the container using ``docker cp``.
147148
'''
148149

149-
fd, temp_filename = mkstemp()
150-
remote_temp_filename = state.get_temp_filename(temp_filename)
150+
fd, local_temp_filename = mkstemp()
151+
remote_temp_filename = remote_temp_filename or state.get_temp_filename(local_temp_filename)
151152

152153
# Load our file or IO object and write it to the temporary file
153154
with get_file_io(filename_or_io) as file_io:
154-
with open(temp_filename, 'wb') as temp_f:
155+
with open(local_temp_filename, 'wb') as temp_f:
155156
data = file_io.read()
156157

157158
if isinstance(data, six.text_type):
@@ -160,7 +161,7 @@ def put_file(
160161
temp_f.write(data)
161162

162163
# upload file to remote server
163-
ssh_status = ssh.put_file(state, host, temp_filename, remote_temp_filename)
164+
ssh_status = ssh.put_file(state, host, local_temp_filename, remote_temp_filename)
164165
if not ssh_status:
165166
raise IOError('Failed to copy file over ssh')
166167

@@ -180,9 +181,9 @@ def put_file(
180181
)
181182
finally:
182183
os.close(fd)
183-
os.remove(temp_filename)
184+
os.remove(local_temp_filename)
184185
remote_remove(
185-
state, host, temp_filename,
186+
state, host, local_temp_filename,
186187
print_output=print_output,
187188
print_input=print_input,
188189
)
@@ -200,6 +201,7 @@ def put_file(
200201

201202
def get_file(
202203
state, host, remote_filename, filename_or_io,
204+
remote_temp_filename=None,
203205
print_output=False, print_input=False,
204206
**kwargs # ignored (sudo/etc)
205207
):
@@ -208,14 +210,14 @@ def get_file(
208210
location and then reading that into our final file/IO object.
209211
'''
210212

211-
temp_filename = state.get_temp_filename(remote_filename)
213+
remote_temp_filename = remote_temp_filename or state.get_temp_filename(remote_filename)
212214

213215
try:
214216
docker_id = host.host_data['docker_container_id']
215217
docker_command = 'docker cp {0}:{1} {2}'.format(
216218
docker_id,
217219
remote_filename,
218-
temp_filename,
220+
remote_temp_filename,
219221
)
220222

221223
status, _, stderr = ssh.run_shell_command(
@@ -225,9 +227,9 @@ def get_file(
225227
print_input=print_input,
226228
)
227229

228-
ssh_status = ssh.get_file(state, host, temp_filename, filename_or_io)
230+
ssh_status = ssh.get_file(state, host, remote_temp_filename, filename_or_io)
229231
finally:
230-
remote_remove(state, host, temp_filename, print_output=print_output,
232+
remote_remove(state, host, remote_temp_filename, print_output=print_output,
231233
print_input=print_input)
232234

233235
if not ssh_status:

pyinfra/api/connectors/local.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,7 @@ def execute_command():
9696

9797
def put_file(
9898
state, host, filename_or_io, remote_filename,
99+
remote_temp_filename=None, # ignored
99100
print_output=False, print_input=False,
100101
**command_kwargs
101102
):
@@ -141,6 +142,7 @@ def put_file(
141142

142143
def get_file(
143144
state, host, remote_filename, filename_or_io,
145+
remote_temp_filename=None, # ignored
144146
print_output=False, print_input=False,
145147
**command_kwargs
146148
):

pyinfra/api/connectors/ssh.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -339,6 +339,7 @@ def _get_file(host, remote_filename, filename_or_io):
339339

340340
def get_file(
341341
state, host, remote_filename, filename_or_io,
342+
remote_temp_filename=None,
342343
sudo=False, sudo_user=None, su_user=None,
343344
print_output=False, print_input=False,
344345
**command_kwargs
@@ -351,7 +352,7 @@ def get_file(
351352

352353
if sudo or su_user:
353354
# Get temp file location
354-
temp_file = state.get_temp_filename(remote_filename)
355+
temp_file = remote_temp_filename or state.get_temp_filename(remote_filename)
355356

356357
# Copy the file to the tempfile location and add read permissions
357358
command = 'cp {0} {1} && chmod +r {0}'.format(remote_filename, temp_file)
@@ -406,6 +407,7 @@ def _put_file(host, filename_or_io, remote_location):
406407

407408
def put_file(
408409
state, host, filename_or_io, remote_filename,
410+
remote_temp_filename=None,
409411
sudo=False, sudo_user=None, su_user=None,
410412
print_output=False, print_input=False,
411413
**command_kwargs
@@ -419,7 +421,7 @@ def put_file(
419421
# user connected, so upload to tmp and copy/chown w/sudo and/or su_user
420422
if sudo or su_user:
421423
# Get temp file location
422-
temp_file = state.get_temp_filename(remote_filename)
424+
temp_file = remote_temp_filename or state.get_temp_filename(remote_filename)
423425
_put_file(host, filename_or_io, temp_file)
424426

425427
# Make sure our sudo/su user can access the file

pyinfra/api/connectors/winrm.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ def run_shell_command(
191191

192192
def get_file(
193193
state, host, remote_filename, filename_or_io,
194+
remote_temp_filename=None,
194195
**command_kwargs
195196
):
196197
raise PyinfraError('Not implemented')
@@ -222,6 +223,7 @@ def _put_file(state, host, filename_or_io, remote_location, chunk_size=2048):
222223
def put_file(
223224
state, host, filename_or_io, remote_filename,
224225
print_output=False, print_input=False,
226+
remote_temp_filename=None, # ignored
225227
**command_kwargs
226228
):
227229
'''

pyinfra/operations/files.py

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -706,11 +706,11 @@ def get(
706706

707707
# No remote file, so assume exists and download it "blind"
708708
if not remote_file or force:
709-
yield FileDownloadCommand(src, dest)
709+
yield FileDownloadCommand(src, dest, remote_temp_filename=state.get_temp_filename(dest))
710710

711711
# No local file, so always download
712712
elif not os_path.exists(dest):
713-
yield FileDownloadCommand(src, dest)
713+
yield FileDownloadCommand(src, dest, remote_temp_filename=state.get_temp_filename(dest))
714714

715715
# Remote file exists - check if it matches our local
716716
else:
@@ -719,7 +719,7 @@ def get(
719719

720720
# Check sha1sum, upload if needed
721721
if local_sum != remote_sum:
722-
yield FileDownloadCommand(src, dest)
722+
yield FileDownloadCommand(src, dest, remote_temp_filename=state.get_temp_filename(dest))
723723

724724

725725
@operation(pipeline_facts={
@@ -799,7 +799,11 @@ def put(
799799

800800
# No remote file, always upload and user/group/mode if supplied
801801
if not remote_file or force:
802-
yield FileUploadCommand(local_file, dest)
802+
yield FileUploadCommand(
803+
local_file,
804+
dest,
805+
remote_temp_filename=state.get_temp_filename(dest),
806+
)
803807

804808
if user or group:
805809
yield chown(dest, user, group)
@@ -813,7 +817,11 @@ def put(
813817

814818
# Check sha1sum, upload if needed
815819
if local_sum != remote_sum:
816-
yield FileUploadCommand(local_file, dest)
820+
yield FileUploadCommand(
821+
local_file,
822+
dest,
823+
remote_temp_filename=state.get_temp_filename(dest),
824+
)
817825

818826
if user or group:
819827
yield chown(dest, user, group)

pyinfra/operations/windows_files.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -209,7 +209,11 @@ def put(
209209

210210
# No remote file, always upload and user/group/mode if supplied
211211
if not remote_file or force:
212-
yield FileUploadCommand(local_file, dest)
212+
yield FileUploadCommand(
213+
local_file,
214+
dest,
215+
remote_temp_filename=state.get_temp_filename(dest),
216+
)
213217

214218
# if user or group:
215219
# yield chown(dest, user, group)
@@ -224,7 +228,11 @@ def put(
224228

225229
# Check sha1sum, upload if needed
226230
if local_sum != remote_sum:
227-
yield FileUploadCommand(local_file, dest)
231+
yield FileUploadCommand(
232+
local_file,
233+
dest,
234+
remote_temp_filename=state.get_temp_filename(dest),
235+
)
228236

229237
# if user or group:
230238
# yield chown(dest, user, group)

tests/test_connectors/test_ssh.py

Lines changed: 36 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -568,9 +568,6 @@ def test_run_shell_command_sudo_password_automatic_prompt(
568568
"sudo -H -A -k sh -c 'echo Šablony'"
569569
), get_pty=False)
570570

571-
# SSH file put/get tests
572-
#
573-
574571
@patch('pyinfra.api.connectors.ssh.SSHClient')
575572
@patch('pyinfra.api.connectors.util._get_sudo_password')
576573
def test_run_shell_command_retry_for_sudo_password(
@@ -606,6 +603,9 @@ def test_run_shell_command_retry_for_sudo_password(
606603
get_pty=False,
607604
)
608605

606+
# SSH file put/get tests
607+
#
608+
609609
@patch('pyinfra.api.connectors.ssh.SSHClient')
610610
@patch('pyinfra.api.connectors.ssh.SFTPClient')
611611
def test_put_file(self, fake_sftp_client, fake_ssh_client):
@@ -727,6 +727,39 @@ def test_put_file_su_user_fail_copy(self, fake_sftp_client, fake_ssh_client):
727727
fake_open(), '/tmp/pyinfra-43db9984686317089fefcf2e38de527e4cb44487',
728728
)
729729

730+
@patch('pyinfra.api.connectors.ssh.SSHClient')
731+
@patch('pyinfra.api.connectors.ssh.SFTPClient')
732+
def test_put_file_sudo_custom_temp_file(self, fake_sftp_client, fake_ssh_client):
733+
inventory = make_inventory(hosts=('anotherhost',))
734+
State(inventory, Config())
735+
host = inventory.get_host('anotherhost')
736+
host.connect()
737+
738+
stdout_mock = MagicMock()
739+
stdout_mock.channel.recv_exit_status.return_value = 0
740+
fake_ssh_client().exec_command.return_value = MagicMock(), stdout_mock, MagicMock()
741+
742+
fake_open = mock_open(read_data='test!')
743+
with patch('pyinfra.api.util.open', fake_open, create=True):
744+
status = host.put_file(
745+
'not-a-file', 'not another file',
746+
print_output=True,
747+
sudo=True,
748+
sudo_user='ubuntu',
749+
remote_temp_filename='/a-different-tempfile',
750+
)
751+
752+
assert status is True
753+
754+
fake_ssh_client().exec_command.assert_called_with((
755+
"sh -c 'rm -f "
756+
"/a-different-tempfile'"
757+
), get_pty=False)
758+
759+
fake_sftp_client.from_transport().putfo.assert_called_with(
760+
fake_open(), '/a-different-tempfile',
761+
)
762+
730763
@patch('pyinfra.api.connectors.ssh.SSHClient')
731764
@patch('pyinfra.api.connectors.ssh.SFTPClient')
732765
def test_get_file(self, fake_sftp_client, fake_ssh_client):

0 commit comments

Comments
 (0)