From 1cdb3cc28d4d89a576cc06e5cd8eb80da1f3a3aa Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 26 四月 2024 11:27:39 +0800
Subject: [PATCH] Dev gzf exp (#1665)
---
funasr/models/scama/encoder.py | 314 ++++++++++++++++------------------------------------
1 files changed, 97 insertions(+), 217 deletions(-)
diff --git a/funasr/models/scama/encoder.py b/funasr/models/scama/encoder.py
index 3651e61..e1fe924 100644
--- a/funasr/models/scama/encoder.py
+++ b/funasr/models/scama/encoder.py
@@ -17,7 +17,10 @@
from funasr.train_utils.device_funcs import to_device
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.sanm.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
-from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
+from funasr.models.transformer.embedding import (
+ SinusoidalPositionEncoder,
+ StreamSinusoidalPositionEncoder,
+)
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
@@ -36,6 +39,7 @@
from funasr.models.ctc.ctc import CTC
from funasr.register import tables
+
class EncoderLayerSANM(nn.Module):
def __init__(
@@ -96,7 +100,18 @@
x = self.norm1(x)
if self.concat_after:
- x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
+ x_concat = torch.cat(
+ (
+ x,
+ self.self_attn(
+ x,
+ mask,
+ mask_shfit_chunk=mask_shfit_chunk,
+ mask_att_chunk_encoder=mask_att_chunk_encoder,
+ ),
+ ),
+ dim=-1,
+ )
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
@@ -104,11 +119,21 @@
else:
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.dropout(
- self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+ self.self_attn(
+ x,
+ mask,
+ mask_shfit_chunk=mask_shfit_chunk,
+ mask_att_chunk_encoder=mask_att_chunk_encoder,
+ )
)
else:
x = stoch_layer_coeff * self.dropout(
- self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+ self.self_attn(
+ x,
+ mask,
+ mask_shfit_chunk=mask_shfit_chunk,
+ mask_att_chunk_encoder=mask_att_chunk_encoder,
+ )
)
if not self.normalize_before:
x = self.norm1(x)
@@ -168,34 +193,34 @@
"""
def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- input_layer: Optional[str] = "conv2d",
- pos_enc_class=SinusoidalPositionEncoder,
- normalize_before: bool = True,
- concat_after: bool = False,
- positionwise_layer_type: str = "linear",
- positionwise_conv_kernel_size: int = 1,
- padding_idx: int = -1,
- interctc_layer_idx: List[int] = [],
- interctc_use_conditioning: bool = False,
- kernel_size: int = 11,
- sanm_shfit: int = 0,
- selfattention_layer_type: str = "sanm",
- chunk_size: Union[int, Sequence[int]] = (16,),
- stride: Union[int, Sequence[int]] = (10,),
- pad_left: Union[int, Sequence[int]] = (0,),
- encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
- decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
- tf2torch_tensor_name_prefix_torch: str = "encoder",
- tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: Optional[str] = "conv2d",
+ pos_enc_class=SinusoidalPositionEncoder,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ positionwise_layer_type: str = "linear",
+ positionwise_conv_kernel_size: int = 1,
+ padding_idx: int = -1,
+ interctc_layer_idx: List[int] = [],
+ interctc_use_conditioning: bool = False,
+ kernel_size: int = 11,
+ sanm_shfit: int = 0,
+ selfattention_layer_type: str = "sanm",
+ chunk_size: Union[int, Sequence[int]] = (16,),
+ stride: Union[int, Sequence[int]] = (10,),
+ pad_left: Union[int, Sequence[int]] = (0,),
+ encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
+ decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
+ tf2torch_tensor_name_prefix_torch: str = "encoder",
+ tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
):
super().__init__()
self._output_size = output_size
@@ -334,12 +359,12 @@
return self._output_size
def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- prev_states: torch.Tensor = None,
- ctc: CTC = None,
- ind: int = 0,
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ prev_states: torch.Tensor = None,
+ ctc: CTC = None,
+ ind: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
@@ -355,10 +380,10 @@
if self.embed is None:
xs_pad = xs_pad
elif (
- isinstance(self.embed, Conv2dSubsampling)
- or isinstance(self.embed, Conv2dSubsampling2)
- or isinstance(self.embed, Conv2dSubsampling6)
- or isinstance(self.embed, Conv2dSubsampling8)
+ isinstance(self.embed, Conv2dSubsampling)
+ or isinstance(self.embed, Conv2dSubsampling2)
+ or isinstance(self.embed, Conv2dSubsampling6)
+ or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
@@ -378,21 +403,26 @@
chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
- mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
- dtype=xs_pad.dtype)
- mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
- xs_pad.size(0),
- dtype=xs_pad.dtype)
+ mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(
+ chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype
+ )
+ mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(
+ chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype
+ )
encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
- encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+ encoder_outs = self.encoders(
+ xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder
+ )
xs_pad, masks = encoder_outs[0], encoder_outs[1]
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
- encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+ encoder_outs = encoder_layer(
+ xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder
+ )
xs_pad, masks = encoder_outs[0], encoder_outs[1]
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
@@ -420,15 +450,16 @@
return feats
cache["feats"] = to_device(cache["feats"], device=feats.device)
overlap_feats = torch.cat((cache["feats"], feats), dim=1)
- cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+ cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]) :, :]
return overlap_feats
- def forward_chunk(self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- cache: dict = None,
- **kwargs,
- ):
+ def forward_chunk(
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ cache: dict = None,
+ **kwargs,
+ ):
is_final = kwargs.get("is_final", False)
xs_pad *= self.output_size() ** 0.5
if self.embed is None:
@@ -446,12 +477,19 @@
new_cache = cache["opt"]
for layer_idx, encoder_layer in enumerate(self.encoders0):
- encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"])
+ encoder_outs = encoder_layer.forward_chunk(
+ xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"]
+ )
xs_pad, new_cache[0] = encoder_outs[0], encoder_outs[1]
for layer_idx, encoder_layer in enumerate(self.encoders):
- encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx+len(self.encoders0)], cache["chunk_size"], cache["encoder_chunk_look_back"])
- xs_pad, new_cache[layer_idx+len(self.encoders0)] = encoder_outs[0], encoder_outs[1]
+ encoder_outs = encoder_layer.forward_chunk(
+ xs_pad,
+ new_cache[layer_idx + len(self.encoders0)],
+ cache["chunk_size"],
+ cache["encoder_chunk_look_back"],
+ )
+ xs_pad, new_cache[layer_idx + len(self.encoders0)] = encoder_outs[0], encoder_outs[1]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
@@ -459,161 +497,3 @@
cache["opt"] = new_cache
return xs_pad, ilens, None
-
- def gen_tf2torch_map_dict(self):
- tensor_name_prefix_torch = self.tf2torch_tensor_name_prefix_torch
- tensor_name_prefix_tf = self.tf2torch_tensor_name_prefix_tf
- map_dict_local = {
- ## encoder
- # cicd
- "{}.encoders.layeridx.norm1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.norm1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.self_attn.linear_q_k_v.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (768,256),(1,256,768)
- "{}.encoders.layeridx.self_attn.linear_q_k_v.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (768,),(768,)
- "{}.encoders.layeridx.self_attn.fsmn_block.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/depth_conv_w".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 2, 0),
- }, # (256,1,31),(1,31,256,1)
- "{}.encoders.layeridx.self_attn.linear_out.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,256),(1,256,256)
- "{}.encoders.layeridx.self_attn.linear_out.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/multi_head/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # ffn
- "{}.encoders.layeridx.norm2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.norm2.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.encoders.layeridx.feed_forward.w_1.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (1024,256),(1,256,1024)
- "{}.encoders.layeridx.feed_forward.w_1.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (1024,),(1024,)
- "{}.encoders.layeridx.feed_forward.w_2.weight".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d_1/kernel".format(tensor_name_prefix_tf),
- "squeeze": 0,
- "transpose": (1, 0),
- }, # (256,1024),(1,1024,256)
- "{}.encoders.layeridx.feed_forward.w_2.bias".format(tensor_name_prefix_torch):
- {"name": "{}/layer_layeridx/ffn/conv1d_1/bias".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- # out norm
- "{}.after_norm.weight".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/gamma".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
- "{}.after_norm.bias".format(tensor_name_prefix_torch):
- {"name": "{}/LayerNorm/beta".format(tensor_name_prefix_tf),
- "squeeze": None,
- "transpose": None,
- }, # (256,),(256,)
-
- }
-
- return map_dict_local
-
- def convert_tf2torch(self,
- var_dict_tf,
- var_dict_torch,
- ):
-
- map_dict = self.gen_tf2torch_map_dict()
-
- var_dict_torch_update = dict()
- for name in sorted(var_dict_torch.keys(), reverse=False):
- names = name.split('.')
- if names[0] == self.tf2torch_tensor_name_prefix_torch:
- if names[1] == "encoders0":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
-
- name_q = name_q.replace("encoders0", "encoders")
- layeridx_bias = 0
- layeridx += layeridx_bias
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
- elif names[1] == "encoders":
- layeridx = int(names[2])
- name_q = name.replace(".{}.".format(layeridx), ".layeridx.")
- layeridx_bias = 1
- layeridx += layeridx_bias
- if name_q in map_dict.keys():
- name_v = map_dict[name_q]["name"]
- name_tf = name_v.replace("layeridx", "{}".format(layeridx))
- data_tf = var_dict_tf[name_tf]
- if map_dict[name_q]["squeeze"] is not None:
- data_tf = np.squeeze(data_tf, axis=map_dict[name_q]["squeeze"])
- if map_dict[name_q]["transpose"] is not None:
- data_tf = np.transpose(data_tf, map_dict[name_q]["transpose"])
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- assert var_dict_torch[name].size() == data_tf.size(), "{}, {}, {} != {}".format(name, name_tf,
- var_dict_torch[
- name].size(),
- data_tf.size())
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_v,
- var_dict_tf[name_tf].shape))
-
- elif names[1] == "after_norm":
- name_tf = map_dict[name]["name"]
- data_tf = var_dict_tf[name_tf]
- data_tf = torch.from_numpy(data_tf).type(torch.float32).to("cpu")
- var_dict_torch_update[name] = data_tf
- logging.info(
- "torch tensor: {}, {}, loading from tf tensor: {}, {}".format(name, data_tf.size(), name_tf,
- var_dict_tf[name_tf].shape))
-
- return var_dict_torch_update
-
--
Gitblit v1.9.1