From b75d1e89bb2f513a79bb07e9100ba1cd2bbcf40c Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 09 六月 2024 00:32:57 +0800
Subject: [PATCH] fix bug
---
funasr/models/llm_asr/adaptor.py | 41 ++++++++++++++++++++++-------------------
1 files changed, 22 insertions(+), 19 deletions(-)
diff --git a/funasr/models/llm_asr/adaptor.py b/funasr/models/llm_asr/adaptor.py
index 9b79ed2..c939883 100644
--- a/funasr/models/llm_asr/adaptor.py
+++ b/funasr/models/llm_asr/adaptor.py
@@ -83,25 +83,27 @@
from funasr.models.transformer.attention import MultiHeadedAttention
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
- self.blocks = nn.ModuleList(
- [
- EncoderLayer(
- llm_dim,
- MultiHeadedAttention(
- kwargs.get("attention_heads", 8),
+ self.blocks = None
+ if kwargs.get("n_layer", 2) > 0:
+ self.blocks = nn.ModuleList(
+ [
+ EncoderLayer(
llm_dim,
- kwargs.get("attention_dropout_rate", 0.0),
- ),
- PositionwiseFeedForward(
- llm_dim,
- llm_dim // 4,
+ MultiHeadedAttention(
+ kwargs.get("attention_heads", 8),
+ llm_dim,
+ kwargs.get("attention_dropout_rate", 0.0),
+ ),
+ PositionwiseFeedForward(
+ llm_dim,
+ llm_dim // 4,
+ kwargs.get("dropout_rate", 0.0),
+ ),
kwargs.get("dropout_rate", 0.0),
- ),
- kwargs.get("dropout_rate", 0.0),
- )
- for i in range(kwargs.get("n_layer", 2))
- ]
- )
+ )
+ for i in range(kwargs.get("n_layer", 2))
+ ]
+ )
def forward(self, x, ilens=None):
@@ -123,6 +125,7 @@
olens = None
olens = (ilens - 1) // self.k + 1
masks = (~make_pad_mask(olens)[:, None, :]).to(x.device)
- for layer, block in enumerate(self.blocks):
- x, masks = block(x, masks)
+ if self.blocks is not None:
+ for layer, block in enumerate(self.blocks):
+ x, masks = block(x, masks)
return x, olens
--
Gitblit v1.9.1