From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 funasr/models/scama/encoder.py |  157 ++++++++++++++++++++++++++++++++--------------------
 1 files changed, 97 insertions(+), 60 deletions(-)

diff --git a/funasr/models/scama/encoder.py b/funasr/models/scama/encoder.py
index 2c676b2..e1fe924 100644
--- a/funasr/models/scama/encoder.py
+++ b/funasr/models/scama/encoder.py
@@ -17,7 +17,10 @@
 from funasr.train_utils.device_funcs import to_device
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
 from funasr.models.sanm.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
-from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
+from funasr.models.transformer.embedding import (
+    SinusoidalPositionEncoder,
+    StreamSinusoidalPositionEncoder,
+)
 from funasr.models.transformer.layer_norm import LayerNorm
 from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
 from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
@@ -36,6 +39,7 @@
 from funasr.models.ctc.ctc import CTC
 
 from funasr.register import tables
+
 
 class EncoderLayerSANM(nn.Module):
     def __init__(
@@ -96,7 +100,18 @@
             x = self.norm1(x)
 
         if self.concat_after:
-            x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
+            x_concat = torch.cat(
+                (
+                    x,
+                    self.self_attn(
+                        x,
+                        mask,
+                        mask_shfit_chunk=mask_shfit_chunk,
+                        mask_att_chunk_encoder=mask_att_chunk_encoder,
+                    ),
+                ),
+                dim=-1,
+            )
             if self.in_size == self.size:
                 x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
             else:
@@ -104,11 +119,21 @@
         else:
             if self.in_size == self.size:
                 x = residual + stoch_layer_coeff * self.dropout(
-                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+                    self.self_attn(
+                        x,
+                        mask,
+                        mask_shfit_chunk=mask_shfit_chunk,
+                        mask_att_chunk_encoder=mask_att_chunk_encoder,
+                    )
                 )
             else:
                 x = stoch_layer_coeff * self.dropout(
-                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+                    self.self_attn(
+                        x,
+                        mask,
+                        mask_shfit_chunk=mask_shfit_chunk,
+                        mask_att_chunk_encoder=mask_att_chunk_encoder,
+                    )
                 )
         if not self.normalize_before:
             x = self.norm1(x)
@@ -168,34 +193,34 @@
     """
 
     def __init__(
-            self,
-            input_size: int,
-            output_size: int = 256,
-            attention_heads: int = 4,
-            linear_units: int = 2048,
-            num_blocks: int = 6,
-            dropout_rate: float = 0.1,
-            positional_dropout_rate: float = 0.1,
-            attention_dropout_rate: float = 0.0,
-            input_layer: Optional[str] = "conv2d",
-            pos_enc_class=SinusoidalPositionEncoder,
-            normalize_before: bool = True,
-            concat_after: bool = False,
-            positionwise_layer_type: str = "linear",
-            positionwise_conv_kernel_size: int = 1,
-            padding_idx: int = -1,
-            interctc_layer_idx: List[int] = [],
-            interctc_use_conditioning: bool = False,
-            kernel_size: int = 11,
-            sanm_shfit: int = 0,
-            selfattention_layer_type: str = "sanm",
-            chunk_size: Union[int, Sequence[int]] = (16,),
-            stride: Union[int, Sequence[int]] = (10,),
-            pad_left: Union[int, Sequence[int]] = (0,),
-            encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
-            decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
-            tf2torch_tensor_name_prefix_torch: str = "encoder",
-            tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
+        self,
+        input_size: int,
+        output_size: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        attention_dropout_rate: float = 0.0,
+        input_layer: Optional[str] = "conv2d",
+        pos_enc_class=SinusoidalPositionEncoder,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        positionwise_layer_type: str = "linear",
+        positionwise_conv_kernel_size: int = 1,
+        padding_idx: int = -1,
+        interctc_layer_idx: List[int] = [],
+        interctc_use_conditioning: bool = False,
+        kernel_size: int = 11,
+        sanm_shfit: int = 0,
+        selfattention_layer_type: str = "sanm",
+        chunk_size: Union[int, Sequence[int]] = (16,),
+        stride: Union[int, Sequence[int]] = (10,),
+        pad_left: Union[int, Sequence[int]] = (0,),
+        encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
+        decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
+        tf2torch_tensor_name_prefix_torch: str = "encoder",
+        tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
     ):
         super().__init__()
         self._output_size = output_size
@@ -334,12 +359,12 @@
         return self._output_size
 
     def forward(
-            self,
-            xs_pad: torch.Tensor,
-            ilens: torch.Tensor,
-            prev_states: torch.Tensor = None,
-            ctc: CTC = None,
-            ind: int = 0,
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+        prev_states: torch.Tensor = None,
+        ctc: CTC = None,
+        ind: int = 0,
     ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
         """Embed positions in tensor.
 
@@ -355,10 +380,10 @@
         if self.embed is None:
             xs_pad = xs_pad
         elif (
-                isinstance(self.embed, Conv2dSubsampling)
-                or isinstance(self.embed, Conv2dSubsampling2)
-                or isinstance(self.embed, Conv2dSubsampling6)
-                or isinstance(self.embed, Conv2dSubsampling8)
+            isinstance(self.embed, Conv2dSubsampling)
+            or isinstance(self.embed, Conv2dSubsampling2)
+            or isinstance(self.embed, Conv2dSubsampling6)
+            or isinstance(self.embed, Conv2dSubsampling8)
         ):
             short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
             if short_status:
@@ -378,21 +403,26 @@
             chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
             xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
             masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-            mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
-                                                                           dtype=xs_pad.dtype)
-            mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
-                                                                                       xs_pad.size(0),
-                                                                                       dtype=xs_pad.dtype)
+            mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(
+                chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype
+            )
+            mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(
+                chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype
+            )
 
         encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
         xs_pad, masks = encoder_outs[0], encoder_outs[1]
         intermediate_outs = []
         if len(self.interctc_layer_idx) == 0:
-            encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+            encoder_outs = self.encoders(
+                xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder
+            )
             xs_pad, masks = encoder_outs[0], encoder_outs[1]
         else:
             for layer_idx, encoder_layer in enumerate(self.encoders):
-                encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+                encoder_outs = encoder_layer(
+                    xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder
+                )
                 xs_pad, masks = encoder_outs[0], encoder_outs[1]
                 if layer_idx + 1 in self.interctc_layer_idx:
                     encoder_out = xs_pad
@@ -420,15 +450,16 @@
             return feats
         cache["feats"] = to_device(cache["feats"], device=feats.device)
         overlap_feats = torch.cat((cache["feats"], feats), dim=1)
-        cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+        cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]) :, :]
         return overlap_feats
 
-    def forward_chunk(self,
-                      xs_pad: torch.Tensor,
-                      ilens: torch.Tensor,
-                      cache: dict = None,
-                      **kwargs,
-                      ):
+    def forward_chunk(
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+        cache: dict = None,
+        **kwargs,
+    ):
         is_final = kwargs.get("is_final", False)
         xs_pad *= self.output_size() ** 0.5
         if self.embed is None:
@@ -446,12 +477,19 @@
             new_cache = cache["opt"]
 
         for layer_idx, encoder_layer in enumerate(self.encoders0):
-            encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"])
+            encoder_outs = encoder_layer.forward_chunk(
+                xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"]
+            )
             xs_pad, new_cache[0] = encoder_outs[0], encoder_outs[1]
 
         for layer_idx, encoder_layer in enumerate(self.encoders):
-            encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx+len(self.encoders0)], cache["chunk_size"], cache["encoder_chunk_look_back"])
-            xs_pad, new_cache[layer_idx+len(self.encoders0)] = encoder_outs[0], encoder_outs[1]
+            encoder_outs = encoder_layer.forward_chunk(
+                xs_pad,
+                new_cache[layer_idx + len(self.encoders0)],
+                cache["chunk_size"],
+                cache["encoder_chunk_look_back"],
+            )
+            xs_pad, new_cache[layer_idx + len(self.encoders0)] = encoder_outs[0], encoder_outs[1]
 
         if self.normalize_before:
             xs_pad = self.after_norm(xs_pad)
@@ -459,4 +497,3 @@
             cache["opt"] = new_cache
 
         return xs_pad, ilens, None
-

--
Gitblit v1.9.1