From 162efb747f21f305851c682e5f7f0f3050d545a9 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 07 六月 2024 16:18:18 +0800
Subject: [PATCH] auto frontend
---
funasr/models/llm_asr/model.py | 2 +-
funasr/models/llm_asr/adaptor.py | 41 ++++++++++++++++++++++-------------------
2 files changed, 23 insertions(+), 20 deletions(-)
diff --git a/funasr/models/llm_asr/adaptor.py b/funasr/models/llm_asr/adaptor.py
index 9b79ed2..c939883 100644
--- a/funasr/models/llm_asr/adaptor.py
+++ b/funasr/models/llm_asr/adaptor.py
@@ -83,25 +83,27 @@
from funasr.models.transformer.attention import MultiHeadedAttention
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
- self.blocks = nn.ModuleList(
- [
- EncoderLayer(
- llm_dim,
- MultiHeadedAttention(
- kwargs.get("attention_heads", 8),
+ self.blocks = None
+ if kwargs.get("n_layer", 2) > 0:
+ self.blocks = nn.ModuleList(
+ [
+ EncoderLayer(
llm_dim,
- kwargs.get("attention_dropout_rate", 0.0),
- ),
- PositionwiseFeedForward(
- llm_dim,
- llm_dim // 4,
+ MultiHeadedAttention(
+ kwargs.get("attention_heads", 8),
+ llm_dim,
+ kwargs.get("attention_dropout_rate", 0.0),
+ ),
+ PositionwiseFeedForward(
+ llm_dim,
+ llm_dim // 4,
+ kwargs.get("dropout_rate", 0.0),
+ ),
kwargs.get("dropout_rate", 0.0),
- ),
- kwargs.get("dropout_rate", 0.0),
- )
- for i in range(kwargs.get("n_layer", 2))
- ]
- )
+ )
+ for i in range(kwargs.get("n_layer", 2))
+ ]
+ )
def forward(self, x, ilens=None):
@@ -123,6 +125,7 @@
olens = None
olens = (ilens - 1) // self.k + 1
masks = (~make_pad_mask(olens)[:, None, :]).to(x.device)
- for layer, block in enumerate(self.blocks):
- x, masks = block(x, masks)
+ if self.blocks is not None:
+ for layer, block in enumerate(self.blocks):
+ x, masks = block(x, masks)
return x, olens
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index b139123..d94058c 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -481,7 +481,7 @@
batch_size, token_num, dims = inputs_embeds.shape
fbank_mask[fbank_mask < 0] = 0
- fbank_fake_lens = fbank_mask.sum(-1)
+ fbank_fake_lens = fbank_mask.sum(-1).to(torch.int32)
# _, l, _ = encoder_out.shape
for batch_idx in range(batch_size):
--
Gitblit v1.9.1