From 2ac79cd3f312e485f3fc4f0e63313cc8a3e0bfc6 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 12 六月 2024 19:27:35 +0800
Subject: [PATCH] decoding
---
funasr/models/llm_asr/adaptor.py | 42 +++++++++++++++++++++++-------------------
1 files changed, 23 insertions(+), 19 deletions(-)
diff --git a/funasr/models/llm_asr/adaptor.py b/funasr/models/llm_asr/adaptor.py
index 9b79ed2..93534fe 100644
--- a/funasr/models/llm_asr/adaptor.py
+++ b/funasr/models/llm_asr/adaptor.py
@@ -83,25 +83,27 @@
from funasr.models.transformer.attention import MultiHeadedAttention
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
- self.blocks = nn.ModuleList(
- [
- EncoderLayer(
- llm_dim,
- MultiHeadedAttention(
- kwargs.get("attention_heads", 8),
+ self.blocks = None
+ if kwargs.get("n_layer", 2) > 0:
+ self.blocks = nn.ModuleList(
+ [
+ EncoderLayer(
llm_dim,
- kwargs.get("attention_dropout_rate", 0.0),
- ),
- PositionwiseFeedForward(
- llm_dim,
- llm_dim // 4,
+ MultiHeadedAttention(
+ kwargs.get("attention_heads", 8),
+ llm_dim,
+ kwargs.get("attention_dropout_rate", 0.0),
+ ),
+ PositionwiseFeedForward(
+ llm_dim,
+ llm_dim // 4,
+ kwargs.get("dropout_rate", 0.0),
+ ),
kwargs.get("dropout_rate", 0.0),
- ),
- kwargs.get("dropout_rate", 0.0),
- )
- for i in range(kwargs.get("n_layer", 2))
- ]
- )
+ )
+ for i in range(kwargs.get("n_layer", 2))
+ ]
+ )
def forward(self, x, ilens=None):
@@ -123,6 +125,8 @@
olens = None
olens = (ilens - 1) // self.k + 1
masks = (~make_pad_mask(olens)[:, None, :]).to(x.device)
- for layer, block in enumerate(self.blocks):
- x, masks = block(x, masks)
+
+ if self.blocks is not None:
+ for layer, block in enumerate(self.blocks):
+ x, masks = block(x, masks)
return x, olens
--
Gitblit v1.9.1