Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 9a7b75c

Browse files
Brett Zachary Allenfacebook-github-bot
Brett Zachary Allen
authored andcommittedApr 5, 2021
[feat] Inference CLI (#830)
Summary: Created script to allow users to run inference using various models, text inputs, and image inputs. All that is required is using mmf_inference and passing some command line args. Here is how you use it: ``` mmf_interactive checkpoint_path=/checkpoint/brettallen/mmft/model ``` This will initiate an interactive inference script that will ask for image URL and text input to run on. It will keep prompting the user for more images and text until they say `exit`. I also added the ability to just say `same` to use the same image because I figured it was inconvenient for the user to constantly have to copy over the image URL every time they want to ask a new question. Pull Request resolved: #830 Reviewed By: apsdehal Differential Revision: D27339291 Pulled By: brettallenyo fbshipit-source-id: db190c40626229900a3530da75b61d6e6b63cdd0
1 parent 99b1f74 commit 9a7b75c

File tree

3 files changed

+177
-111
lines changed

3 files changed

+177
-111
lines changed
 

‎mmf/utils/configuration.py

+112-111
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,116 @@ def get_mmf_env(key=None):
163163
return config.env
164164

165165

166+
def _merge_with_dotlist(
167+
config: DictConfig,
168+
opts: List[str],
169+
skip_missing: bool = False,
170+
log_info: bool = True,
171+
):
172+
# TODO: To remove technical debt, a possible solution is to use
173+
# struct mode to update with dotlist OmegaConf node. Look into this
174+
# in next iteration
175+
# TODO: Simplify this function
176+
if opts is None:
177+
opts = []
178+
179+
if len(opts) == 0:
180+
return config
181+
182+
# Support equal e.g. model=visual_bert for better future hydra support
183+
has_equal = opts[0].find("=") != -1
184+
if has_equal:
185+
opt_values = [opt.split("=", maxsplit=1) for opt in opts]
186+
if not all(len(opt) == 2 for opt in opt_values):
187+
for opt in opt_values:
188+
assert len(opt) == 2, f"{opt} has no value"
189+
else:
190+
assert len(opts) % 2 == 0, "Number of opts should be multiple of 2"
191+
opt_values = zip(opts[0::2], opts[1::2])
192+
193+
for opt, value in opt_values:
194+
if opt == "dataset":
195+
opt = "datasets"
196+
197+
splits = opt.split(".")
198+
current = config
199+
for idx, field in enumerate(splits):
200+
array_index = -1
201+
if field.find("[") != -1 and field.find("]") != -1:
202+
stripped_field = field[: field.find("[")]
203+
array_index = int(field[field.find("[") + 1 : field.find("]")])
204+
else:
205+
stripped_field = field
206+
if stripped_field not in current:
207+
if skip_missing is True:
208+
break
209+
raise AttributeError(
210+
"While updating configuration"
211+
" option {} is missing from"
212+
" configuration at field {}".format(opt, stripped_field)
213+
)
214+
if isinstance(current[stripped_field], collections.abc.Mapping):
215+
current = current[stripped_field]
216+
elif (
217+
isinstance(current[stripped_field], collections.abc.Sequence)
218+
and array_index != -1
219+
):
220+
try:
221+
current_value = current[stripped_field][array_index]
222+
except OCErrors.ConfigIndexError:
223+
if skip_missing:
224+
break
225+
raise
226+
227+
# Case where array element to be updated is last element
228+
if (
229+
not isinstance(
230+
current_value,
231+
(collections.abc.Mapping, collections.abc.Sequence),
232+
)
233+
or idx == len(splits) - 1
234+
):
235+
if log_info:
236+
logger.info(f"Overriding option {opt} to {value}")
237+
current[stripped_field][array_index] = _decode_value(value)
238+
else:
239+
# Otherwise move on down the chain
240+
current = current_value
241+
else:
242+
if idx == len(splits) - 1:
243+
if log_info:
244+
logger.info(f"Overriding option {opt} to {value}")
245+
current[stripped_field] = _decode_value(value)
246+
else:
247+
if skip_missing:
248+
break
249+
250+
raise AttributeError(
251+
"While updating configuration",
252+
"option {} is not present "
253+
"after field {}".format(opt, stripped_field),
254+
)
255+
256+
return config
257+
258+
259+
def _decode_value(value):
260+
# https://github.com/rbgirshick/yacs/blob/master/yacs/config.py#L400
261+
if not isinstance(value, str):
262+
return value
263+
264+
if value == "None":
265+
value = None
266+
267+
try:
268+
value = literal_eval(value)
269+
except ValueError:
270+
pass
271+
except SyntaxError:
272+
pass
273+
return value
274+
275+
166276
def resolve_cache_dir(env_variable="MMF_CACHE_DIR", default="mmf"):
167277
# Some of this follow what "transformers" does for there cache resolving
168278
try:
@@ -217,7 +327,7 @@ def __init__(self, args=None, default_only=False):
217327

218328
# Initially, silently add opts so that some of the overrides for the defaults
219329
# from command line required for setup can be honored
220-
self._default_config = self._merge_with_dotlist(
330+
self._default_config = _merge_with_dotlist(
221331
self._default_config, args.opts, skip_missing=True, log_info=False
222332
)
223333
# Register the config and configuration for setup
@@ -231,7 +341,7 @@ def __init__(self, args=None, default_only=False):
231341

232342
self.config = OmegaConf.merge(self._default_config, other_configs)
233343

234-
self.config = self._merge_with_dotlist(self.config, args.opts)
344+
self.config = _merge_with_dotlist(self.config, args.opts)
235345
self._update_specific(self.config)
236346
self.upgrade(self.config)
237347
# Resolve the config here itself after full creation so that spawned workers
@@ -382,115 +492,6 @@ def _register_resolvers(self):
382492
OmegaConf.register_resolver("resolve_cache_dir", resolve_cache_dir)
383493
OmegaConf.register_resolver("resolve_dir", resolve_dir)
384494

385-
def _merge_with_dotlist(
386-
self,
387-
config: DictConfig,
388-
opts: List[str],
389-
skip_missing: bool = False,
390-
log_info: bool = True,
391-
):
392-
# TODO: To remove technical debt, a possible solution is to use
393-
# struct mode to update with dotlist OmegaConf node. Look into this
394-
# in next iteration
395-
# TODO: Simplify this function
396-
if opts is None:
397-
opts = []
398-
399-
if len(opts) == 0:
400-
return config
401-
402-
# Support equal e.g. model=visual_bert for better future hydra support
403-
has_equal = opts[0].find("=") != -1
404-
if has_equal:
405-
opt_values = [opt.split("=", maxsplit=1) for opt in opts]
406-
if not all(len(opt) == 2 for opt in opt_values):
407-
for opt in opt_values:
408-
assert len(opt) == 2, "{} has no value".format(opt)
409-
else:
410-
assert len(opts) % 2 == 0, "Number of opts should be multiple of 2"
411-
opt_values = zip(opts[0::2], opts[1::2])
412-
413-
for opt, value in opt_values:
414-
if opt == "dataset":
415-
opt = "datasets"
416-
417-
splits = opt.split(".")
418-
current = config
419-
for idx, field in enumerate(splits):
420-
array_index = -1
421-
if field.find("[") != -1 and field.find("]") != -1:
422-
stripped_field = field[: field.find("[")]
423-
array_index = int(field[field.find("[") + 1 : field.find("]")])
424-
else:
425-
stripped_field = field
426-
if stripped_field not in current:
427-
if skip_missing is True:
428-
break
429-
raise AttributeError(
430-
"While updating configuration"
431-
" option {} is missing from"
432-
" configuration at field {}".format(opt, stripped_field)
433-
)
434-
if isinstance(current[stripped_field], collections.abc.Mapping):
435-
current = current[stripped_field]
436-
elif (
437-
isinstance(current[stripped_field], collections.abc.Sequence)
438-
and array_index != -1
439-
):
440-
try:
441-
current_value = current[stripped_field][array_index]
442-
except OCErrors.ConfigIndexError:
443-
if skip_missing:
444-
break
445-
raise
446-
447-
# Case where array element to be updated is last element
448-
if (
449-
not isinstance(
450-
current_value,
451-
(collections.abc.Mapping, collections.abc.Sequence),
452-
)
453-
or idx == len(splits) - 1
454-
):
455-
if log_info:
456-
logger.info(f"Overriding option {opt} to {value}")
457-
current[stripped_field][array_index] = self._decode_value(value)
458-
else:
459-
# Otherwise move on down the chain
460-
current = current_value
461-
else:
462-
if idx == len(splits) - 1:
463-
if log_info:
464-
logger.info(f"Overriding option {opt} to {value}")
465-
current[stripped_field] = self._decode_value(value)
466-
else:
467-
if skip_missing:
468-
break
469-
470-
raise AttributeError(
471-
"While updating configuration",
472-
"option {} is not present "
473-
"after field {}".format(opt, stripped_field),
474-
)
475-
476-
return config
477-
478-
def _decode_value(self, value):
479-
# https://github.com/rbgirshick/yacs/blob/master/yacs/config.py#L400
480-
if not isinstance(value, str):
481-
return value
482-
483-
if value == "None":
484-
value = None
485-
486-
try:
487-
value = literal_eval(value)
488-
except ValueError:
489-
pass
490-
except SyntaxError:
491-
pass
492-
return value
493-
494495
def freeze(self):
495496
OmegaConf.set_struct(self.config, True)
496497

‎mmf_cli/interactive.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
#!/usr/bin/env python3 -u
2+
# Copyright (c) Facebook, Inc. and its affiliates.
3+
import argparse
4+
import logging
5+
import typing
6+
7+
from mmf.utils.configuration import _merge_with_dotlist
8+
from mmf.utils.flags import flags
9+
from mmf.utils.inference import Inference
10+
from mmf.utils.logger import setup_logger
11+
from omegaconf import OmegaConf
12+
13+
14+
def construct_config(opts: typing.List[str]):
15+
config = OmegaConf.create({"checkpoint_path": ""})
16+
return _merge_with_dotlist(config, opts)
17+
18+
19+
def interactive(opts: typing.Optional[typing.List[str]] = None):
20+
"""Inference runs inference on an image and text provided by the user.
21+
You can optionally run inference programmatically by passing an optlist as opts.
22+
23+
Args:
24+
opts (typing.Optional[typing.List[str]], optional): Optlist which can be used.
25+
to override opts programmatically. For e.g. if you pass
26+
opts = ["checkpoint_path=my/directory"], this will set the checkpoint.
27+
"""
28+
if opts is None:
29+
parser = flags.get_parser()
30+
args = parser.parse_args()
31+
else:
32+
args = argparse.Namespace(config_override=None)
33+
args.opts = opts
34+
35+
setup_logger()
36+
logger = logging.getLogger("mmf_cli.interactive")
37+
38+
config = construct_config(args.opts)
39+
inference = Inference(checkpoint_path=config.checkpoint_path)
40+
logger.info("Enter 'exit' at any point to terminate.")
41+
logger.info("Enter an image URL:")
42+
image_url = input()
43+
logger.info("Got image URL.")
44+
logger.info("Enter text:")
45+
text = input()
46+
logger.info("Got text input.")
47+
while text != "exit":
48+
logger.info("Running inference on image and text input.")
49+
answer = inference.forward(image_url, {"text": text}, image_format="url")
50+
logger.info("Model response: " + answer)
51+
logger.info(
52+
f"Enter another image URL or leave it blank to continue using {image_url}:"
53+
)
54+
new_image_url = input()
55+
if new_image_url != "":
56+
image_url = new_image_url
57+
if new_image_url == "exit":
58+
break
59+
logger.info("Enter another text input:")
60+
text = input()
61+
62+
63+
if __name__ == "__main__":
64+
interactive()

‎setup.py

+1
Original file line numberDiff line numberDiff line change
@@ -156,6 +156,7 @@ def fetch_package_data():
156156
"mmf_run = mmf_cli.run:run",
157157
"mmf_predict = mmf_cli.predict:predict",
158158
"mmf_convert_hm = mmf_cli.hm_convert:main",
159+
"mmf_interactive = mmf_cli.interactive:interactive",
159160
]
160161
},
161162
)

0 commit comments

Comments
 (0)
Please sign in to comment.