Skip to content

Commit 79074ef

Browse files
authored
removed the erroneous ‘’continual'' implementation (#1865)
1 parent 8ab0352 commit 79074ef

File tree

2 files changed

+14
-126
lines changed

2 files changed

+14
-126
lines changed

egs/wenetspeech4tts/TTS/valle/infer.py

+14-29
Original file line numberDiff line numberDiff line change
@@ -118,13 +118,6 @@ def get_args():
118118
help="The temperature of AR Decoder top_k sampling.",
119119
)
120120

121-
parser.add_argument(
122-
"--continual",
123-
type=str2bool,
124-
default=False,
125-
help="Do continual task.",
126-
)
127-
128121
parser.add_argument(
129122
"--repetition-aware-sampling",
130123
type=str2bool,
@@ -262,29 +255,21 @@ def main():
262255
)
263256

264257
# synthesis
265-
if args.continual:
266-
assert text == ""
267-
encoded_frames = model.continual(
268-
text_tokens.to(device),
269-
text_tokens_lens.to(device),
270-
audio_prompts,
271-
)
272-
else:
273-
enroll_x_lens = None
274-
if text_prompts:
275-
_, enroll_x_lens = text_collater(
276-
[tokenize_text(text_tokenizer, text=f"{text_prompts}".strip())]
277-
)
278-
encoded_frames = model.inference(
279-
text_tokens.to(device),
280-
text_tokens_lens.to(device),
281-
audio_prompts,
282-
enroll_x_lens=enroll_x_lens,
283-
top_k=args.top_k,
284-
temperature=args.temperature,
285-
top_p=args.top_p,
286-
ras=args.repetition_aware_sampling,
258+
enroll_x_lens = None
259+
if text_prompts:
260+
_, enroll_x_lens = text_collater(
261+
[tokenize_text(text_tokenizer, text=f"{text_prompts}".strip())]
287262
)
263+
encoded_frames = model.inference(
264+
text_tokens.to(device),
265+
text_tokens_lens.to(device),
266+
audio_prompts,
267+
enroll_x_lens=enroll_x_lens,
268+
top_k=args.top_k,
269+
temperature=args.temperature,
270+
top_p=args.top_p,
271+
ras=args.repetition_aware_sampling,
272+
)
288273

289274
if audio_prompts != []:
290275
samples = audio_tokenizer.decode([(encoded_frames.transpose(2, 1), None)])

egs/wenetspeech4tts/TTS/valle/valle.py

-97
Original file line numberDiff line numberDiff line change
@@ -1564,103 +1564,6 @@ def inference(
15641564
assert len(codes) == self.num_quantizers
15651565
return torch.stack(codes, dim=-1)
15661566

1567-
def continual(
1568-
self,
1569-
x: torch.Tensor,
1570-
x_lens: torch.Tensor,
1571-
y: torch.Tensor,
1572-
) -> torch.Tensor:
1573-
"""
1574-
Args:
1575-
x:
1576-
A 2-D tensor of shape (1, S).
1577-
x_lens:
1578-
A 1-D tensor of shape (1,). It contains the number of tokens in `x`
1579-
before padding.
1580-
y:
1581-
A 3-D tensor of shape (1, T, 8).
1582-
Returns:
1583-
Return the predicted audio code matrix.
1584-
"""
1585-
assert x.ndim == 2, x.shape
1586-
assert x_lens.ndim == 1, x_lens.shape
1587-
assert y.ndim == 3, y.shape
1588-
assert y.shape[0] == 1, y.shape
1589-
1590-
assert torch.all(x_lens > 0)
1591-
assert self.num_quantizers == 8
1592-
1593-
# NOTE: x has been padded in TextTokenCollater
1594-
text = x
1595-
x = self.ar_text_embedding(text)
1596-
x = self.ar_text_prenet(x)
1597-
x = self.ar_text_position(x)
1598-
1599-
text_len = x_lens.max()
1600-
1601-
prefix_len = min(int(y.shape[1] * 0.5), 3 * 75)
1602-
1603-
# AR Decoder
1604-
prompts = y[:, :prefix_len]
1605-
1606-
codes = [y[:, prefix_len:, 0]]
1607-
# Non-AR Decoders
1608-
x = self.nar_text_embedding(text)
1609-
x = self.nar_text_prenet(x)
1610-
x = self.nar_text_position(x)
1611-
1612-
y_emb = self.nar_audio_embeddings[0](y[..., 0])
1613-
1614-
if self.prefix_mode == 0:
1615-
for i, (predict_layer, embedding_layer) in enumerate(
1616-
zip(
1617-
self.nar_predict_layers,
1618-
self.nar_audio_embeddings[1:],
1619-
)
1620-
):
1621-
y_pos = self.nar_audio_position(y_emb)
1622-
y_pos = self.nar_audio_prenet(y_pos)
1623-
xy_pos = torch.concat([x, y_pos], dim=1)
1624-
1625-
xy_dec, _ = self.nar_decoder(
1626-
(xy_pos, self.nar_stage_embeddings[i].weight)
1627-
)
1628-
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
1629-
1630-
samples = torch.argmax(logits, dim=-1)
1631-
codes.append(samples)
1632-
1633-
if i < 6:
1634-
y_emb[:, :prefix_len] += embedding_layer(prompts[..., i + 1])
1635-
y_emb[:, prefix_len:] += embedding_layer(samples)
1636-
else:
1637-
for j in range(1, 8):
1638-
y_emb[:, :prefix_len] += self.nar_audio_embeddings[j](prompts[..., j])
1639-
1640-
for i, (predict_layer, embedding_layer) in enumerate(
1641-
zip(
1642-
self.nar_predict_layers,
1643-
self.nar_audio_embeddings[1:],
1644-
)
1645-
):
1646-
y_pos = self.nar_audio_prenet(y_emb)
1647-
y_pos = self.nar_audio_position(y_pos)
1648-
xy_pos = torch.concat([x, y_pos], dim=1)
1649-
1650-
xy_dec, _ = self.nar_decoder(
1651-
(xy_pos, self.nar_stage_embeddings[i].weight)
1652-
)
1653-
logits = predict_layer(xy_dec[:, text_len + prefix_len :])
1654-
1655-
samples = torch.argmax(logits, dim=-1)
1656-
codes.append(samples)
1657-
1658-
if i < 6:
1659-
y_emb[:, prefix_len:] += embedding_layer(samples)
1660-
1661-
assert len(codes) == 8
1662-
return torch.stack(codes, dim=-1)
1663-
16641567
def visualize(
16651568
self,
16661569
predicts: Tuple[torch.Tensor],

0 commit comments

Comments
 (0)