From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/sanm/decoder.py |  104 +++++++++++++++++++++++++++++++---------------------
 1 files changed, 62 insertions(+), 42 deletions(-)

diff --git a/funasr/models/sanm/decoder.py b/funasr/models/sanm/decoder.py
index 3575282..1a4fb26 100644
--- a/funasr/models/sanm/decoder.py
+++ b/funasr/models/sanm/decoder.py
@@ -13,13 +13,17 @@
 from funasr.models.scama import utils as myutils
 from funasr.models.transformer.decoder import BaseTransformerDecoder
 
-from funasr.models.sanm.attention import MultiHeadedAttentionSANMDecoder, MultiHeadedAttentionCrossAtt
+from funasr.models.sanm.attention import (
+    MultiHeadedAttentionSANMDecoder,
+    MultiHeadedAttentionCrossAtt,
+)
 from funasr.models.transformer.embedding import PositionalEncoding
 from funasr.models.transformer.layer_norm import LayerNorm
 from funasr.models.sanm.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
 from funasr.models.transformer.utils.repeat import repeat
 
 from funasr.register import tables
+
 
 class DecoderLayerSANM(nn.Module):
     """Single decoder layer module.
@@ -151,10 +155,11 @@
 
             x = residual + self.dropout(self.src_attn(x, memory, memory_mask))
 
-
         return x, tgt_mask, memory, memory_mask, cache
 
-    def forward_chunk(self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0):
+    def forward_chunk(
+        self, tgt, memory, fsmn_cache=None, opt_cache=None, chunk_size=None, look_back=0
+    ):
         """Compute decoded features.
 
         Args:
@@ -202,7 +207,7 @@
     San-m: Memory equipped self-attention for end-to-end speech recognition
     https://arxiv.org/abs/2006.01713
     """
-    
+
     def __init__(
         self,
         vocab_size: int,
@@ -240,7 +245,7 @@
         )
         if attention_dim is None:
             attention_dim = encoder_output_size
-        
+
         if input_layer == "embed":
             self.embed = torch.nn.Sequential(
                 torch.nn.Embedding(vocab_size, attention_dim),
@@ -255,7 +260,7 @@
             )
         else:
             raise ValueError(f"only 'embed' or 'linear' is supported: {input_layer}")
-        
+
         self.normalize_before = normalize_before
         if self.normalize_before:
             self.after_norm = LayerNorm(attention_dim)
@@ -263,7 +268,7 @@
             self.output_layer = torch.nn.Linear(attention_dim, vocab_size)
         else:
             self.output_layer = None
-        
+
         self.att_layer_num = att_layer_num
         self.num_blocks = num_blocks
         if sanm_shfit is None:
@@ -276,7 +281,10 @@
                     attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
                 ),
                 MultiHeadedAttentionCrossAtt(
-                    attention_heads, attention_dim, src_attention_dropout_rate, encoder_output_size=encoder_output_size
+                    attention_heads,
+                    attention_dim,
+                    src_attention_dropout_rate,
+                    encoder_output_size=encoder_output_size,
                 ),
                 PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
                 dropout_rate,
@@ -292,7 +300,10 @@
                 lambda lnum: DecoderLayerSANM(
                     attention_dim,
                     MultiHeadedAttentionSANMDecoder(
-                        attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
+                        attention_dim,
+                        self_attention_dropout_rate,
+                        kernel_size,
+                        sanm_shfit=sanm_shfit,
                     ),
                     None,
                     PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
@@ -301,7 +312,7 @@
                     concat_after,
                 ),
             )
-        
+
         self.decoders3 = repeat(
             1,
             lambda lnum: DecoderLayerSANM(
@@ -321,8 +332,12 @@
                     attention_dim + encoder_output_size,
                     None,
                     None,
-                    PositionwiseFeedForwardDecoderSANM(attention_dim + encoder_output_size, linear_units, dropout_rate,
-                                                       adim=attention_dim),
+                    PositionwiseFeedForwardDecoderSANM(
+                        attention_dim + encoder_output_size,
+                        linear_units,
+                        dropout_rate,
+                        adim=attention_dim,
+                    ),
                     dropout_rate,
                     normalize_before,
                     concat_after,
@@ -334,7 +349,7 @@
         self.tf2torch_tensor_name_prefix_torch = tf2torch_tensor_name_prefix_torch
         self.tf2torch_tensor_name_prefix_tf = tf2torch_tensor_name_prefix_tf
         self.embed_tensor_name_prefix_tf = embed_tensor_name_prefix_tf
-    
+
     def forward(
         self,
         hs_pad: torch.Tensor,
@@ -363,47 +378,54 @@
         """
         tgt = ys_in_pad
         tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
-        
+
         memory = hs_pad
         memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
         if chunk_mask is not None:
             memory_mask = memory_mask * chunk_mask
             if tgt_mask.size(1) != memory_mask.size(1):
                 memory_mask = torch.cat((memory_mask, memory_mask[:, -2:-1, :]), dim=1)
-        
+
         x = self.embed(tgt)
-        
+
         if pre_acoustic_embeds is not None and self.concat_embeds:
             x = torch.cat((x, pre_acoustic_embeds), dim=-1)
             x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
-        
-        x, tgt_mask, memory, memory_mask, _ = self.decoders(
-            x, tgt_mask, memory, memory_mask
-        )
+
+        x, tgt_mask, memory, memory_mask, _ = self.decoders(x, tgt_mask, memory, memory_mask)
         if self.decoders2 is not None:
-            x, tgt_mask, memory, memory_mask, _ = self.decoders2(
-                x, tgt_mask, memory, memory_mask
-            )
-        x, tgt_mask, memory, memory_mask, _ = self.decoders3(
-            x, tgt_mask, memory, memory_mask
-        )
+            x, tgt_mask, memory, memory_mask, _ = self.decoders2(x, tgt_mask, memory, memory_mask)
+        x, tgt_mask, memory, memory_mask, _ = self.decoders3(x, tgt_mask, memory, memory_mask)
         if self.normalize_before:
             x = self.after_norm(x)
         if self.output_layer is not None:
             x = self.output_layer(x)
-        
+
         olens = tgt_mask.sum(1)
         return x, olens
-    
-    def score(self, ys, state, x, x_mask=None, pre_acoustic_embeds: torch.Tensor = None, ):
+
+    def score(
+        self,
+        ys,
+        state,
+        x,
+        x_mask=None,
+        pre_acoustic_embeds: torch.Tensor = None,
+    ):
         """Score."""
-        ys_mask = myutils.sequence_mask(torch.tensor([len(ys)], dtype=torch.int32), device=x.device)[:, :, None]
+        ys_mask = myutils.sequence_mask(
+            torch.tensor([len(ys)], dtype=torch.int32), device=x.device
+        )[:, :, None]
         logp, state = self.forward_one_step(
-            ys.unsqueeze(0), ys_mask, x.unsqueeze(0), memory_mask=x_mask, pre_acoustic_embeds=pre_acoustic_embeds,
-            cache=state
+            ys.unsqueeze(0),
+            ys_mask,
+            x.unsqueeze(0),
+            memory_mask=x_mask,
+            pre_acoustic_embeds=pre_acoustic_embeds,
+            cache=state,
         )
         return logp.squeeze(0), state
-    
+
     def forward_one_step(
         self,
         tgt: torch.Tensor,
@@ -426,15 +448,15 @@
             y, cache: NN output value and cache per `self.decoders`.
             y.shape` is (batch, maxlen_out, token)
         """
-        
+
         x = tgt[:, -1:]
         tgt_mask = None
         x = self.embed(x)
-        
+
         if pre_acoustic_embeds is not None and self.concat_embeds:
             x = torch.cat((x, pre_acoustic_embeds), dim=-1)
             x, _, _, _, _ = self.embed_concat_ffn(x, None, None, None, None)
-        
+
         if cache is None:
             cache_layer_num = len(self.decoders)
             if self.decoders2 is not None:
@@ -449,7 +471,7 @@
                 x, tgt_mask, memory, memory_mask, cache=c
             )
             new_cache.append(c_ret)
-        
+
         if self.num_blocks - self.att_layer_num >= 1:
             for i in range(self.num_blocks - self.att_layer_num):
                 j = i + self.att_layer_num
@@ -459,12 +481,12 @@
                     x, tgt_mask, memory, memory_mask, cache=c
                 )
                 new_cache.append(c_ret)
-        
+
         for decoder in self.decoders3:
             x, tgt_mask, memory, memory_mask, _ = decoder.forward_one_step(
                 x, tgt_mask, memory, None, cache=None
             )
-        
+
         if self.normalize_before:
             y = self.after_norm(x[:, -1])
         else:
@@ -472,7 +494,5 @@
         if self.output_layer is not None:
             y = self.output_layer(y)
             y = torch.log_softmax(y, dim=-1)
-        
+
         return y, new_cache
-    
-    
\ No newline at end of file

--
Gitblit v1.9.1