From 1cdb3cc28d4d89a576cc06e5cd8eb80da1f3a3aa Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 26 四月 2024 11:27:39 +0800
Subject: [PATCH] Dev gzf exp (#1665)

---
 funasr/models/scama/encoder.py |  314 ++++++++++++++++------------------------------------
 1 files changed, 97 insertions(+), 217 deletions(-)

diff --git a/funasr/models/scama/encoder.py b/funasr/models/scama/encoder.py
index 3651e61..e1fe924 100644
--- a/funasr/models/scama/encoder.py
+++ b/funasr/models/scama/encoder.py
@@ -17,7 +17,10 @@
 from funasr.train_utils.device_funcs import to_device
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
 from funasr.models.sanm.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
-from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
+from funasr.models.transformer.embedding import (
+    SinusoidalPositionEncoder,
+    StreamSinusoidalPositionEncoder,
+)
 from funasr.models.transformer.layer_norm import LayerNorm
 from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
 from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
@@ -36,6 +39,7 @@
 from funasr.models.ctc.ctc import CTC
 
 from funasr.register import tables
+
 
 class EncoderLayerSANM(nn.Module):
     def __init__(
@@ -96,7 +100,18 @@
             x = self.norm1(x)
 
         if self.concat_after:
-            x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
+            x_concat = torch.cat(
+                (
+                    x,
+                    self.self_attn(
+                        x,
+                        mask,
+                        mask_shfit_chunk=mask_shfit_chunk,
+                        mask_att_chunk_encoder=mask_att_chunk_encoder,
+                    ),
+                ),
+                dim=-1,
+            )
             if self.in_size == self.size:
                 x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
             else:
@@ -104,11 +119,21 @@
         else:
             if self.in_size == self.size:
                 x = residual + stoch_layer_coeff * self.dropout(
-                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+                    self.self_attn(
+                        x,
+                        mask,
+                        mask_shfit_chunk=mask_shfit_chunk,
+                        mask_att_chunk_encoder=mask_att_chunk_encoder,
+                    )
                 )
             else:
                 x = stoch_layer_coeff * self.dropout(
-                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+                    self.self_attn(
+                        x,
+                        mask,
+                        mask_shfit_chunk=mask_shfit_chunk,
+                        mask_att_chunk_encoder=mask_att_chunk_encoder,
+                    )
                 )
         if not self.normalize_before:
             x = self.norm1(x)
@@ -168,34 +193,34 @@
     """
 
     def __init__(
-            self,
-            input_size: int,
-            output_size: int = 256,
-            attention_heads: int = 4,
-            linear_units: int = 2048,
-            num_blocks: int = 6,
-            dropout_rate: float = 0.1,
-            positional_dropout_rate: float = 0.1,
-            attention_dropout_rate: float = 0.0,
-            input_layer: Optional[str] = "conv2d",
-            pos_enc_class=SinusoidalPositionEncoder,
-            normalize_before: bool = True,
-            concat_after: bool = False,
-            positionwise_layer_type: str = "linear",
-            positionwise_conv_kernel_size: int = 1,
-            padding_idx: int = -1,
-            interctc_layer_idx: List[int] = [],
-            interctc_use_conditioning: bool = False,
-            kernel_size: int = 11,
-            sanm_shfit: int = 0,
-            selfattention_layer_type: str = "sanm",
-            chunk_size: Union[int, Sequence[int]] = (16,),
-            stride: Union[int, Sequence[int]] = (10,),
-            pad_left: Union[int, Sequence[int]] = (0,),
-            encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
-            decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
-            tf2torch_tensor_name_prefix_torch: str = "encoder",
-            tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
+        self,
+        input_size: int,
+        output_size: int = 256,
+        attention_heads: int = 4,
+        linear_units: int = 2048,
+        num_blocks: int = 6,
+        dropout_rate: float = 0.1,
+        positional_dropout_rate: float = 0.1,
+        attention_dropout_rate: float = 0.0,
+        input_layer: Optional[str] = "conv2d",
+        pos_enc_class=SinusoidalPositionEncoder,
+        normalize_before: bool = True,
+        concat_after: bool = False,
+        positionwise_layer_type: str = "linear",
+        positionwise_conv_kernel_size: int = 1,
+        padding_idx: int = -1,
+        interctc_layer_idx: List[int] = [],
+        interctc_use_conditioning: bool = False,
+        kernel_size: int = 11,
+        sanm_shfit: int = 0,
+        selfattention_layer_type: str = "sanm",
+        chunk_size: Union[int, Sequence[int]] = (16,),
+        stride: Union[int, Sequence[int]] = (10,),
+        pad_left: Union[int, Sequence[int]] = (0,),
+        encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
+        decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
+        tf2torch_tensor_name_prefix_torch: str = "encoder",
+        tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
     ):
         super().__init__()
         self._output_size = output_size
@@ -334,12 +359,12 @@
         return self._output_size
 
     def forward(
-            self,
-            xs_pad: torch.Tensor,
-            ilens: torch.Tensor,
-            prev_states: torch.Tensor = None,
-            ctc: CTC = None,
-            ind: int = 0,
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+        prev_states: torch.Tensor = None,
+        ctc: CTC = None,
+        ind: int = 0,
     ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
         """Embed positions in tensor.
 
@@ -355,10 +380,10 @@
         if self.embed is None:
             xs_pad = xs_pad
         elif (
-                isinstance(self.embed, Conv2dSubsampling)
-                or isinstance(self.embed, Conv2dSubsampling2)
-                or isinstance(self.embed, Conv2dSubsampling6)
-                or isinstance(self.embed, Conv2dSubsampling8)
+            isinstance(self.embed, Conv2dSubsampling)
+            or isinstance(self.embed, Conv2dSubsampling2)
+            or isinstance(self.embed, Conv2dSubsampling6)
+            or isinstance(self.embed, Conv2dSubsampling8)
         ):
             short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
             if short_status:
@@ -378,21 +403,26 @@
             chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
             xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
             masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-            mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
-                                                                           dtype=xs_pad.dtype)
-            mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
-                                                                                       xs_pad.size(0),
-                                                                                       dtype=xs_pad.dtype)
+            mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(
+                chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype
+            )
+            mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(
+                chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype
+            )
 
         encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
         xs_pad, masks = encoder_outs[0], encoder_outs[1]
         intermediate_outs = []
         if len(self.interctc_layer_idx) == 0:
-            encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+            encoder_outs = self.encoders(
+                xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder
+            )
             xs_pad, masks = encoder_outs[0], encoder_outs[1]
         else:
             for layer_idx, encoder_layer in enumerate(self.encoders):
-                encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+                encoder_outs = encoder_layer(
+                    xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder
+                )
                 xs_pad, masks = encoder_outs[0], encoder_outs[1]
                 if layer_idx + 1 in self.interctc_layer_idx:
                     encoder_out = xs_pad
@@ -420,15 +450,16 @@
             return feats
         cache["feats"] = to_device(cache["feats"], device=feats.device)
         overlap_feats = torch.cat((cache["feats"], feats), dim=1)
-        cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+        cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]) :, :]
         return overlap_feats
 
-    def forward_chunk(self,
-                      xs_pad: torch.Tensor,
-                      ilens: torch.Tensor,
-                      cache: dict = None,
-                      **kwargs,
-                      ):
+    def forward_chunk(
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+        cache: dict = None,
+        **kwargs,
+    ):
         is_final = kwargs.get("is_final", False)
         xs_pad *= self.output_size() ** 0.5
         if self.embed is None:
@@ -446,12 +477,19 @@
             new_cache = cache["opt"]
 
         for layer_idx, encoder_layer in enumerate(self.encoders0):
-            encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"])
+            encoder_outs = encoder_layer.forward_chunk(
+                xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"]
+            )
             xs_pad, new_cache[0] = encoder_outs[0], encoder_outs[1]
 
         for layer_idx, encoder_layer in enumerate(self.encoders):
-            encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx+len(self.encoders0)], cache["chunk_size"], cache["encoder_chunk_look_back"])
-            xs_pad, new_cache[layer_idx+len(self.encoders0)] = encoder_outs[0], encoder_outs[1]
+            encoder_outs = encoder_layer.forward_chunk(
+                xs_pad,
+                new_cache[layer_idx + len(self.encoders0)],
+                cache["chunk_size"],
+                cache["encoder_chunk_look_back"],
+            )
+            xs_pad, new_cache[layer_idx + len(self.encoders0)] = encoder_outs[0], encoder_outs[1]
 
         if self.normalize_before:
             xs_pad = self.after_norm(xs_pad)
@@ -459,161 +497,3 @@
             cache["opt"] = new_cache
 
         return xs_pad, ilens, None
-
-    def gen_tf2torch_map_dict(self):
-        tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
-        tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
-        map_dict_local = {
-            ## encoder
-            # cicd
-            "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (768,256),(1,256,768)
-            "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (768,),(768,)
-            "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 2, 0),
-                 },  # (256,1,31),(1,31,256,1)
-            "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,256),(1,256,256)
-            "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            # ffn
-            "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (1024,256),(1,256,1024)
-            "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (1024,),(1024,)
-            "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
-                 "squeeze": 0,
-                 "transpose": (1, 0),
-                 },  # (256,1024),(1,1024,256)
-            "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            # out norm
-            "{}.after_norm.weight".format(tensor_name_prefix_torch):
-                {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-            "{}.after_norm.bias".format(tensor_name_prefix_torch):
-                {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
-                 "squeeze": None,
-                 "transpose": None,
-                 },  # (256,),(256,)
-        
-        }
-    
-        return map_dict_local
-
-    def convert_tf2torch(self,
-                         var_dict_tf,
-                         var_dict_torch,
-                         ):
-    
-        map_dict = self.gen_tf2torch_map_dict()
-    
-        var_dict_torch_update = dict()
-        for name in sorted(var_dict_torch.keys(), reverse=False):
-            names = name.split('.')
-            if names[0] == self.tf2torch_tensor_name_prefix_torch:
-                if names[1] == "encoders0":
-                    layeridx = int(names[2])
-                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-                
-                    name_q = name_q.replace("encoders0", "encoders")
-                    layeridx_bias = 0
-                    layeridx += layeridx_bias
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-                elif names[1] == "encoders":
-                    layeridx = int(names[2])
-                    name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-                    layeridx_bias = 1
-                    layeridx += layeridx_bias
-                    if name_q in map_dict.keys():
-                        name_v = map_dict[name_q]["name"]
-                        name_tf = name_v.replace("layeridx", "{}".format(layeridx))
-                        data_tf = var_dict_tf[name_tf]
-                        if map_dict[name_q]["squeeze"] is not None:
-                            data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
-                        if map_dict[name_q]["transpose"] is not None:
-                            data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
-                        data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                        assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
-                                                                                                        var_dict_torch[
-                                                                                                            name].size(),
-                                                                                                        data_tf.size())
-                        var_dict_torch_update[name] = data_tf
-                        logging.info(
-                            "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
-                                                                                          var_dict_tf[name_tf].shape))
-            
-                elif names[1] == "after_norm":
-                    name_tf = map_dict[name]["name"]
-                    data_tf = var_dict_tf[name_tf]
-                    data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
-                    var_dict_torch_update[name] = data_tf
-                    logging.info(
-                        "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
-                                                                                      var_dict_tf[name_tf].shape))
-    
-        return var_dict_torch_update
-

--
Gitblit v1.9.1