From 162efb747f21f305851c682e5f7f0f3050d545a9 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 07 六月 2024 16:18:18 +0800
Subject: [PATCH] auto frontend

---
 funasr/models/llm_asr/model.py   |    2 +-
 funasr/models/llm_asr/adaptor.py |   41 ++++++++++++++++++++++-------------------
 2 files changed, 23 insertions(+), 20 deletions(-)

diff --git a/funasr/models/llm_asr/adaptor.py b/funasr/models/llm_asr/adaptor.py
index 9b79ed2..c939883 100644
--- a/funasr/models/llm_asr/adaptor.py
+++ b/funasr/models/llm_asr/adaptor.py
@@ -83,25 +83,27 @@
         from funasr.models.transformer.attention import MultiHeadedAttention
         from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
 
-        self.blocks = nn.ModuleList(
-            [
-                EncoderLayer(
-                    llm_dim,
-                    MultiHeadedAttention(
-                        kwargs.get("attention_heads", 8),
+        self.blocks = None
+        if kwargs.get("n_layer", 2) > 0:
+            self.blocks = nn.ModuleList(
+                [
+                    EncoderLayer(
                         llm_dim,
-                        kwargs.get("attention_dropout_rate", 0.0),
-                    ),
-                    PositionwiseFeedForward(
-                        llm_dim,
-                        llm_dim // 4,
+                        MultiHeadedAttention(
+                            kwargs.get("attention_heads", 8),
+                            llm_dim,
+                            kwargs.get("attention_dropout_rate", 0.0),
+                        ),
+                        PositionwiseFeedForward(
+                            llm_dim,
+                            llm_dim // 4,
+                            kwargs.get("dropout_rate", 0.0),
+                        ),
                         kwargs.get("dropout_rate", 0.0),
-                    ),
-                    kwargs.get("dropout_rate", 0.0),
-                )
-                for i in range(kwargs.get("n_layer", 2))
-            ]
-        )
+                    )
+                    for i in range(kwargs.get("n_layer", 2))
+                ]
+            )
 
     def forward(self, x, ilens=None):
 
@@ -123,6 +125,7 @@
         olens = None
         olens = (ilens - 1) // self.k + 1
         masks = (~make_pad_mask(olens)[:, None, :]).to(x.device)
-        for layer, block in enumerate(self.blocks):
-            x, masks = block(x, masks)
+        if self.blocks is not None:
+            for layer, block in enumerate(self.blocks):
+                x, masks = block(x, masks)
         return x, olens
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index b139123..d94058c 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -481,7 +481,7 @@
 
         batch_size, token_num, dims = inputs_embeds.shape
         fbank_mask[fbank_mask < 0] = 0
-        fbank_fake_lens = fbank_mask.sum(-1)
+        fbank_fake_lens = fbank_mask.sum(-1).to(torch.int32)
         # _, l, _ = encoder_out.shape
         for batch_idx in range(batch_size):
 

--
Gitblit v1.9.1