From 3f8294b9d7deaa0cbdb0b2ef6f3802d46ae133a9 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 25 十二月 2024 17:16:11 +0800
Subject: [PATCH] Revert "shfit to shift (#2266)" (#2336)
---
funasr/models/scama/encoder.py | 30 +++++++++++++++---------------
1 files changed, 15 insertions(+), 15 deletions(-)
diff --git a/funasr/models/scama/encoder.py b/funasr/models/scama/encoder.py
index 0c871e1..e1fe924 100644
--- a/funasr/models/scama/encoder.py
+++ b/funasr/models/scama/encoder.py
@@ -69,7 +69,7 @@
self.stochastic_depth_rate = stochastic_depth_rate
self.dropout_rate = dropout_rate
- def forward(self, x, mask, cache=None, mask_shift_chunk=None, mask_att_chunk_encoder=None):
+ def forward(self, x, mask, cache=None, mask_shfit_chunk=None, mask_att_chunk_encoder=None):
"""Compute encoded features.
Args:
@@ -106,7 +106,7 @@
self.self_attn(
x,
mask,
- mask_shift_chunk=mask_shift_chunk,
+ mask_shfit_chunk=mask_shfit_chunk,
mask_att_chunk_encoder=mask_att_chunk_encoder,
),
),
@@ -122,7 +122,7 @@
self.self_attn(
x,
mask,
- mask_shift_chunk=mask_shift_chunk,
+ mask_shfit_chunk=mask_shfit_chunk,
mask_att_chunk_encoder=mask_att_chunk_encoder,
)
)
@@ -131,7 +131,7 @@
self.self_attn(
x,
mask,
- mask_shift_chunk=mask_shift_chunk,
+ mask_shfit_chunk=mask_shfit_chunk,
mask_att_chunk_encoder=mask_att_chunk_encoder,
)
)
@@ -145,7 +145,7 @@
if not self.normalize_before:
x = self.norm2(x)
- return x, mask, cache, mask_shift_chunk, mask_att_chunk_encoder
+ return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
def forward_chunk(self, x, cache=None, chunk_size=None, look_back=0):
"""Compute encoded features.
@@ -212,7 +212,7 @@
interctc_layer_idx: List[int] = [],
interctc_use_conditioning: bool = False,
kernel_size: int = 11,
- sanm_shift: int = 0,
+ sanm_shfit: int = 0,
selfattention_layer_type: str = "sanm",
chunk_size: Union[int, Sequence[int]] = (16,),
stride: Union[int, Sequence[int]] = (10,),
@@ -299,7 +299,7 @@
output_size,
attention_dropout_rate,
kernel_size,
- sanm_shift,
+ sanm_shfit,
)
encoder_selfattn_layer_args = (
@@ -308,7 +308,7 @@
output_size,
attention_dropout_rate,
kernel_size,
- sanm_shift,
+ sanm_shfit,
)
self.encoders0 = repeat(
1,
@@ -343,12 +343,12 @@
assert 0 < min(interctc_layer_idx) and max(interctc_layer_idx) < num_blocks
self.interctc_use_conditioning = interctc_use_conditioning
self.conditioning_layer = None
- shift_fsmn = (kernel_size - 1) // 2
+ shfit_fsmn = (kernel_size - 1) // 2
self.overlap_chunk_cls = overlap_chunk(
chunk_size=chunk_size,
stride=stride,
pad_left=pad_left,
- shift_fsmn=shift_fsmn,
+ shfit_fsmn=shfit_fsmn,
encoder_att_look_back_factor=encoder_att_look_back_factor,
decoder_att_look_back_factor=decoder_att_look_back_factor,
)
@@ -397,31 +397,31 @@
else:
xs_pad = self.embed(xs_pad)
- mask_shift_chunk, mask_att_chunk_encoder = None, None
+ mask_shfit_chunk, mask_att_chunk_encoder = None, None
if self.overlap_chunk_cls is not None:
ilens = masks.squeeze(1).sum(1)
chunk_outs = self.overlap_chunk_cls.gen_chunk_mask(ilens, ind)
xs_pad, ilens = self.overlap_chunk_cls.split_chunk(xs_pad, ilens, chunk_outs=chunk_outs)
masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
- mask_shift_chunk = self.overlap_chunk_cls.get_mask_shift_chunk(
+ mask_shfit_chunk = self.overlap_chunk_cls.get_mask_shfit_chunk(
chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype
)
mask_att_chunk_encoder = self.overlap_chunk_cls.get_mask_att_chunk_encoder(
chunk_outs, xs_pad.device, xs_pad.size(0), dtype=xs_pad.dtype
)
- encoder_outs = self.encoders0(xs_pad, masks, None, mask_shift_chunk, mask_att_chunk_encoder)
+ encoder_outs = self.encoders0(xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
intermediate_outs = []
if len(self.interctc_layer_idx) == 0:
encoder_outs = self.encoders(
- xs_pad, masks, None, mask_shift_chunk, mask_att_chunk_encoder
+ xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder
)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
else:
for layer_idx, encoder_layer in enumerate(self.encoders):
encoder_outs = encoder_layer(
- xs_pad, masks, None, mask_shift_chunk, mask_att_chunk_encoder
+ xs_pad, masks, None, mask_shfit_chunk, mask_att_chunk_encoder
)
xs_pad, masks = encoder_outs[0], encoder_outs[1]
if layer_idx + 1 in self.interctc_layer_idx:
--
Gitblit v1.9.1