zhifu gao
2024-12-25 3f8294b9d7deaa0cbdb0b2ef6f3802d46ae133a9
funasr/models/ct_transformer_streaming/encoder.py
@@ -1,39 +1,34 @@
from typing import List
from typing import Optional
from typing import Sequence
from typing import Tuple
from typing import Union
import logging
import torch
import torch.nn as nn
import torch.nn.functional as F
from funasr.models.scama.chunk_utilis import overlap_chunk
import numpy as np
from funasr.train_utils.device_funcs import to_device
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.sanm.attention import MultiHeadedAttention
from funasr.models.ct_transformer_streaming.attention import MultiHeadedAttentionSANMwithMask
from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
from funasr.models.transformer.positionwise_feed_forward import (
    PositionwiseFeedForward,  # noqa: H301
)
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling2
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling6
from funasr.models.transformer.utils.subsampling import Conv2dSubsampling8
from funasr.models.transformer.utils.subsampling import TooShortUttError
from funasr.models.transformer.utils.subsampling import check_short_utt
from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
from funasr.models.ctc.ctc import CTC
import torch
from typing import List, Optional, Tuple
from funasr.register import tables
from funasr.models.ctc.ctc import CTC
from funasr.models.transformer.utils.repeat import repeat
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.sanm.attention import MultiHeadedAttention
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.utils.subsampling import check_short_utt
from funasr.models.transformer.utils.subsampling import TooShortUttError
from funasr.models.transformer.embedding import SinusoidalPositionEncoder
from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
from funasr.models.transformer.utils.mask import subsequent_mask, vad_mask
from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
from funasr.models.ct_transformer_streaming.attention import MultiHeadedAttentionSANMwithMask
from funasr.models.transformer.utils.subsampling import (
    Conv2dSubsampling,
    Conv2dSubsampling2,
    Conv2dSubsampling6,
    Conv2dSubsampling8,
)
class EncoderLayerSANM(nn.Module):
class EncoderLayerSANM(torch.nn.Module):
    def __init__(
        self,
        in_size,
@@ -51,13 +46,13 @@
        self.feed_forward = feed_forward
        self.norm1 = LayerNorm(in_size)
        self.norm2 = LayerNorm(size)
        self.dropout = nn.Dropout(dropout_rate)
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.in_size = in_size
        self.size = size
        self.normalize_before = normalize_before
        self.concat_after = concat_after
        if self.concat_after:
            self.concat_linear = nn.Linear(size + size, size)
            self.concat_linear = torch.nn.Linear(size + size, size)
        self.stochastic_depth_rate = stochastic_depth_rate
        self.dropout_rate = dropout_rate
@@ -92,7 +87,18 @@
            x = self.norm1(x)
        if self.concat_after:
            x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
            x_concat = torch.cat(
                (
                    x,
                    self.self_attn(
                        x,
                        mask,
                        mask_shfit_chunk=mask_shfit_chunk,
                        mask_att_chunk_encoder=mask_att_chunk_encoder,
                    ),
                ),
                dim=-1,
            )
            if self.in_size == self.size:
                x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
            else:
@@ -100,11 +106,21 @@
        else:
            if self.in_size == self.size:
                x = residual + stoch_layer_coeff * self.dropout(
                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
                    self.self_attn(
                        x,
                        mask,
                        mask_shfit_chunk=mask_shfit_chunk,
                        mask_att_chunk_encoder=mask_att_chunk_encoder,
                    )
                )
            else:
                x = stoch_layer_coeff * self.dropout(
                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
                    self.self_attn(
                        x,
                        mask,
                        mask_shfit_chunk=mask_shfit_chunk,
                        mask_att_chunk_encoder=mask_att_chunk_encoder,
                    )
                )
        if not self.normalize_before:
            x = self.norm1(x)
@@ -156,7 +172,7 @@
@tables.register("encoder_classes", "SANMVadEncoder")
class SANMVadEncoder(nn.Module):
class SANMVadEncoder(torch.nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
@@ -181,8 +197,8 @@
        padding_idx: int = -1,
        interctc_layer_idx: List[int] = [],
        interctc_use_conditioning: bool = False,
        kernel_size : int = 11,
        sanm_shfit : int = 0,
        kernel_size: int = 11,
        sanm_shfit: int = 0,
        selfattention_layer_type: str = "sanm",
    ):
        super().__init__()
@@ -287,7 +303,7 @@
        )
        self.encoders = repeat(
            num_blocks-1,
            num_blocks - 1,
            lambda lnum: EncoderLayerSANM(
                output_size,
                output_size,
@@ -306,7 +322,7 @@
            assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
        self.interctc_use_conditioning = interctc_use_conditioning
        self.conditioning_layer = None
        self.dropout = nn.Dropout(dropout_rate)
        self.dropout = torch.nn.Dropout(dropout_rate)
    def output_size(self) -> int:
        return self._output_size
@@ -331,16 +347,20 @@
        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
        sub_masks = subsequent_mask(masks.size(-1), device=xs_pad.device).unsqueeze(0)
        no_future_masks = masks & sub_masks
        xs_pad *= self.output_size()**0.5
        xs_pad *= self.output_size() ** 0.5
        if self.embed is None:
            xs_pad = xs_pad
        elif (isinstance(self.embed, Conv2dSubsampling) or isinstance(self.embed, Conv2dSubsampling2)
              or isinstance(self.embed, Conv2dSubsampling6) or isinstance(self.embed, Conv2dSubsampling8)):
        elif (
            isinstance(self.embed, Conv2dSubsampling)
            or isinstance(self.embed, Conv2dSubsampling2)
            or isinstance(self.embed, Conv2dSubsampling6)
            or isinstance(self.embed, Conv2dSubsampling8)
        ):
            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
            if short_status:
                raise TooShortUttError(
                    f"has {xs_pad.size(1)} frames and is too short for subsampling " +
                    f"(it needs more than {limit_size} frames), return empty results",
                    f"has {xs_pad.size(1)} frames and is too short for subsampling "
                    + f"(it needs more than {limit_size} frames), return empty results",
                    xs_pad.size(1),
                    limit_size,
                )
@@ -354,25 +374,26 @@
        xs_pad, _ = encoder_outs[0], encoder_outs[1]
        intermediate_outs = []
        for layer_idx, encoder_layer in enumerate(self.encoders):
                if layer_idx + 1 == len(self.encoders):
                    # This is last layer.
                    coner_mask = torch.ones(masks.size(0),
                                            masks.size(-1),
                                            masks.size(-1),
                                            device=xs_pad.device,
                                            dtype=torch.bool)
                    for word_index, length in enumerate(ilens):
                        coner_mask[word_index, :, :] = vad_mask(masks.size(-1),
                                                                vad_indexes[word_index],
                                                                device=xs_pad.device)
                    layer_mask = masks & coner_mask
                else:
                    layer_mask = no_future_masks
                mask_tup1 = [masks, layer_mask]
                encoder_outs = encoder_layer(xs_pad, mask_tup1)
                xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
            if layer_idx + 1 == len(self.encoders):
                # This is last layer.
                coner_mask = torch.ones(
                    masks.size(0),
                    masks.size(-1),
                    masks.size(-1),
                    device=xs_pad.device,
                    dtype=torch.bool,
                )
                for word_index, length in enumerate(ilens):
                    coner_mask[word_index, :, :] = vad_mask(
                        masks.size(-1), vad_indexes[word_index], device=xs_pad.device
                    )
                layer_mask = masks & coner_mask
            else:
                layer_mask = no_future_masks
            mask_tup1 = [masks, layer_mask]
            encoder_outs = encoder_layer(xs_pad, mask_tup1)
            xs_pad, layer_mask = encoder_outs[0], encoder_outs[1]
        if self.normalize_before:
            xs_pad = self.after_norm(xs_pad)
@@ -381,3 +402,105 @@
        if len(intermediate_outs) > 0:
            return (xs_pad, intermediate_outs), olens, None
        return xs_pad, olens, None
class EncoderLayerSANMExport(torch.nn.Module):
    def __init__(
        self,
        model,
    ):
        """Construct an EncoderLayer object."""
        super().__init__()
        self.self_attn = model.self_attn
        self.feed_forward = model.feed_forward
        self.norm1 = model.norm1
        self.norm2 = model.norm2
        self.in_size = model.in_size
        self.size = model.size
    def forward(self, x, mask):
        residual = x
        x = self.norm1(x)
        x = self.self_attn(x, mask)
        if self.in_size == self.size:
            x = x + residual
        residual = x
        x = self.norm2(x)
        x = self.feed_forward(x)
        x = x + residual
        return x, mask
@tables.register("encoder_classes", "SANMVadEncoderExport")
class SANMVadEncoderExport(torch.nn.Module):
    def __init__(
        self,
        model,
        max_seq_len=512,
        feats_dim=560,
        model_name="encoder",
        onnx: bool = True,
    ):
        super().__init__()
        self.embed = model.embed
        self.model = model
        self._output_size = model._output_size
        from funasr.utils.torch_function import sequence_mask
        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        from funasr.models.sanm.attention import MultiHeadedAttentionSANMExport
        if hasattr(model, "encoders0"):
            for i, d in enumerate(self.model.encoders0):
                if isinstance(d.self_attn, MultiHeadedAttentionSANMwithMask):
                    d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn)
                self.model.encoders0[i] = EncoderLayerSANMExport(d)
        for i, d in enumerate(self.model.encoders):
            if isinstance(d.self_attn, MultiHeadedAttentionSANMwithMask):
                d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn)
            self.model.encoders[i] = EncoderLayerSANMExport(d)
    def prepare_mask(self, mask, sub_masks):
        mask_3d_btd = mask[:, :, None]
        mask_4d_bhlt = (1 - sub_masks) * -10000.0
        return mask_3d_btd, mask_4d_bhlt
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        vad_masks: torch.Tensor,
        sub_masks: torch.Tensor,
    ):
        speech = speech * self._output_size**0.5
        mask = self.make_pad_mask(speech_lengths)
        vad_masks = self.prepare_mask(mask, vad_masks)
        mask = self.prepare_mask(mask, sub_masks)
        if self.embed is None:
            xs_pad = speech
        else:
            xs_pad = self.embed(speech)
        encoder_outs = self.model.encoders0(xs_pad, mask)
        xs_pad, masks = encoder_outs[0], encoder_outs[1]
        # encoder_outs = self.model.encoders(xs_pad, mask)
        for layer_idx, encoder_layer in enumerate(self.model.encoders):
            if layer_idx == len(self.model.encoders) - 1:
                mask = vad_masks
            encoder_outs = encoder_layer(xs_pad, mask)
            xs_pad, masks = encoder_outs[0], encoder_outs[1]
        xs_pad = self.model.after_norm(xs_pad)
        return xs_pad, speech_lengths
    def get_output_size(self):
        return self.model.encoders[0].size