-
Notifications
You must be signed in to change notification settings - Fork 1
/
webdav_stageout.py
108 lines (84 loc) · 3.4 KB
/
webdav_stageout.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
import os
import tempfile
from airflow.models import Variable
from airflow.decorators import dag, task
from airflow.models.param import Param
import pendulum
from decors import get_connection, get_parameter
from utils import (
RFSC,
walk_dir,
mkdir_rec,
resolve_oid,
get_webdav_client,
get_webdav_prefix,
clean_up_vaultid,
file_exist,
is_dir
)
@dag(
schedule=None,
start_date=pendulum.today("UTC"),
on_success_callback=clean_up_vaultid,
params={
"vault_id": Param(default="", type="string"),
"host": Param(default="", type="string"),
"port": Param(type="integer", default=22),
"login": Param(default="", type="string"),
"path": Param(default="/tmp/", type="string"),
"oid": Param(default="", type="string"),
"verify_webdav_cert": Param(default=True, type="boolean")
},
)
def webdav_stageout():
@task(multiple_outputs=True)
def copy(**context):
oid = get_parameter(parameter="oid", default=False, **context)
if not oid:
print(
"Missing target storage id (oid) in pipeline parameters. Please provide datacat id"
)
webdav_connid, dirname = resolve_oid(oid=oid, type="storage_target")
client = get_webdav_client(webdav_connid=webdav_connid)
client.verify = get_parameter(parameter="verify_webdav_cert", default=True, **context)
prefix = get_webdav_prefix(client=client, dirname=dirname)
if not prefix:
print("Unable to determine common prefix")
print(f"Determined common prefix: {prefix}")
params = context["params"]
if (s_con_id := params.pop("vault_id")) == "":
s_con_id = params.get("connection_id", None)
else:
s_con_id=f"vault_{s_con_id}"
source_ssh_hook = get_connection(conn_id=s_con_id, params=params)
sftp_client = source_ssh_hook.get_conn().open_sftp()
sclient = RFSC(sftp_client)
working_dir = Variable.get("working_dir", default_var="/tmp/")
copied = {}
try:
print('Checking if ', params['path'],' is a directory...')
if is_dir(sftp=sftp_client, name=params['path']):
print('Postive. Recursive search of all files')
mappings = walk_dir(client=sclient, path=params["path"], prefix="")
else:
mappings = [params['path']]
params['path'] = os.path.dirname(params['path'])
print("Negative will only be copying one file")
except IOError:
print("Invalid path or file name")
return -1
for fname in mappings:
with tempfile.NamedTemporaryFile(dir=working_dir) as tmp:
print(f"Getting: {fname}->{tmp.name}")
sftp_client.get(remotepath=fname, localpath=tmp.name)
directory, fl = os.path.split(fname[len(params["path"])+1:])
remote_path = os.path.join(dirname, directory)
mkdir_rec(client=client, path=remote_path)
print(f"Uploading {tmp.name}->{os.path.join(remote_path, fl)}")
client.upload_sync(
remote_path=os.path.join(remote_path, fl), local_path=tmp.name,
)
copied[fname] = os.path.join(remote_path, fl)
return copied
copy()
dag = webdav_stageout()