@@ -56,6 +56,8 @@ def _validate_ref_impl_exists() -> None:
5656
5757 if op_name_clean not in ref_impls :
5858 if op_name not in _SKIP_OPS :
59+ print ("*" * 100 )
60+ print (op_name_clean )
5961 error_impls .append (op_name )
6062
6163 if error_impls :
@@ -81,6 +83,13 @@ def register_fake(
8183 _REGISTERED_META_KERNELS .add (op_name )
8284 return _register_fake_original (op_name )
8385
86+ lib .define (
87+ "box_with_nms_limit.out(Tensor scores, Tensor boxes, Tensor batch_splits, float score_thresh, float nms, int detections_per_im, bool soft_nms_enabled, str soft_nms_method, float soft_nms_sigma, float soft_nms_min_score_thres, bool rotated, bool cls_agnostic_bbox_reg, bool input_boxes_include_bg_cls, bool output_classes_include_bg_cls, bool legacy_plus_one, Tensor[]? _caffe2_preallocated_outputs=None, *, Tensor(a!) out_scores, Tensor(b!) out_boxes, Tensor(c!) out_classes, Tensor(d!) batch_splits_out, Tensor(e!) out_keeps, Tensor(f!) out_keeps_size) -> (Tensor(a!) scores, Tensor(b!) boxes, Tensor(c!) classes, Tensor(d!) batch_splits, Tensor(e!) keeps, Tensor(f!) keeps_size)"
88+ )
89+
90+ lib .define (
91+ "box_with_nms_limit(Tensor scores, Tensor boxes, Tensor batch_splits, float score_thresh, float nms, int detections_per_im, bool soft_nms_enabled, str soft_nms_method, float soft_nms_sigma, float soft_nms_min_score_thres, bool rotated, bool cls_agnostic_bbox_reg, bool input_boxes_include_bg_cls, bool output_classes_include_bg_cls, bool legacy_plus_one, Tensor[]? _caffe2_preallocated_outputs=None) -> (Tensor scores, Tensor boxes, Tensor classes, Tensor batch_splits, Tensor keeps, Tensor keeps_size)"
92+ )
8493
8594lib .define (
8695 "quantize_per_tensor(Tensor input, float scale, int zero_point, int quant_min, int quant_max, ScalarType dtype) -> (Tensor Z)"
@@ -2734,6 +2743,48 @@ def quantized_w8a32_gru_meta(
27342743 return hidden .new_empty ((2 , hidden .shape [- 1 ]), dtype = torch .float32 )
27352744
27362745
2746+
2747+ @register_fake ("cadence::box_with_nms_limit" )
2748+ def box_with_nms_limit_meta (
2749+ tscores : torch .Tensor ,
2750+ tboxes : torch .Tensor ,
2751+ tbatch_splits : torch .Tensor ,
2752+ score_thres : float ,
2753+ nms_thres : float ,
2754+ detections_per_im : int ,
2755+ soft_nms_enabled : bool ,
2756+ soft_nms_method_str : str ,
2757+ soft_nms_sigma : float ,
2758+ soft_nms_min_score_thres : float ,
2759+ rotated : bool ,
2760+ cls_agnostic_bbox_reg : bool ,
2761+ input_boxes_include_bg_cls : bool ,
2762+ output_classes_include_bg_cls : bool ,
2763+ legacy_plus_one : bool ,
2764+ optional_tensor_list : Optional [list [torch .Tensor ]] = None ,
2765+ ) -> Tuple [
2766+ torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor , torch .Tensor
2767+ ]:
2768+ box_dim = 5 if rotated else 4
2769+ assert detections_per_im != 0
2770+ batch_size = tbatch_splits .size (0 )
2771+ num_classes = tscores .size (1 )
2772+ out_scores = tscores .new_empty ([detections_per_im ])
2773+ out_boxes = tscores .new_empty ([detections_per_im , box_dim ])
2774+ out_classes = tscores .new_empty ([detections_per_im ])
2775+ batch_splits_out = tscores .new_empty ([batch_size ])
2776+ out_keeps = tscores .new_empty ([detections_per_im ], dtype = torch .int32 )
2777+ out_keeps_size = tscores .new_empty ([batch_size , num_classes ], dtype = torch .int32 )
2778+
2779+ return (
2780+ out_scores ,
2781+ out_boxes ,
2782+ out_classes ,
2783+ batch_splits_out ,
2784+ out_keeps ,
2785+ out_keeps_size ,
2786+ )
2787+
27372788# Validate that all meta kernels have reference implementations
27382789# This is called at module import time to catch missing implementations early
27392790_validate_ref_impl_exists ()
0 commit comments