zhifu gao
2024-04-17 e8f80e96f99cb856423d030c7d055c302a6d3278
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import copy
from typing import Optional, Tuple, Union
 
import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr.models.transformer.utils.nets_utils import make_pad_mask
 
def sense_voice_decode_forward(
    self,
    x: torch.Tensor,
    xa: torch.Tensor,
    kv_cache: Optional[dict] = None,
    **kwargs,
):
    """Forward decoder.
 
    Args:
        hs_pad: encoded memory, float32  (batch, maxlen_in, feat)
        hlens: (batch)
        ys_in_pad:
            input token ids, int64 (batch, maxlen_out)
            if input_layer == "embed"
            input tensor (batch, maxlen_out, #mels) in the other cases
        ys_in_lens: (batch)
    Returns:
        (tuple): tuple containing:
 
        x: decoded token score before softmax (batch, maxlen_out, token)
            if use_output_layer is True,
        olens: (batch, )
    """
    # import pdb;pdb.set_trace()
    use_padmask = self.use_padmask
    hlens = kwargs.get("hlens", None)
    
    ys_in_lens = kwargs.get("ys_in_lens", None)
    
    offset = next(iter(kv_cache.values())).shape[1] if kv_cache else 0
    tgt, memory = x, xa
    tgt[tgt==-1] = 0
    tgt = (
        self.token_embedding(tgt)
        + self.positional_embedding[offset : offset + tgt.size(1)]
    )
    # tgt = self.dropout(tgt)
    
    x = tgt.to(memory.dtype)
    
    if use_padmask and hlens is not None:
        memory_mask = (~make_pad_mask(hlens)[:, None, :]).to(memory.device)
    else:
        memory_mask = None
    
    for layer, block in enumerate(self.blocks):
        x = block(x, memory, mask=self.mask, memory_mask=memory_mask, is_pad_mask=False, is_pad_memory_mask=True)
 
 
    x = self.ln(x)
    x = (
        x @ torch.transpose(self.token_embedding.weight.to(x.dtype), 0, 1)
    ).float()
    
    
    return x