Skip to content

Commit 80b2213

Browse files
authored
faster mmdet config download (#285)
1 parent e46b076 commit 80b2213

File tree

1 file changed

+81
-80
lines changed

1 file changed

+81
-80
lines changed

sahi/utils/mmdet.py

+81-80
Original file line numberDiff line numberDiff line change
@@ -92,105 +92,106 @@ def download_mmdet_config(
9292
)
9393
main_config_url = base_config_url + config_file_name
9494

95-
# set config dirs
96-
temp_configs_dir = Path("temp_mmdet_configs")
97-
main_config_dir = temp_configs_dir / model_name
95+
# set final config dirs
96+
configs_dir = Path("mmdet_configs") / mmdet_ver
97+
model_config_dir = configs_dir / model_name
9898

99-
# create config dirs
100-
temp_configs_dir.mkdir(parents=True, exist_ok=True)
101-
main_config_dir.mkdir(parents=True, exist_ok=True)
99+
# create final config dir
100+
configs_dir.mkdir(parents=True, exist_ok=True)
101+
model_config_dir.mkdir(parents=True, exist_ok=True)
102102

103-
# get main config file name
103+
# get final config file name
104104
filename = Path(main_config_url).name
105105

106-
# set main config file path
107-
main_config_path = str(main_config_dir / filename)
108-
109-
# download main config file
110-
urllib.request.urlretrieve(
111-
main_config_url,
112-
main_config_path,
113-
)
106+
# set final config file path
107+
final_config_path = str(model_config_dir / filename)
114108

115-
# read main config file
116-
sys.path.insert(0, str(main_config_dir))
117-
temp_module_name = path.splitext(filename)[0]
118-
mod = import_module(temp_module_name)
119-
sys.path.pop(0)
120-
config_dict = {name: value for name, value in mod.__dict__.items() if not name.startswith("__")}
109+
if not Path(final_config_path).exists():
110+
# set config dirs
111+
temp_configs_dir = Path("temp_mmdet_configs")
112+
main_config_dir = temp_configs_dir / model_name
121113

122-
# handle when config_dict["_base_"] is string
123-
if not isinstance(config_dict["_base_"], list):
124-
config_dict["_base_"] = [config_dict["_base_"]]
114+
# create config dirs
115+
temp_configs_dir.mkdir(parents=True, exist_ok=True)
116+
main_config_dir.mkdir(parents=True, exist_ok=True)
125117

126-
# iterate over secondary config files
127-
for secondary_config_file_path in config_dict["_base_"]:
128-
# set config url
129-
config_url = base_config_url + secondary_config_file_path
130-
config_path = main_config_dir / secondary_config_file_path
118+
# get main config file name
119+
filename = Path(main_config_url).name
131120

132-
# create secondary config dir
133-
config_path.parent.mkdir(parents=True, exist_ok=True)
121+
# set main config file path
122+
main_config_path = str(main_config_dir / filename)
134123

135-
# download secondary config files
124+
# download main config file
136125
urllib.request.urlretrieve(
137-
config_url,
138-
str(config_path),
126+
main_config_url,
127+
main_config_path,
139128
)
140129

141-
# read secondary config file
142-
secondary_config_dir = config_path.parent
143-
sys.path.insert(0, str(secondary_config_dir))
144-
temp_module_name = path.splitext(Path(config_path).name)[0]
130+
# read main config file
131+
sys.path.insert(0, str(main_config_dir))
132+
temp_module_name = path.splitext(filename)[0]
145133
mod = import_module(temp_module_name)
146134
sys.path.pop(0)
147-
secondary_config_dict = {name: value for name, value in mod.__dict__.items() if not name.startswith("__")}
148-
149-
# go deeper if there are more steps
150-
if secondary_config_dict.get("_base_") is not None:
151-
# handle when config_dict["_base_"] is string
152-
if not isinstance(secondary_config_dict["_base_"], list):
153-
secondary_config_dict["_base_"] = [secondary_config_dict["_base_"]]
154-
155-
# iterate over third config files
156-
for third_config_file_path in secondary_config_dict["_base_"]:
157-
# set config url
158-
config_url = base_config_url + third_config_file_path
159-
config_path = main_config_dir / third_config_file_path
160-
161-
# create secondary config dir
162-
config_path.parent.mkdir(parents=True, exist_ok=True)
163-
# download secondary config files
164-
urllib.request.urlretrieve(
165-
config_url,
166-
str(config_path),
167-
)
135+
config_dict = {name: value for name, value in mod.__dict__.items() if not name.startswith("__")}
168136

169-
# set final config dirs
170-
configs_dir = Path("mmdet_configs") / mmdet_ver
171-
model_config_dir = configs_dir / model_name
172-
173-
# create final config dir
174-
configs_dir.mkdir(parents=True, exist_ok=True)
175-
model_config_dir.mkdir(parents=True, exist_ok=True)
176-
177-
# get final config file name
178-
filename = Path(main_config_url).name
179-
180-
# set final config file path
181-
final_config_path = str(model_config_dir / filename)
137+
# handle when config_dict["_base_"] is string
138+
if not isinstance(config_dict["_base_"], list):
139+
config_dict["_base_"] = [config_dict["_base_"]]
182140

183-
# dump final config as single file
184-
from mmcv import Config
141+
# iterate over secondary config files
142+
for secondary_config_file_path in config_dict["_base_"]:
143+
# set config url
144+
config_url = base_config_url + secondary_config_file_path
145+
config_path = main_config_dir / secondary_config_file_path
185146

186-
config = Config.fromfile(main_config_path)
187-
config.dump(final_config_path)
147+
# create secondary config dir
148+
config_path.parent.mkdir(parents=True, exist_ok=True)
188149

189-
if verbose:
190-
print(f"mmdet config file has been downloaded to {path.abspath(final_config_path)}")
150+
# download secondary config files
151+
urllib.request.urlretrieve(
152+
config_url,
153+
str(config_path),
154+
)
191155

192-
# remove temp config dir
193-
shutil.rmtree(temp_configs_dir)
156+
# read secondary config file
157+
secondary_config_dir = config_path.parent
158+
sys.path.insert(0, str(secondary_config_dir))
159+
temp_module_name = path.splitext(Path(config_path).name)[0]
160+
mod = import_module(temp_module_name)
161+
sys.path.pop(0)
162+
secondary_config_dict = {name: value for name, value in mod.__dict__.items() if not name.startswith("__")}
163+
164+
# go deeper if there are more steps
165+
if secondary_config_dict.get("_base_") is not None:
166+
# handle when config_dict["_base_"] is string
167+
if not isinstance(secondary_config_dict["_base_"], list):
168+
secondary_config_dict["_base_"] = [secondary_config_dict["_base_"]]
169+
170+
# iterate over third config files
171+
for third_config_file_path in secondary_config_dict["_base_"]:
172+
# set config url
173+
config_url = base_config_url + third_config_file_path
174+
config_path = main_config_dir / third_config_file_path
175+
176+
# create secondary config dir
177+
config_path.parent.mkdir(parents=True, exist_ok=True)
178+
# download secondary config files
179+
urllib.request.urlretrieve(
180+
config_url,
181+
str(config_path),
182+
)
183+
184+
# dump final config as single file
185+
from mmcv import Config
186+
187+
config = Config.fromfile(main_config_path)
188+
config.dump(final_config_path)
189+
190+
if verbose:
191+
print(f"mmdet config file has been downloaded to {path.abspath(final_config_path)}")
192+
193+
# remove temp config dir
194+
shutil.rmtree(temp_configs_dir)
194195

195196
return path.abspath(final_config_path)
196197

0 commit comments

Comments
 (0)