-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdecode.py
More file actions
executable file
·76 lines (71 loc) · 3.19 KB
/
decode.py
File metadata and controls
executable file
·76 lines (71 loc) · 3.19 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
import os
import pdb
import time
import torch
import ctcdecode
import numpy as np
from itertools import groupby
import torch.nn.functional as F
class Decode(object):
def __init__(self, gloss_dict, num_classes, search_mode, blank_id=0):
# self.i2g_dict = dict((v[0], k) for k, v in gloss_dict.items())
# self.g2i_dict = {v: k for k, v in self.i2g_dict.items()}
self.g2i_dict = {}
for k, v in gloss_dict.items():
if v == 0:
continue
self.g2i_dict[k] = v
self.i2g_dict = {v: k for k, v in self.g2i_dict.items()}
self.num_classes = num_classes
self.search_mode = search_mode
self.blank_id = blank_id
vocab = [chr(x) for x in range(20000, 20000 + num_classes)]
self.ctc_decoder = ctcdecode.CTCBeamDecoder(vocab, beam_width=10, blank_id=blank_id,
num_processes=10)
# self.ctc_decoder = None
def decode(self, nn_output, vid_lgt, batch_first=True, probs=False):
if not batch_first:
nn_output = nn_output.permute(1, 0, 2)
if self.search_mode == "max":
return self.MaxDecode(nn_output, vid_lgt)
else:
return self.BeamSearch(nn_output, vid_lgt, probs)
def BeamSearch(self, nn_output, vid_lgt, probs=False):
'''
CTCBeamDecoder Shape:
- Input: nn_output (B, T, N), which should be passed through a softmax layer
- Output: beam_resuls (B, N_beams, T), int, need to be decoded by i2g_dict
beam_scores (B, N_beams), p=1/np.exp(beam_score)
timesteps (B, N_beams)
out_lens (B, N_beams)
'''
if not probs:
nn_output = nn_output.softmax(-1).cpu()
vid_lgt = vid_lgt.cpu()
beam_result, beam_scores, timesteps, out_seq_len = self.ctc_decoder.decode(nn_output, vid_lgt)
ret_list = []
for batch_idx in range(len(nn_output)):
first_result = beam_result[batch_idx][0][:out_seq_len[batch_idx][0]]
if len(first_result) != 0:
first_result = torch.stack([x[0] for x in groupby(first_result)])
ret_list.append([(self.i2g_dict[int(gloss_id)], idx) for idx, gloss_id in
enumerate(first_result)])
return ret_list
def MaxDecode(self, nn_output, vid_lgt):
index_list = torch.argmax(nn_output, axis=2)
batchsize, lgt = index_list.shape
ret_list = []
# result_list = []
for batch_idx in range(batchsize):
group_result = [x[0] for x in groupby(index_list[batch_idx][:vid_lgt[batch_idx]])]
filtered = [*filter(lambda x: x != self.blank_id, group_result)]
if len(filtered) > 0:
max_result = torch.stack(filtered)
max_result = [x[0] for x in groupby(max_result)]
else:
max_result = filtered
ret_list.append([(self.i2g_dict[int(gloss_id)], idx) for idx, gloss_id in
enumerate(max_result)])
# result_list.append(max_result)
# return ret_list, result_list
return ret_list