From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/models/sanm/encoder.py |  192 +++++++++++++++++++++++++++++++++++++++++++----
 1 files changed, 175 insertions(+), 17 deletions(-)

diff --git a/funasr/models/sanm/encoder.py b/funasr/models/sanm/encoder.py
index cb4e21a..0d39ca7 100644
--- a/funasr/models/sanm/encoder.py
+++ b/funasr/models/sanm/encoder.py
@@ -1,3 +1,8 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 from typing import List
 from typing import Optional
 from typing import Sequence
@@ -12,7 +17,10 @@
 from funasr.train_utils.device_funcs import to_device
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
 from funasr.models.sanm.attention import MultiHeadedAttention, MultiHeadedAttentionSANM
-from funasr.models.transformer.embedding import SinusoidalPositionEncoder, StreamSinusoidalPositionEncoder
+from funasr.models.transformer.embedding import (
+    SinusoidalPositionEncoder,
+    StreamSinusoidalPositionEncoder,
+)
 from funasr.models.transformer.layer_norm import LayerNorm
 from funasr.models.transformer.utils.multi_layer_conv import Conv1dLinear
 from funasr.models.transformer.utils.multi_layer_conv import MultiLayeredConv1d
@@ -31,6 +39,7 @@
 from funasr.models.ctc.ctc import CTC
 
 from funasr.register import tables
+
 
 class EncoderLayerSANM(nn.Module):
     def __init__(
@@ -91,7 +100,18 @@
             x = self.norm1(x)
 
         if self.concat_after:
-            x_concat = torch.cat((x, self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)), dim=-1)
+            x_concat = torch.cat(
+                (
+                    x,
+                    self.self_attn(
+                        x,
+                        mask,
+                        mask_shfit_chunk=mask_shfit_chunk,
+                        mask_att_chunk_encoder=mask_att_chunk_encoder,
+                    ),
+                ),
+                dim=-1,
+            )
             if self.in_size == self.size:
                 x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
             else:
@@ -99,11 +119,21 @@
         else:
             if self.in_size == self.size:
                 x = residual + stoch_layer_coeff * self.dropout(
-                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+                    self.self_attn(
+                        x,
+                        mask,
+                        mask_shfit_chunk=mask_shfit_chunk,
+                        mask_att_chunk_encoder=mask_att_chunk_encoder,
+                    )
                 )
             else:
                 x = stoch_layer_coeff * self.dropout(
-                    self.self_attn(x, mask, mask_shfit_chunk=mask_shfit_chunk, mask_att_chunk_encoder=mask_att_chunk_encoder)
+                    self.self_attn(
+                        x,
+                        mask,
+                        mask_shfit_chunk=mask_shfit_chunk,
+                        mask_att_chunk_encoder=mask_att_chunk_encoder,
+                    )
                 )
         if not self.normalize_before:
             x = self.norm1(x)
@@ -153,13 +183,13 @@
 
         return x, cache
 
+
 @tables.register("encoder_classes", "SANMEncoder")
 class SANMEncoder(nn.Module):
     """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
+    Author: Zhifu Gao, Shiliang Zhang, Ming Lei, Ian McLoughlin
     San-m: Memory equipped self-attention for end-to-end speech recognition
     https://arxiv.org/abs/2006.01713
-
     """
 
     def __init__(
@@ -181,8 +211,8 @@
         padding_idx: int = -1,
         interctc_layer_idx: List[int] = [],
         interctc_use_conditioning: bool = False,
-        kernel_size : int = 11,
-        sanm_shfit : int = 0,
+        kernel_size: int = 11,
+        sanm_shfit: int = 0,
         lora_list: List[str] = None,
         lora_rank: int = 8,
         lora_alpha: int = 16,
@@ -302,7 +332,7 @@
         )
 
         self.encoders = repeat(
-            num_blocks-1,
+            num_blocks - 1,
             lambda lnum: EncoderLayerSANM(
                 output_size,
                 output_size,
@@ -345,7 +375,7 @@
             position embedded tensor and mask
         """
         masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-        xs_pad = xs_pad * self.output_size()**0.5
+        xs_pad = xs_pad * self.output_size() ** 0.5
         if self.embed is None:
             xs_pad = xs_pad
         elif (
@@ -404,15 +434,16 @@
             return feats
         cache["feats"] = to_device(cache["feats"], device=feats.device)
         overlap_feats = torch.cat((cache["feats"], feats), dim=1)
-        cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]):, :]
+        cache["feats"] = overlap_feats[:, -(cache["chunk_size"][0] + cache["chunk_size"][2]) :, :]
         return overlap_feats
 
-    def forward_chunk(self,
-                      xs_pad: torch.Tensor,
-                      ilens: torch.Tensor,
-                      cache: dict = None,
-                      ctc: CTC = None,
-                      ):
+    def forward_chunk(
+        self,
+        xs_pad: torch.Tensor,
+        ilens: torch.Tensor,
+        cache: dict = None,
+        ctc: CTC = None,
+    ):
         xs_pad *= self.output_size() ** 0.5
         if self.embed is None:
             xs_pad = xs_pad
@@ -452,3 +483,130 @@
             return (xs_pad, intermediate_outs), None, None
         return xs_pad, ilens, None
 
+
+class EncoderLayerSANMExport(nn.Module):
+    def __init__(
+        self,
+        model,
+    ):
+        """Construct an EncoderLayer object."""
+        super().__init__()
+        self.self_attn = model.self_attn
+        self.feed_forward = model.feed_forward
+        self.norm1 = model.norm1
+        self.norm2 = model.norm2
+        self.in_size = model.in_size
+        self.size = model.size
+
+    def forward(self, x, mask):
+
+        residual = x
+        x = self.norm1(x)
+        x = self.self_attn(x, mask)
+        if self.in_size == self.size:
+            x = x + residual
+        residual = x
+        x = self.norm2(x)
+        x = self.feed_forward(x)
+        x = x + residual
+
+        return x, mask
+
+
+@tables.register("encoder_classes", "SANMEncoderChunkOptExport")
+@tables.register("encoder_classes", "SANMEncoderExport")
+class SANMEncoderExport(nn.Module):
+    def __init__(
+        self,
+        model,
+        max_seq_len=512,
+        feats_dim=560,
+        model_name="encoder",
+        onnx: bool = True,
+        ctc_linear: nn.Module = None,
+    ):
+        super().__init__()
+        self.embed = model.embed
+        if isinstance(self.embed, StreamSinusoidalPositionEncoder):
+            self.embed = None
+        self.model = model
+        self.feats_dim = feats_dim
+        self._output_size = model._output_size
+
+        from funasr.utils.torch_function import sequence_mask
+
+        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+        from funasr.models.sanm.attention import MultiHeadedAttentionSANMExport
+
+        if hasattr(model, "encoders0"):
+            for i, d in enumerate(self.model.encoders0):
+                if isinstance(d.self_attn, MultiHeadedAttentionSANM):
+                    d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn)
+                self.model.encoders0[i] = EncoderLayerSANMExport(d)
+
+        for i, d in enumerate(self.model.encoders):
+            if isinstance(d.self_attn, MultiHeadedAttentionSANM):
+                d.self_attn = MultiHeadedAttentionSANMExport(d.self_attn)
+            self.model.encoders[i] = EncoderLayerSANMExport(d)
+
+        self.model_name = model_name
+        self.num_heads = model.encoders[0].self_attn.h
+        self.hidden_size = model.encoders[0].self_attn.linear_out.out_features
+
+        self.ctc_linear = ctc_linear
+
+    def prepare_mask(self, mask):
+        mask_3d_btd = mask[:, :, None]
+        if len(mask.shape) == 2:
+            mask_4d_bhlt = 1 - mask[:, None, None, :]
+        elif len(mask.shape) == 3:
+            mask_4d_bhlt = 1 - mask[:, None, :]
+        mask_4d_bhlt = mask_4d_bhlt * -10000.0
+
+        return mask_3d_btd, mask_4d_bhlt
+
+    def forward(self, speech: torch.Tensor, speech_lengths: torch.Tensor, online: bool = False):
+        if not online:
+            speech = speech * self._output_size**0.5
+
+        mask = self.make_pad_mask(speech_lengths)
+        mask = self.prepare_mask(mask)
+        if self.embed is None:
+            xs_pad = speech
+        else:
+            xs_pad = self.embed(speech)
+
+        encoder_outs = self.model.encoders0(xs_pad, mask)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+        encoder_outs = self.model.encoders(xs_pad, mask)
+        xs_pad, masks = encoder_outs[0], encoder_outs[1]
+
+        xs_pad = self.model.after_norm(xs_pad)
+
+        if self.ctc_linear is not None:
+            xs_pad = self.ctc_linear(xs_pad)
+            xs_pad = F.softmax(xs_pad, dim=2)
+
+        return xs_pad, speech_lengths
+
+    def get_output_size(self):
+        return self.model.encoders[0].size
+
+    def get_dummy_inputs(self):
+        feats = torch.randn(1, 100, self.feats_dim)
+        return feats
+
+    def get_input_names(self):
+        return ["feats"]
+
+    def get_output_names(self):
+        return ["encoder_out", "encoder_out_lens", "predictor_weight"]
+
+    def get_dynamic_axes(self):
+        return {
+            "feats": {1: "feats_length"},
+            "encoder_out": {1: "enc_out_length"},
+            "predictor_weight": {1: "pre_out_length"},
+        }

--
Gitblit v1.9.1