From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交
---
funasr/models/scama/decoder.py | 140 +++++++++++++++++++++++++++-------------------
1 files changed, 81 insertions(+), 59 deletions(-)
diff --git a/funasr/models/scama/decoder.py b/funasr/models/scama/decoder.py
index 8257f59..31b2357 100644
--- a/funasr/models/scama/decoder.py
+++ b/funasr/models/scama/decoder.py
@@ -13,13 +13,17 @@
from funasr.models.scama import utils as myutils
from funasr.models.transformer.decoder import BaseTransformerDecoder
-from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
+from funasr.models.sanm.attention import (
+ MultiHeadedAttentionSANMDecoder,
+ MultiHeadedAttentionCrossAtt,
+)
from funasr.models.transformer.embedding import PositionalEncoding
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr.models.transformer.utils.repeat import repeat
from funasr.register import tables
+
class DecoderLayerSANM(nn.Module):
"""Single decoder layer module.
@@ -151,10 +155,11 @@
x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
-
return x, tgt_mask, memory, memory_mask, cache
- def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
+ def forward_chunk(
+ self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0
+ ):
"""Compute decoded features.
Args:
@@ -194,6 +199,7 @@
return x, memory, fsmn_cache, opt_cache
+
@tables.register("decoder_classes", "FsmnDecoderSCAMAOpt")
class FsmnDecoderSCAMAOpt(BaseTransformerDecoder):
"""
@@ -201,31 +207,31 @@
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
https://arxiv.org/abs/2006.01712
"""
-
+
def __init__(
- self,
- vocab_size: int,
- encoder_output_size: int,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- self_attention_dropout_rate: float = 0.0,
- src_attention_dropout_rate: float = 0.0,
- input_layer: str = "embed",
- use_output_layer: bool = True,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
- att_layer_num: int = 6,
- kernel_size: int = 21,
- sanm_shfit: int = None,
- concat_embeds: bool = False,
- attention_dim: int = None,
- tf2torch_tensor_name_prefix_torch: str = "decoder",
- tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
- embed_tensor_name_prefix_tf: str = None,
+ self,
+ vocab_size: int,
+ encoder_output_size: int,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ self_attention_dropout_rate: float = 0.0,
+ src_attention_dropout_rate: float = 0.0,
+ input_layer: str = "embed",
+ use_output_layer: bool = True,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ att_layer_num: int = 6,
+ kernel_size: int = 21,
+ sanm_shfit: int = None,
+ concat_embeds: bool = False,
+ attention_dim: int = None,
+ tf2torch_tensor_name_prefix_torch: str = "decoder",
+ tf2torch_tensor_name_prefix_tf: str = "seq2seq/decoder",
+ embed_tensor_name_prefix_tf: str = None,
):
super().__init__(
vocab_size=vocab_size,
@@ -275,7 +281,10 @@
attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
),
MultiHeadedAttentionCrossAtt(
- attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size
+ attention_heads,
+ attention_dim,
+ src_attention_dropout_rate,
+ encoder_output_size=encoder_output_size,
),
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
dropout_rate,
@@ -291,7 +300,10 @@
lambda lnum: DecoderLayerSANM(
attention_dim,
MultiHeadedAttentionSANMDecoder(
- attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
+ attention_dim,
+ self_attention_dropout_rate,
+ kernel_size,
+ sanm_shfit=sanm_shfit,
),
None,
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
@@ -320,8 +332,12 @@
attention_dim + encoder_output_size,
None,
None,
- PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate,
- adim=attention_dim),
+ PositionwiseFeedForwardDecoderSANM(
+ attention_dim + encoder_output_size,
+ linear_units,
+ dropout_rate,
+ adim=attention_dim,
+ ),
dropout_rate,
normalize_before,
concat_after,
@@ -335,13 +351,13 @@
self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf
def forward(
- self,
- hs_pad: torch.Tensor,
- hlens: torch.Tensor,
- ys_in_pad: torch.Tensor,
- ys_in_lens: torch.Tensor,
- chunk_mask: torch.Tensor = None,
- pre_acoustic_embeds: torch.Tensor = None,
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ chunk_mask: torch.Tensor = None,
+ pre_acoustic_embeds: torch.Tensor = None,
) -> Tuple[torch.Tensor, torch.Tensor]:
"""Forward decoder.
@@ -376,16 +392,10 @@
x = torch.cat((x, pre_acoustic_embeds), dim=-1)
x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
- x, tgt_mask, memory, memory_mask, _ = self.decoders(
- x, tgt_mask, memory, memory_mask
- )
+ x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask)
if self.decoders2 is not None:
- x, tgt_mask, memory, memory_mask, _ = self.decoders2(
- x, tgt_mask, memory, memory_mask
- )
- x, tgt_mask, memory, memory_mask, _ = self.decoders3(
- x, tgt_mask, memory, memory_mask
- )
+ x, tgt_mask, memory, memory_mask, _ = self.decoders2(x, tgt_mask, memory, memory_mask)
+ x, tgt_mask, memory, memory_mask, _ = self.decoders3(x, tgt_mask, memory, memory_mask)
if self.normalize_before:
x = self.after_norm(x)
if self.output_layer is not None:
@@ -394,23 +404,36 @@
olens = tgt_mask.sum(1)
return x, olens
- def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ):
+ def score(
+ self,
+ ys,
+ state,
+ x,
+ x_mask=None,
+ pre_acoustic_embeds: torch.Tensor = None,
+ ):
"""Score."""
- ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
+ ys_mask = myutils.sequence_mask(
+ torch.tensor([len(ys)], dtype=torch.int32), device=x.device
+ )[:, :, None]
logp, state = self.forward_one_step(
- ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds,
- cache=state
+ ys.unsqueeze(0),
+ ys_mask,
+ x.unsqueeze(0),
+ memory_mask=x_mask,
+ pre_acoustic_embeds=pre_acoustic_embeds,
+ cache=state,
)
return logp.squeeze(0), state
def forward_one_step(
- self,
- tgt: torch.Tensor,
- tgt_mask: torch.Tensor,
- memory: torch.Tensor,
- memory_mask: torch.Tensor = None,
- pre_acoustic_embeds: torch.Tensor = None,
- cache: List[torch.Tensor] = None,
+ self,
+ tgt: torch.Tensor,
+ tgt_mask: torch.Tensor,
+ memory: torch.Tensor,
+ memory_mask: torch.Tensor = None,
+ pre_acoustic_embeds: torch.Tensor = None,
+ cache: List[torch.Tensor] = None,
) -> Tuple[torch.Tensor, List[torch.Tensor]]:
"""Forward one step.
@@ -473,4 +496,3 @@
y = torch.log_softmax(y, dim=-1)
return y, new_cache
-
--
Gitblit v1.9.1