zhifu gao
2024-12-25 3f8294b9d7deaa0cbdb0b2ef6f3802d46ae133a9
funasr/models/ct_transformer_streaming/encoder.py
@@ -20,7 +20,12 @@
from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
from funasr.models.ct_transformer_streaming.attention import MultiHeadedAttentionSANMwithMask
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling, Conv2dSubsampling2, Conv2dSubsampling6, Conv2dSubsampling8
from funasr.models.transformer.utils.subsampling import (
    Conv2dSubsampling,
    Conv2dSubsampling2,
    Conv2dSubsampling6,
    Conv2dSubsampling8,
)
class EncoderLayerSANM(torch.nn.Module):
@@ -82,7 +87,18 @@
            x = self.norm1(x)
        if self.concat_after:
            x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
            x_concat = torch.cat(
                (
                    x,
                    self.self_attn(
                        x,
                        mask,
                        mask_shfit_chunk=mask_shfit_chunk,
                        mask_att_chunk_encoder=mask_att_chunk_encoder,
                    ),
                ),
                dim=-1,
            )
            if self.in_size == self.size:
                x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
            else:
@@ -90,11 +106,21 @@
        else:
            if self.in_size == self.size:
                x = residual + stoch_layer_coeff * self.dropout(
                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
                    self.self_attn(
                        x,
                        mask,
                        mask_shfit_chunk=mask_shfit_chunk,
                        mask_att_chunk_encoder=mask_att_chunk_encoder,
                    )
                )
            else:
                x = stoch_layer_coeff * self.dropout(
                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
                    self.self_attn(
                        x,
                        mask,
                        mask_shfit_chunk=mask_shfit_chunk,
                        mask_att_chunk_encoder=mask_att_chunk_encoder,
                    )
                )
        if not self.normalize_before:
            x = self.norm1(x)
@@ -171,8 +197,8 @@
        padding_idx: int = -1,
        interctc_layer_idx: List[int] = [],
        interctc_use_conditioning: bool = False,
        kernel_size : int = 11,
        sanm_shfit : int = 0,
        kernel_size: int = 11,
        sanm_shfit: int = 0,
        selfattention_layer_type: str = "sanm",
    ):
        super().__init__()
@@ -277,7 +303,7 @@
        )
        self.encoders = repeat(
            num_blocks-1,
            num_blocks - 1,
            lambda lnum: EncoderLayerSANM(
                output_size,
                output_size,
@@ -321,16 +347,20 @@
        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
        sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
        no_future_masks = masks & sub_masks
        xs_pad *= self.output_size()**0.5
        xs_pad *= self.output_size() ** 0.5
        if self.embed is None:
            xs_pad = xs_pad
        elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
              or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
        elif (
            isinstance(self.embed, Conv2dSubsampling)
            or isinstance(self.embed, Conv2dSubsampling2)
            or isinstance(self.embed, Conv2dSubsampling6)
            or isinstance(self.embed, Conv2dSubsampling8)
        ):
            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
            if short_status:
                raise TooShortUttError(
                    f"has {xs_pad.size(1)} frames and is too short for subsampling " +
                    f"(it needs more than {limit_size} frames), return empty results",
                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
                    + f"(it needs more than {limit_size} frames), return empty results",
                    xs_pad.size(1),
                    limit_size,
                )
@@ -344,25 +374,26 @@
        xs_pad, _ = encoder_outs[0], encoder_outs[1]
        intermediate_outs = []
        for layer_idx, encoder_layer in enumerate(self.encoders):
                if layer_idx + 1 == len(self.encoders):
                    # This is last layer.
                    coner_mask = torch.ones(masks.size(0),
                                            masks.size(-1),
                                            masks.size(-1),
                                            device=xs_pad.device,
                                            dtype=torch.bool)
                    for word_index, length in enumerate(ilens):
                        coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
                                                                vad_indexes[word_index],
                                                                device=xs_pad.device)
                    layer_mask = masks & coner_mask
                else:
                    layer_mask = no_future_masks
                mask_tup1 = [masks, layer_mask]
                encoder_outs = encoder_layer(xs_pad, mask_tup1)
                xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
            if layer_idx + 1 == len(self.encoders):
                # This is last layer.
                coner_mask = torch.ones(
                    masks.size(0),
                    masks.size(-1),
                    masks.size(-1),
                    device=xs_pad.device,
                    dtype=torch.bool,
                )
                for word_index, length in enumerate(ilens):
                    coner_mask[word_index, :, :] = vad_mask(
                        masks.size(-1), vad_indexes[word_index], device=xs_pad.device
                    )
                layer_mask = masks & coner_mask
            else:
                layer_mask = no_future_masks
            mask_tup1 = [masks, layer_mask]
            encoder_outs = encoder_layer(xs_pad, mask_tup1)
            xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
        if self.normalize_before:
            xs_pad = self.after_norm(xs_pad)
@@ -401,6 +432,7 @@
        return x, mask
@tables.register("encoder_classes", "SANMVadEncoderExport")
class SANMVadEncoderExport(torch.nn.Module):
    def __init__(
@@ -408,68 +440,67 @@
        model,
        max_seq_len=512,
        feats_dim=560,
        model_name='encoder',
        model_name="encoder",
        onnx: bool = True,
    ):
        super().__init__()
        self.embed = model.embed
        self.model = model
        self._output_size = model._output_size
        from funasr.utils.torch_function import sequence_mask
        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        from funasr.models.sanm.attention import MultiHeadedAttentionSANMExport
        if hasattr(model, 'encoders0'):
        if hasattr(model, "encoders0"):
            for i, d in enumerate(self.model.encoders0):
                if isinstance(d.self_attn, MultiHeadedAttentionSANMwithMask):
                    d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn)
                self.model.encoders0[i] = EncoderLayerSANMExport(d)
        for i, d in enumerate(self.model.encoders):
            if isinstance(d.self_attn, MultiHeadedAttentionSANMwithMask):
                d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn)
            self.model.encoders[i] = EncoderLayerSANMExport(d)
    def prepare_mask(self, mask, sub_masks):
        mask_3d_btd = mask[:, :, None]
        mask_4d_bhlt = (1 - sub_masks) * -10000.0
        return mask_3d_btd, mask_4d_bhlt
    def forward(self,
                speech: torch.Tensor,
                speech_lengths: torch.Tensor,
                vad_masks: torch.Tensor,
                sub_masks: torch.Tensor,
                ):
        speech = speech * self._output_size ** 0.5
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        vad_masks: torch.Tensor,
        sub_masks: torch.Tensor,
    ):
        speech = speech * self._output_size**0.5
        mask = self.make_pad_mask(speech_lengths)
        vad_masks = self.prepare_mask(mask, vad_masks)
        mask = self.prepare_mask(mask, sub_masks)
        if self.embed is None:
            xs_pad = speech
        else:
            xs_pad = self.embed(speech)
        encoder_outs = self.model.encoders0(xs_pad, mask)
        xs_pad, masks = encoder_outs[0], encoder_outs[1]
        # encoder_outs = self.model.encoders(xs_pad, mask)
        for layer_idx, encoder_layer in enumerate(self.model.encoders):
            if layer_idx == len(self.model.encoders) - 1:
                mask = vad_masks
            encoder_outs = encoder_layer(xs_pad, mask)
            xs_pad, masks = encoder_outs[0], encoder_outs[1]
        xs_pad = self.model.after_norm(xs_pad)
        return xs_pad, speech_lengths
    def get_output_size(self):
        return self.model.encoders[0].size
        return self.model.encoders[0].size