From 8c7b7e5feb68fda1fc4ddd627bad0f915358149e Mon Sep 17 00:00:00 2001
From: Zhanzhao (Deo) Liang <liangzhanzhao1985@gmail.com>
Date: 星期三, 25 十二月 2024 16:40:29 +0800
Subject: [PATCH] fix export_meta import of sense voice (#2334)
---
funasr/models/contextual_paraformer/decoder.py | 13 ++++++-------
1 files changed, 6 insertions(+), 7 deletions(-)
diff --git a/funasr/models/contextual_paraformer/decoder.py b/funasr/models/contextual_paraformer/decoder.py
index 0b30c99..958c46b 100644
--- a/funasr/models/contextual_paraformer/decoder.py
+++ b/funasr/models/contextual_paraformer/decoder.py
@@ -137,7 +137,7 @@
concat_after: bool = False,
att_layer_num: int = 6,
kernel_size: int = 21,
- sanm_shfit: int = 0,
+ sanm_shift: int = 0,
):
super().__init__(
vocab_size=vocab_size,
@@ -179,14 +179,14 @@
self.att_layer_num = att_layer_num
self.num_blocks = num_blocks
- if sanm_shfit is None:
- sanm_shfit = (kernel_size - 1) // 2
+ if sanm_shift is None:
+ sanm_shift = (kernel_size - 1) // 2
self.decoders = repeat(
att_layer_num - 1,
lambda lnum: DecoderLayerSANM(
attention_dim,
MultiHeadedAttentionSANMDecoder(
- attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
+ attention_dim, self_attention_dropout_rate, kernel_size, sanm_shift=sanm_shift
),
MultiHeadedAttentionCrossAtt(
attention_heads, attention_dim, src_attention_dropout_rate
@@ -210,7 +210,7 @@
self.last_decoder = ContextualDecoderLayer(
attention_dim,
MultiHeadedAttentionSANMDecoder(
- attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=sanm_shfit
+ attention_dim, self_attention_dropout_rate, kernel_size, sanm_shift=sanm_shift
),
MultiHeadedAttentionCrossAtt(
attention_heads, attention_dim, src_attention_dropout_rate
@@ -228,7 +228,7 @@
lambda lnum: DecoderLayerSANM(
attention_dim,
MultiHeadedAttentionSANMDecoder(
- attention_dim, self_attention_dropout_rate, kernel_size, sanm_shfit=0
+ attention_dim, self_attention_dropout_rate, kernel_size, sanm_shift=0
),
None,
PositionwiseFeedForwardDecoderSANM(attention_dim, linear_units, dropout_rate),
@@ -424,7 +424,6 @@
# contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
contextual_mask = self.make_pad_mask(contextual_length)
contextual_mask, _ = self.prepare_mask(contextual_mask)
- # import pdb; pdb.set_trace()
contextual_mask = contextual_mask.transpose(2, 1).unsqueeze(1)
cx, tgt_mask, _, _, _ = self.bias_decoder(
x_self_attn, tgt_mask, bias_embed, memory_mask=contextual_mask
--
Gitblit v1.9.1