zhifu gao
2024-04-17 eaf9dda9e4d970af3d09db695e9e10c83ef94e25
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import copy
from typing import Optional, Tuple, Union
 
import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr.models.transformer.utils.nets_utils import make_pad_mask
 
 
def sense_voice_encode_forward(
    self,
    x: torch.Tensor,
    ilens: torch.Tensor = None,
    **kwargs,
):
    use_padmask = self.use_padmask
    x = F.gelu(self.conv1(x))
    x = F.gelu(self.conv2(x))
    x = x.permute(0, 2, 1)
    
    n_frames = x.size(1)
    max_pos = self.positional_embedding.size(0)
    max_pos = n_frames if n_frames < max_pos else max_pos
    x = (x[:, :max_pos, :] + self.positional_embedding[None, :max_pos, :]).to(x.dtype)
    
    
    if ilens is not None:
        if self.downsample_rate == 4:
            olens = (
                1
                + (
                    ilens
                    - self.conv1.kernel_size[0]
                    + 2 * self.conv1.padding[0]
                )
                // self.conv1.stride[0]
            )
        else:
            olens = ilens
        olens = (
            1
            + (
                olens
                - self.conv2.kernel_size[0]
                + 2 * self.conv2.padding[0]
            )
            // self.conv2.stride[0]
        )
        olens = torch.clamp(olens, max=max_pos)
    else:
        olens = None
    
    if use_padmask and olens is not None:
        padding_mask = (~make_pad_mask(olens)[:, None, :]).to(x.device)
    else:
        padding_mask = None
    
    for layer, block in enumerate(self.blocks):
        x = block(x, mask=padding_mask, is_pad_mask=True)
        
 
    x = self.ln_post(x)
    
    if ilens is None:
        return x
    else:
        return x, olens