|
18 | 18 | from __future__ import division
|
19 | 19 | from __future__ import print_function
|
20 | 20 |
|
| 21 | +import argparse |
21 | 22 | import errno
|
22 | 23 | import os
|
23 | 24 | import platform
|
|
32 | 33 | from distutils.spawn import find_executable as which
|
33 | 34 | # pylint: enable=g-import-not-at-top
|
34 | 35 |
|
35 |
| -_TF_BAZELRC = os.path.join(os.path.dirname(os.path.abspath(__file__)), |
36 |
| - '.tf_configure.bazelrc') |
37 |
| -_TF_WORKSPACE = os.path.join(os.path.dirname(os.path.abspath(__file__)), |
38 |
| - 'WORKSPACE') |
39 | 36 | _DEFAULT_CUDA_VERSION = '9.0'
|
40 | 37 | _DEFAULT_CUDNN_VERSION = '7'
|
41 | 38 | _DEFAULT_CUDA_COMPUTE_CAPABILITIES = '3.5,5.2'
|
|
51 | 48 |
|
52 | 49 | _DEFAULT_PROMPT_ASK_ATTEMPTS = 10
|
53 | 50 |
|
| 51 | +_TF_WORKSPACE_ROOT = os.path.abspath(os.path.dirname(__file__)) |
| 52 | +_TF_BAZELRC_FILENAME = '.tf_configure.bazelrc' |
| 53 | +_TF_BAZELRC = os.path.join(_TF_WORKSPACE_ROOT, _TF_BAZELRC_FILENAME) |
| 54 | +_TF_WORKSPACE = os.path.join(_TF_WORKSPACE_ROOT, 'WORKSPACE') |
| 55 | + |
54 | 56 |
|
55 | 57 | class UserInputError(Exception):
|
56 | 58 | pass
|
@@ -119,22 +121,6 @@ def sed_in_place(filename, old, new):
|
119 | 121 | f.write(newdata)
|
120 | 122 |
|
121 | 123 |
|
122 |
| -def remove_line_with(filename, token): |
123 |
| - """Remove lines that contain token from file. |
124 |
| -
|
125 |
| - Args: |
126 |
| - filename: string for filename. |
127 |
| - token: string token to check if to remove a line from file or not. |
128 |
| - """ |
129 |
| - with open(filename, 'r') as f: |
130 |
| - filedata = f.read() |
131 |
| - |
132 |
| - with open(filename, 'w') as f: |
133 |
| - for line in filedata.strip().split('\n'): |
134 |
| - if token not in line: |
135 |
| - f.write(line + '\n') |
136 |
| - |
137 |
| - |
138 | 124 | def write_to_bazelrc(line):
|
139 | 125 | with open(_TF_BAZELRC, 'a') as f:
|
140 | 126 | f.write(line + '\n')
|
@@ -245,33 +231,35 @@ def setup_python(environ_cp):
|
245 | 231 | environ_cp['PYTHON_BIN_PATH'] = python_bin_path
|
246 | 232 |
|
247 | 233 | # Write tools/python_bin_path.sh
|
248 |
| - with open('tools/python_bin_path.sh', 'w') as f: |
| 234 | + with open(os.path.join( |
| 235 | + _TF_WORKSPACE_ROOT, 'tools', 'python_bin_path.sh'), 'w') as f: |
249 | 236 | f.write('export PYTHON_BIN_PATH="%s"' % python_bin_path)
|
250 | 237 |
|
251 | 238 |
|
252 |
| -def reset_tf_configure_bazelrc(): |
| 239 | +def reset_tf_configure_bazelrc(workspace_path): |
253 | 240 | """Reset file that contains customized config settings."""
|
254 | 241 | open(_TF_BAZELRC, 'w').close()
|
| 242 | + bazelrc_path = os.path.join(workspace_path, '.bazelrc') |
255 | 243 |
|
256 |
| - home = os.path.expanduser('~') |
257 |
| - if not os.path.exists('.bazelrc'): |
258 |
| - if os.path.exists(os.path.join(home, '.bazelrc')): |
259 |
| - with open('.bazelrc', 'a') as f: |
260 |
| - f.write('import %s/.bazelrc\n' % home.replace('\\', '/')) |
261 |
| - else: |
262 |
| - open('.bazelrc', 'w').close() |
263 |
| - |
264 |
| - remove_line_with('.bazelrc', 'tf_configure') |
265 |
| - with open('.bazelrc', 'a') as f: |
266 |
| - f.write('import %workspace%/.tf_configure.bazelrc\n') |
| 244 | + data = [] |
| 245 | + if os.path.exists(bazelrc_path): |
| 246 | + with open(bazelrc_path, 'r') as f: |
| 247 | + data = f.read().splitlines() |
| 248 | + with open(bazelrc_path, 'w') as f: |
| 249 | + for l in data: |
| 250 | + if _TF_BAZELRC_FILENAME in l: |
| 251 | + continue |
| 252 | + f.write('%s\n' % l) |
| 253 | + f.write('import %s\n' % _TF_BAZELRC) |
267 | 254 |
|
268 | 255 |
|
269 | 256 | def cleanup_makefile():
|
270 | 257 | """Delete any leftover BUILD files from the Makefile build.
|
271 | 258 |
|
272 | 259 | These files could interfere with Bazel parsing.
|
273 | 260 | """
|
274 |
| - makefile_download_dir = 'tensorflow/contrib/makefile/downloads' |
| 261 | + makefile_download_dir = os.path.join( |
| 262 | + _TF_WORKSPACE_ROOT, 'tensorflow', 'contrib', 'makefile', 'downloads') |
275 | 263 | if os.path.isdir(makefile_download_dir):
|
276 | 264 | for root, _, filenames in os.walk(makefile_download_dir):
|
277 | 265 | for f in filenames:
|
@@ -1373,13 +1361,20 @@ def config_info_line(name, help_text):
|
1373 | 1361 |
|
1374 | 1362 |
|
1375 | 1363 | def main():
|
| 1364 | + parser = argparse.ArgumentParser() |
| 1365 | + parser.add_argument("--workspace", |
| 1366 | + type=str, |
| 1367 | + default=_TF_WORKSPACE_ROOT, |
| 1368 | + help="The absolute path to your active Bazel workspace.") |
| 1369 | + args = parser.parse_args() |
| 1370 | + |
1376 | 1371 | # Make a copy of os.environ to be clear when functions and getting and setting
|
1377 | 1372 | # environment variables.
|
1378 | 1373 | environ_cp = dict(os.environ)
|
1379 | 1374 |
|
1380 | 1375 | check_bazel_version('0.5.4')
|
1381 | 1376 |
|
1382 |
| - reset_tf_configure_bazelrc() |
| 1377 | + reset_tf_configure_bazelrc(args.workspace) |
1383 | 1378 | cleanup_makefile()
|
1384 | 1379 | setup_python(environ_cp)
|
1385 | 1380 |
|
|
0 commit comments