liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
import torch
 
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.nets_utils import make_pad_mask
 
 
class SimpleAvg(AbsEncoder):
    def __init__(self, feat_dim):
        super(SimpleAvg, self).__init__()
        self.feat_dim = feat_dim
 
    def forward(self, x, ilens):
        mask = ~make_pad_mask(ilens, maxlen=x.shape[1]).to(x.device)
        avg_x = (x * mask[:, :, None]).sum(1) / mask.sum(-1)[:, None]
        return avg_x
 
    def output_size(self) -> int:
        return self.feat_dim