Skip to content

Commit 1803a56

Browse files
author
Phil Chiu
committed
Use context manager for prepend_path
1 parent bb6bd3f commit 1803a56

File tree

1 file changed

+23
-17
lines changed

1 file changed

+23
-17
lines changed

pyoptsparse/pyOpt_utils.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,8 @@
88
mat = {"csr": [rowp, colind, data], "shape": [nrow, ncols]} # A csr matrix
99
mat = {"csc": [colp, rowind, data], "shape": [nrow, ncols]} # A csc matrix
1010
"""
11-
1211
# Standard Python modules
12+
import contextlib
1313
import importlib
1414
import os
1515
import sys
@@ -576,6 +576,20 @@ def _broadcast_to_array(name: str, value: ArrayType, n_values: int, allow_none:
576576
return value
577577

578578

579+
@contextlib.contextmanager
580+
def _prepend_path(path: Union[str, Sequence[str]]):
581+
"""Context manager which temporarily prepends to `sys.path`."""
582+
if isinstance(path, str):
583+
path = [path]
584+
orig_path = sys.path
585+
if path:
586+
path = [os.path.abspath(os.path.expandvars(os.path.expanduser(p))) for p in path]
587+
sys.path = path + sys.path
588+
yield
589+
sys.path = orig_path
590+
return
591+
592+
579593
def import_module(
580594
module_name: str,
581595
path: Union[str, Sequence[str]] = (),
@@ -603,20 +617,12 @@ def import_module(
603617
if on_error.lower() not in ("raise", "return"):
604618
raise ValueError("`on_error` must be 'raise' or 'return'.")
605619

606-
if isinstance(path, str):
607-
path = [path]
608-
609-
orig_path = sys.path
610-
if path:
611-
path = [os.path.abspath(os.path.expandvars(os.path.expanduser(p))) for p in path]
612-
sys.path = path + sys.path
613-
try:
614-
module = importlib.import_module(module_name)
615-
except ImportError as e:
616-
if on_error.lower() == "raise":
617-
raise e
618-
else:
619-
module = e
620-
finally:
621-
sys.path = orig_path
620+
with _prepend_path(path):
621+
try:
622+
module = importlib.import_module(module_name)
623+
except ImportError as e:
624+
if on_error.lower() == "raise":
625+
raise e
626+
else:
627+
module = e
622628
return module

0 commit comments

Comments
 (0)