liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
funasr/models/scama/encoder.py
@@ -17,7 +17,10 @@
from funasr.train_utils.device_funcs import to_device
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.sanm.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
from funasr.models.transformer.embedding import (
    SinusoidalPositionEncoder,
    StreamSinusoidalPositionEncoder,
)
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
@@ -36,6 +39,7 @@
from funasr.models.ctc.ctc import CTC
from funasr.register import tables
class EncoderLayerSANM(nn.Module):
    def __init__(
@@ -96,7 +100,18 @@
            x = self.norm1(x)
        if self.concat_after:
            x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
            x_concat = torch.cat(
                (
                    x,
                    self.self_attn(
                        x,
                        mask,
                        mask_shfit_chunk=mask_shfit_chunk,
                        mask_att_chunk_encoder=mask_att_chunk_encoder,
                    ),
                ),
                dim=-1,
            )
            if self.in_size == self.size:
                x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
            else:
@@ -104,11 +119,21 @@
        else:
            if self.in_size == self.size:
                x = residual + stoch_layer_coeff * self.dropout(
                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
                    self.self_attn(
                        x,
                        mask,
                        mask_shfit_chunk=mask_shfit_chunk,
                        mask_att_chunk_encoder=mask_att_chunk_encoder,
                    )
                )
            else:
                x = stoch_layer_coeff * self.dropout(
                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
                    self.self_attn(
                        x,
                        mask,
                        mask_shfit_chunk=mask_shfit_chunk,
                        mask_att_chunk_encoder=mask_att_chunk_encoder,
                    )
                )
        if not self.normalize_before:
            x = self.norm1(x)
@@ -168,34 +193,34 @@
    """
    def __init__(
            self,
            input_size: int,
            output_size: int = 256,
            attention_heads: int = 4,
            linear_units: int = 2048,
            num_blocks: int = 6,
            dropout_rate: float = 0.1,
            positional_dropout_rate: float = 0.1,
            attention_dropout_rate: float = 0.0,
            input_layer: Optional[str] = "conv2d",
            pos_enc_class=SinusoidalPositionEncoder,
            normalize_before: bool = True,
            concat_after: bool = False,
            positionwise_layer_type: str = "linear",
            positionwise_conv_kernel_size: int = 1,
            padding_idx: int = -1,
            interctc_layer_idx: List[int] = [],
            interctc_use_conditioning: bool = False,
            kernel_size: int = 11,
            sanm_shfit: int = 0,
            selfattention_layer_type: str = "sanm",
            chunk_size: Union[int, Sequence[int]] = (16,),
            stride: Union[int, Sequence[int]] = (10,),
            pad_left: Union[int, Sequence[int]] = (0,),
            encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
            decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
            tf2torch_tensor_name_prefix_torch: str = "encoder",
            tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
        self,
        input_size: int,
        output_size: int = 256,
        attention_heads: int = 4,
        linear_units: int = 2048,
        num_blocks: int = 6,
        dropout_rate: float = 0.1,
        positional_dropout_rate: float = 0.1,
        attention_dropout_rate: float = 0.0,
        input_layer: Optional[str] = "conv2d",
        pos_enc_class=SinusoidalPositionEncoder,
        normalize_before: bool = True,
        concat_after: bool = False,
        positionwise_layer_type: str = "linear",
        positionwise_conv_kernel_size: int = 1,
        padding_idx: int = -1,
        interctc_layer_idx: List[int] = [],
        interctc_use_conditioning: bool = False,
        kernel_size: int = 11,
        sanm_shfit: int = 0,
        selfattention_layer_type: str = "sanm",
        chunk_size: Union[int, Sequence[int]] = (16,),
        stride: Union[int, Sequence[int]] = (10,),
        pad_left: Union[int, Sequence[int]] = (0,),
        encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
        decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
        tf2torch_tensor_name_prefix_torch: str = "encoder",
        tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
    ):
        super().__init__()
        self._output_size = output_size
@@ -334,12 +359,12 @@
        return self._output_size
    def forward(
            self,
            xs_pad: torch.Tensor,
            ilens: torch.Tensor,
            prev_states: torch.Tensor = None,
            ctc: CTC = None,
            ind: int = 0,
        self,
        xs_pad: torch.Tensor,
        ilens: torch.Tensor,
        prev_states: torch.Tensor = None,
        ctc: CTC = None,
        ind: int = 0,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """Embed positions in tensor.
@@ -355,10 +380,10 @@
        if self.embed is None:
            xs_pad = xs_pad
        elif (
                isinstance(self.embed, Conv2dSubsampling)
                or isinstance(self.embed, Conv2dSubsampling2)
                or isinstance(self.embed, Conv2dSubsampling6)
                or isinstance(self.embed, Conv2dSubsampling8)
            isinstance(self.embed, Conv2dSubsampling)
            or isinstance(self.embed, Conv2dSubsampling2)
            or isinstance(self.embed, Conv2dSubsampling6)
            or isinstance(self.embed, Conv2dSubsampling8)
        ):
            short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
            if short_status:
@@ -378,21 +403,26 @@
            chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
            xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
            masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
            mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
                                                                           dtype=xs_pad.dtype)
            mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
                                                                                       xs_pad.size(0),
                                                                                       dtype=xs_pad.dtype)
            mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(
                chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype
            )
            mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(
                chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype
            )
        encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
        xs_pad, masks = encoder_outs[0], encoder_outs[1]
        intermediate_outs = []
        if len(self.interctc_layer_idx) == 0:
            encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
            encoder_outs = self.encoders(
                xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder
            )
            xs_pad, masks = encoder_outs[0], encoder_outs[1]
        else:
            for layer_idx, encoder_layer in enumerate(self.encoders):
                encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
                encoder_outs = encoder_layer(
                    xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder
                )
                xs_pad, masks = encoder_outs[0], encoder_outs[1]
                if layer_idx + 1 in self.interctc_layer_idx:
                    encoder_out = xs_pad
@@ -420,15 +450,16 @@
            return feats
        cache["feats"] = to_device(cache["feats"], device=feats.device)
        overlap_feats = torch.cat((cache["feats"], feats), dim=1)
        cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
        cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]) :, :]
        return overlap_feats
    def forward_chunk(self,
                      xs_pad: torch.Tensor,
                      ilens: torch.Tensor,
                      cache: dict = None,
                      **kwargs,
                      ):
    def forward_chunk(
        self,
        xs_pad: torch.Tensor,
        ilens: torch.Tensor,
        cache: dict = None,
        **kwargs,
    ):
        is_final = kwargs.get("is_final", False)
        xs_pad *= self.output_size() ** 0.5
        if self.embed is None:
@@ -446,12 +477,19 @@
            new_cache = cache["opt"]
        for layer_idx, encoder_layer in enumerate(self.encoders0):
            encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"])
            encoder_outs = encoder_layer.forward_chunk(
                xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"]
            )
            xs_pad, new_cache[0] = encoder_outs[0], encoder_outs[1]
        for layer_idx, encoder_layer in enumerate(self.encoders):
            encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx+len(self.encoders0)], cache["chunk_size"], cache["encoder_chunk_look_back"])
            xs_pad, new_cache[layer_idx+len(self.encoders0)] = encoder_outs[0], encoder_outs[1]
            encoder_outs = encoder_layer.forward_chunk(
                xs_pad,
                new_cache[layer_idx + len(self.encoders0)],
                cache["chunk_size"],
                cache["encoder_chunk_look_back"],
            )
            xs_pad, new_cache[layer_idx + len(self.encoders0)] = encoder_outs[0], encoder_outs[1]
        if self.normalize_before:
            xs_pad = self.after_norm(xs_pad)
@@ -459,4 +497,3 @@
            cache["opt"] = new_cache
        return xs_pad, ilens, None