diff --git a/tf_trt_models/detection.py b/tf_trt_models/detection.py index 88a1263..76d2305 100644 --- a/tf_trt_models/detection.py +++ b/tf_trt_models/detection.py @@ -4,6 +4,7 @@ import os import tarfile import subprocess +import warnings from google.protobuf import text_format @@ -86,7 +87,7 @@ def download_detection_model(model, output_dir='.'): return config_path, checkpoint_path -def build_detection_graph(config, checkpoint): +def build_detection_graph(config, checkpoint, score_threshold=0.3): """Build an object detection model from the TensorFlow model zoo. This function creates an object detection model, sourced from the @@ -112,6 +113,8 @@ def build_detection_graph(config, checkpoint): :type config: string :param checkpoint: path to the checkpoint files prefix containing trained model params :type checkpoint: string + :score_threshold: NonMaxSuppression score_threshold (default 0.3) + :type score_threshold: float :returns: the configured frozen graph representing object detection model :rtype: a tensorflow GraphDef """ @@ -123,6 +126,13 @@ def build_detection_graph(config, checkpoint): config = TrainEvalPipelineConfig() text_format.Merge(config_str, config) + try: + old_score = config.model.ssd.post_processing.batch_non_max_suppression.score_threshold + config.model.ssd.post_processing.batch_non_max_suppression.score_threshold=score_threshold + warnings.warn("The score threshold of NonMaxSuppression was set from "+str(old_score)+" to "+str(score_threshold)) + except AttributeError: + warnings.warn("The score threshold of NonMaxSuppression can not be reconfigured") + pass tf_config = tf.ConfigProto() tf_config.gpu_options.allow_growth = True