diff --git a/groundingdino/models/GroundingDINO/transformer.py b/groundingdino/models/GroundingDINO/transformer.py index fcb8742d..cf846b78 100644 --- a/groundingdino/models/GroundingDINO/transformer.py +++ b/groundingdino/models/GroundingDINO/transformer.py @@ -297,6 +297,16 @@ def forward(self, srcs, masks, refpoint_embed, pos_embeds, tgt, attn_mask=None, self.enc_out_bbox_embed(output_memory) + output_proposals ) # (bs, \sum{hw}, 4) unsigmoid topk = self.num_queries + if topk > topk_logits.shape[-1]: + missing = topk-topk_logits.shape[-1] + repeat_pad = topk_logits.clone().detach()[:,-missing:] + topk_logits = torch.cat([topk_logits, repeat_pad], dim=-1) + repeat_pad = output_proposals.clone().detach()[:,-missing:, :] + output_proposals = torch.cat([output_proposals, repeat_pad], dim=-2) + repeat_pad = enc_outputs_coord_unselected.clone().detach()[:,-missing:, :] + enc_outputs_coord_unselected = torch.cat([enc_outputs_coord_unselected, repeat_pad], dim=-2) + repeat_pad = output_memory.clone().detach()[:,-missing:, :] + output_memory = torch.cat([output_memory, repeat_pad], dim=-2) topk_proposals = torch.topk(topk_logits, topk, dim=1)[1] # bs, nq