@@ -1564,103 +1564,6 @@ def inference(
1564
1564
assert len (codes ) == self .num_quantizers
1565
1565
return torch .stack (codes , dim = - 1 )
1566
1566
1567
- def continual (
1568
- self ,
1569
- x : torch .Tensor ,
1570
- x_lens : torch .Tensor ,
1571
- y : torch .Tensor ,
1572
- ) -> torch .Tensor :
1573
- """
1574
- Args:
1575
- x:
1576
- A 2-D tensor of shape (1, S).
1577
- x_lens:
1578
- A 1-D tensor of shape (1,). It contains the number of tokens in `x`
1579
- before padding.
1580
- y:
1581
- A 3-D tensor of shape (1, T, 8).
1582
- Returns:
1583
- Return the predicted audio code matrix.
1584
- """
1585
- assert x .ndim == 2 , x .shape
1586
- assert x_lens .ndim == 1 , x_lens .shape
1587
- assert y .ndim == 3 , y .shape
1588
- assert y .shape [0 ] == 1 , y .shape
1589
-
1590
- assert torch .all (x_lens > 0 )
1591
- assert self .num_quantizers == 8
1592
-
1593
- # NOTE: x has been padded in TextTokenCollater
1594
- text = x
1595
- x = self .ar_text_embedding (text )
1596
- x = self .ar_text_prenet (x )
1597
- x = self .ar_text_position (x )
1598
-
1599
- text_len = x_lens .max ()
1600
-
1601
- prefix_len = min (int (y .shape [1 ] * 0.5 ), 3 * 75 )
1602
-
1603
- # AR Decoder
1604
- prompts = y [:, :prefix_len ]
1605
-
1606
- codes = [y [:, prefix_len :, 0 ]]
1607
- # Non-AR Decoders
1608
- x = self .nar_text_embedding (text )
1609
- x = self .nar_text_prenet (x )
1610
- x = self .nar_text_position (x )
1611
-
1612
- y_emb = self .nar_audio_embeddings [0 ](y [..., 0 ])
1613
-
1614
- if self .prefix_mode == 0 :
1615
- for i , (predict_layer , embedding_layer ) in enumerate (
1616
- zip (
1617
- self .nar_predict_layers ,
1618
- self .nar_audio_embeddings [1 :],
1619
- )
1620
- ):
1621
- y_pos = self .nar_audio_position (y_emb )
1622
- y_pos = self .nar_audio_prenet (y_pos )
1623
- xy_pos = torch .concat ([x , y_pos ], dim = 1 )
1624
-
1625
- xy_dec , _ = self .nar_decoder (
1626
- (xy_pos , self .nar_stage_embeddings [i ].weight )
1627
- )
1628
- logits = predict_layer (xy_dec [:, text_len + prefix_len :])
1629
-
1630
- samples = torch .argmax (logits , dim = - 1 )
1631
- codes .append (samples )
1632
-
1633
- if i < 6 :
1634
- y_emb [:, :prefix_len ] += embedding_layer (prompts [..., i + 1 ])
1635
- y_emb [:, prefix_len :] += embedding_layer (samples )
1636
- else :
1637
- for j in range (1 , 8 ):
1638
- y_emb [:, :prefix_len ] += self .nar_audio_embeddings [j ](prompts [..., j ])
1639
-
1640
- for i , (predict_layer , embedding_layer ) in enumerate (
1641
- zip (
1642
- self .nar_predict_layers ,
1643
- self .nar_audio_embeddings [1 :],
1644
- )
1645
- ):
1646
- y_pos = self .nar_audio_prenet (y_emb )
1647
- y_pos = self .nar_audio_position (y_pos )
1648
- xy_pos = torch .concat ([x , y_pos ], dim = 1 )
1649
-
1650
- xy_dec , _ = self .nar_decoder (
1651
- (xy_pos , self .nar_stage_embeddings [i ].weight )
1652
- )
1653
- logits = predict_layer (xy_dec [:, text_len + prefix_len :])
1654
-
1655
- samples = torch .argmax (logits , dim = - 1 )
1656
- codes .append (samples )
1657
-
1658
- if i < 6 :
1659
- y_emb [:, prefix_len :] += embedding_layer (samples )
1660
-
1661
- assert len (codes ) == 8
1662
- return torch .stack (codes , dim = - 1 )
1663
-
1664
1567
def visualize (
1665
1568
self ,
1666
1569
predicts : Tuple [torch .Tensor ],
0 commit comments