From 8c7b7e5feb68fda1fc4ddd627bad0f915358149e Mon Sep 17 00:00:00 2001
From: Zhanzhao (Deo) Liang <liangzhanzhao1985@gmail.com>
Date: 星期三, 25 十二月 2024 16:40:29 +0800
Subject: [PATCH] fix export_meta import of sense voice (#2334)
---
funasr/models/scama/encoder.py | 173 +++++++++++++++++++++++++++++++++++----------------------
1 files changed, 105 insertions(+), 68 deletions(-)
diff --git a/funasr/models/scama/encoder.py b/funasr/models/scama/encoder.py
index 2c676b2..0c871e1 100644
--- a/funasr/models/scama/encoder.py
+++ b/funasr/models/scama/encoder.py
@@ -17,7 +17,10 @@
from funasr.train_utils.device_funcs import to_device
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.sanm.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
-from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
+from funasr.models.transformer.embedding import (
+ SinusoidalPositionEncoder,
+ StreamSinusoidalPositionEncoder,
+)
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
@@ -36,6 +39,7 @@
from funasr.models.ctc.ctc import CTC
from funasr.register import tables
+
class EncoderLayerSANM(nn.Module):
def __init__(
@@ -65,7 +69,7 @@
self.stochastic_depth_rate = stochastic_depth_rate
self.dropout_rate = dropout_rate
- def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
+ def forward(self, x, mask, cache=None, mask_shift_chunk=None, mask_att_chunk_encoder=None):
"""Compute encoded features.
Args:
@@ -96,7 +100,18 @@
x = self.norm1(x)
if self.concat_after:
- x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
+ x_concat = torch.cat(
+ (
+ x,
+ self.self_attn(
+ x,
+ mask,
+ mask_shift_chunk=mask_shift_chunk,
+ mask_att_chunk_encoder=mask_att_chunk_encoder,
+ ),
+ ),
+ dim=-1,
+ )
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
@@ -104,11 +119,21 @@
else:
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.dropout(
- self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+ self.self_attn(
+ x,
+ mask,
+ mask_shift_chunk=mask_shift_chunk,
+ mask_att_chunk_encoder=mask_att_chunk_encoder,
+ )
)
else:
x = stoch_layer_coeff * self.dropout(
- self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+ self.self_attn(
+ x,
+ mask,
+ mask_shift_chunk=mask_shift_chunk,
+ mask_att_chunk_encoder=mask_att_chunk_encoder,
+ )
)
if not self.normalize_before:
x = self.norm1(x)
@@ -120,7 +145,7 @@
if not self.normalize_before:
x = self.norm2(x)
- return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
+ return x, mask, cache, mask_shift_chunk, mask_att_chunk_encoder
def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
"""Compute encoded features.
@@ -168,34 +193,34 @@
"""
def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- input_layer: Optional[str] = "conv2d",
- pos_enc_class=SinusoidalPositionEncoder,
- normalize_before: bool = True,
- concat_after: bool = False,
- positionwise_layer_type: str = "linear",
- positionwise_conv_kernel_size: int = 1,
- padding_idx: int = -1,
- interctc_layer_idx: List[int] = [],
- interctc_use_conditioning: bool = False,
- kernel_size: int = 11,
- sanm_shfit: int = 0,
- selfattention_layer_type: str = "sanm",
- chunk_size: Union[int, Sequence[int]] = (16,),
- stride: Union[int, Sequence[int]] = (10,),
- pad_left: Union[int, Sequence[int]] = (0,),
- encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
- decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
- tf2torch_tensor_name_prefix_torch: str = "encoder",
- tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: Optional[str] = "conv2d",
+ pos_enc_class=SinusoidalPositionEncoder,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ positionwise_layer_type: str = "linear",
+ positionwise_conv_kernel_size: int = 1,
+ padding_idx: int = -1,
+ interctc_layer_idx: List[int] = [],
+ interctc_use_conditioning: bool = False,
+ kernel_size: int = 11,
+ sanm_shift: int = 0,
+ selfattention_layer_type: str = "sanm",
+ chunk_size: Union[int, Sequence[int]] = (16,),
+ stride: Union[int, Sequence[int]] = (10,),
+ pad_left: Union[int, Sequence[int]] = (0,),
+ encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
+ decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
+ tf2torch_tensor_name_prefix_torch: str = "encoder",
+ tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
):
super().__init__()
self._output_size = output_size
@@ -274,7 +299,7 @@
output_size,
attention_dropout_rate,
kernel_size,
- sanm_shfit,
+ sanm_shift,
)
encoder_selfattn_layer_args = (
@@ -283,7 +308,7 @@
output_size,
attention_dropout_rate,
kernel_size,
- sanm_shfit,
+ sanm_shift,
)
self.encoders0 = repeat(
1,
@@ -318,12 +343,12 @@
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
- shfit_fsmn = (kernel_size - 1) // 2
+ shift_fsmn = (kernel_size - 1) // 2
self.overlap_chunk_cls = overlap_chunk(
chunk_size=chunk_size,
stride=stride,
pad_left=pad_left,
- shfit_fsmn=shfit_fsmn,
+ shift_fsmn=shift_fsmn,
encoder_att_look_back_factor=encoder_att_look_back_factor,
decoder_att_look_back_factor=decoder_att_look_back_factor,
)
@@ -334,12 +359,12 @@
return self._output_size
def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- prev_states: torch.Tensor = None,
- ctc: CTC = None,
- ind: int = 0,
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ prev_states: torch.Tensor = None,
+ ctc: CTC = None,
+ ind: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
@@ -355,10 +380,10 @@
if self.embed is None:
xs_pad = xs_pad
elif (
- isinstance(self.embed, Conv2dSubsampling)
- or isinstance(self.embed, Conv2dSubsampling2)
- or isinstance(self.embed, Conv2dSubsampling6)
- or isinstance(self.embed, Conv2dSubsampling8)
+ isinstance(self.embed, Conv2dSubsampling)
+ or isinstance(self.embed, Conv2dSubsampling2)
+ or isinstance(self.embed, Conv2dSubsampling6)
+ or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
@@ -372,27 +397,32 @@
else:
xs_pad = self.embed(xs_pad)
- mask_shfit_chunk, mask_att_chunk_encoder = None, None
+ mask_shift_chunk, mask_att_chunk_encoder = None, None
if self.overlap_chunk_cls is not None:
ilens = masks.squeeze(1).sum(1)
chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
- mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
- dtype=xs_pad.dtype)
- mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
- xs_pad.size(0),
- dtype=xs_pad.dtype)
+ mask_shift_chunk = self.overlap_chunk_cls.get_mask_shift_chunk(
+ chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype
+ )
+ mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(
+ chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype
+ )
- encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+ encoder_outs = self.encoders0(xs_pad, masks, None, mask_shift_chunk, mask_att_chunk_encoder)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
- encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+ encoder_outs = self.encoders(
+ xs_pad, masks, None, mask_shift_chunk, mask_att_chunk_encoder
+ )
xs_pad, masks = encoder_outs[0], encoder_outs[1]
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
- encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+ encoder_outs = encoder_layer(
+ xs_pad, masks, None, mask_shift_chunk, mask_att_chunk_encoder
+ )
xs_pad, masks = encoder_outs[0], encoder_outs[1]
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
@@ -420,15 +450,16 @@
return feats
cache["feats"] = to_device(cache["feats"], device=feats.device)
overlap_feats = torch.cat((cache["feats"], feats), dim=1)
- cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+ cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]) :, :]
return overlap_feats
- def forward_chunk(self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- cache: dict = None,
- **kwargs,
- ):
+ def forward_chunk(
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ cache: dict = None,
+ **kwargs,
+ ):
is_final = kwargs.get("is_final", False)
xs_pad *= self.output_size() ** 0.5
if self.embed is None:
@@ -446,12 +477,19 @@
new_cache = cache["opt"]
for layer_idx, encoder_layer in enumerate(self.encoders0):
- encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"])
+ encoder_outs = encoder_layer.forward_chunk(
+ xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"]
+ )
xs_pad, new_cache[0] = encoder_outs[0], encoder_outs[1]
for layer_idx, encoder_layer in enumerate(self.encoders):
- encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx+len(self.encoders0)], cache["chunk_size"], cache["encoder_chunk_look_back"])
- xs_pad, new_cache[layer_idx+len(self.encoders0)] = encoder_outs[0], encoder_outs[1]
+ encoder_outs = encoder_layer.forward_chunk(
+ xs_pad,
+ new_cache[layer_idx + len(self.encoders0)],
+ cache["chunk_size"],
+ cache["encoder_chunk_look_back"],
+ )
+ xs_pad, new_cache[layer_idx + len(self.encoders0)] = encoder_outs[0], encoder_outs[1]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
@@ -459,4 +497,3 @@
cache["opt"] = new_cache
return xs_pad, ilens, None
-
--
Gitblit v1.9.1