zhifu gao
2024-09-25 2196844d1d6e5b8732c95896bb46f0eacdd9cf9d
funasr/models/sanm/encoder.py
@@ -523,6 +523,7 @@
        feats_dim=560,
        model_name="encoder",
        onnx: bool = True,
        ctc_linear: nn.Module = None,
    ):
        super().__init__()
        self.embed = model.embed
@@ -553,6 +554,8 @@
        self.num_heads = model.encoders[0].self_attn.h
        self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
        self.ctc_linear = ctc_linear
    def prepare_mask(self, mask):
        mask_3d_btd = mask[:, :, None]
        if len(mask.shape) == 2:
@@ -566,6 +569,7 @@
    def forward(self, speech: torch.Tensor, speech_lengths: torch.Tensor, online: bool = False):
        if not online:
            speech = speech * self._output_size**0.5
        mask = self.make_pad_mask(speech_lengths)
        mask = self.prepare_mask(mask)
        if self.embed is None:
@@ -581,6 +585,10 @@
        xs_pad = self.model.after_norm(xs_pad)
        if self.ctc_linear is not None:
            xs_pad = self.ctc_linear(xs_pad)
            xs_pad = F.softmax(xs_pad, dim=2)
        return xs_pad, speech_lengths
    def get_output_size(self):