From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
funasr/models/scama/encoder.py | 157 ++++++++++++++++++++++++++++++++--------------------
1 files changed, 97 insertions(+), 60 deletions(-)
diff --git a/funasr/models/scama/encoder.py b/funasr/models/scama/encoder.py
index 2c676b2..e1fe924 100644
--- a/funasr/models/scama/encoder.py
+++ b/funasr/models/scama/encoder.py
@@ -17,7 +17,10 @@
from funasr.train_utils.device_funcs import to_device
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.sanm.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
-from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
+from funasr.models.transformer.embedding import (
+ SinusoidalPositionEncoder,
+ StreamSinusoidalPositionEncoder,
+)
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
@@ -36,6 +39,7 @@
from funasr.models.ctc.ctc import CTC
from funasr.register import tables
+
class EncoderLayerSANM(nn.Module):
def __init__(
@@ -96,7 +100,18 @@
x = self.norm1(x)
if self.concat_after:
- x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
+ x_concat = torch.cat(
+ (
+ x,
+ self.self_attn(
+ x,
+ mask,
+ mask_shfit_chunk=mask_shfit_chunk,
+ mask_att_chunk_encoder=mask_att_chunk_encoder,
+ ),
+ ),
+ dim=-1,
+ )
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
else:
@@ -104,11 +119,21 @@
else:
if self.in_size == self.size:
x = residual + stoch_layer_coeff * self.dropout(
- self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+ self.self_attn(
+ x,
+ mask,
+ mask_shfit_chunk=mask_shfit_chunk,
+ mask_att_chunk_encoder=mask_att_chunk_encoder,
+ )
)
else:
x = stoch_layer_coeff * self.dropout(
- self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+ self.self_attn(
+ x,
+ mask,
+ mask_shfit_chunk=mask_shfit_chunk,
+ mask_att_chunk_encoder=mask_att_chunk_encoder,
+ )
)
if not self.normalize_before:
x = self.norm1(x)
@@ -168,34 +193,34 @@
"""
def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- input_layer: Optional[str] = "conv2d",
- pos_enc_class=SinusoidalPositionEncoder,
- normalize_before: bool = True,
- concat_after: bool = False,
- positionwise_layer_type: str = "linear",
- positionwise_conv_kernel_size: int = 1,
- padding_idx: int = -1,
- interctc_layer_idx: List[int] = [],
- interctc_use_conditioning: bool = False,
- kernel_size: int = 11,
- sanm_shfit: int = 0,
- selfattention_layer_type: str = "sanm",
- chunk_size: Union[int, Sequence[int]] = (16,),
- stride: Union[int, Sequence[int]] = (10,),
- pad_left: Union[int, Sequence[int]] = (0,),
- encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
- decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
- tf2torch_tensor_name_prefix_torch: str = "encoder",
- tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ input_layer: Optional[str] = "conv2d",
+ pos_enc_class=SinusoidalPositionEncoder,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ positionwise_layer_type: str = "linear",
+ positionwise_conv_kernel_size: int = 1,
+ padding_idx: int = -1,
+ interctc_layer_idx: List[int] = [],
+ interctc_use_conditioning: bool = False,
+ kernel_size: int = 11,
+ sanm_shfit: int = 0,
+ selfattention_layer_type: str = "sanm",
+ chunk_size: Union[int, Sequence[int]] = (16,),
+ stride: Union[int, Sequence[int]] = (10,),
+ pad_left: Union[int, Sequence[int]] = (0,),
+ encoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
+ decoder_att_look_back_factor: Union[int, Sequence[int]] = (1,),
+ tf2torch_tensor_name_prefix_torch: str = "encoder",
+ tf2torch_tensor_name_prefix_tf: str = "seq2seq/encoder",
):
super().__init__()
self._output_size = output_size
@@ -334,12 +359,12 @@
return self._output_size
def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- prev_states: torch.Tensor = None,
- ctc: CTC = None,
- ind: int = 0,
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ prev_states: torch.Tensor = None,
+ ctc: CTC = None,
+ ind: int = 0,
) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
"""Embed positions in tensor.
@@ -355,10 +380,10 @@
if self.embed is None:
xs_pad = xs_pad
elif (
- isinstance(self.embed, Conv2dSubsampling)
- or isinstance(self.embed, Conv2dSubsampling2)
- or isinstance(self.embed, Conv2dSubsampling6)
- or isinstance(self.embed, Conv2dSubsampling8)
+ isinstance(self.embed, Conv2dSubsampling)
+ or isinstance(self.embed, Conv2dSubsampling2)
+ or isinstance(self.embed, Conv2dSubsampling6)
+ or isinstance(self.embed, Conv2dSubsampling8)
):
short_status, limit_size = check_short_utt(self.embed, xs_pad.size(1))
if short_status:
@@ -378,21 +403,26 @@
chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
- mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(chunk_outs, xs_pad.device, xs_pad.size(0),
- dtype=xs_pad.dtype)
- mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(chunk_outs, xs_pad.device,
- xs_pad.size(0),
- dtype=xs_pad.dtype)
+ mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(
+ chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype
+ )
+ mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(
+ chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype
+ )
encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
- encoder_outs = self.encoders(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+ encoder_outs = self.encoders(
+ xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder
+ )
xs_pad, masks = encoder_outs[0], encoder_outs[1]
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
- encoder_outs = encoder_layer(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
+ encoder_outs = encoder_layer(
+ xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder
+ )
xs_pad, masks = encoder_outs[0], encoder_outs[1]
if layer_idx + 1 in self.interctc_layer_idx:
encoder_out = xs_pad
@@ -420,15 +450,16 @@
return feats
cache["feats"] = to_device(cache["feats"], device=feats.device)
overlap_feats = torch.cat((cache["feats"], feats), dim=1)
- cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+ cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]) :, :]
return overlap_feats
- def forward_chunk(self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- cache: dict = None,
- **kwargs,
- ):
+ def forward_chunk(
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ cache: dict = None,
+ **kwargs,
+ ):
is_final = kwargs.get("is_final", False)
xs_pad *= self.output_size() ** 0.5
if self.embed is None:
@@ -446,12 +477,19 @@
new_cache = cache["opt"]
for layer_idx, encoder_layer in enumerate(self.encoders0):
- encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"])
+ encoder_outs = encoder_layer.forward_chunk(
+ xs_pad, new_cache[layer_idx], cache["chunk_size"], cache["encoder_chunk_look_back"]
+ )
xs_pad, new_cache[0] = encoder_outs[0], encoder_outs[1]
for layer_idx, encoder_layer in enumerate(self.encoders):
- encoder_outs = encoder_layer.forward_chunk(xs_pad, new_cache[layer_idx+len(self.encoders0)], cache["chunk_size"], cache["encoder_chunk_look_back"])
- xs_pad, new_cache[layer_idx+len(self.encoders0)] = encoder_outs[0], encoder_outs[1]
+ encoder_outs = encoder_layer.forward_chunk(
+ xs_pad,
+ new_cache[layer_idx + len(self.encoders0)],
+ cache["chunk_size"],
+ cache["encoder_chunk_look_back"],
+ )
+ xs_pad, new_cache[layer_idx + len(self.encoders0)] = encoder_outs[0], encoder_outs[1]
if self.normalize_before:
xs_pad = self.after_norm(xs_pad)
@@ -459,4 +497,3 @@
cache["opt"] = new_cache
return xs_pad, ilens, None
-
--
Gitblit v1.9.1