From 32e783664534bbb8d3b8ba64c2c2ecb42398eb00 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 06 六月 2024 09:54:35 +0800
Subject: [PATCH] update with main (#1786)

---
 funasr/models/sense_voice/model.py                |   28 +++++++--
 funasr/models/sense_voice/decoder.py              |    1 
 funasr/datasets/audio_datasets/espnet_samplers.py |    2 
 funasr/models/transformer/encoder.py              |    2 
 funasr/train_utils/trainer_ds.py                  |    1 
 funasr/auto/auto_frontend.py                      |   12 ++--
 funasr/models/sense_voice/whisper_lib/model.py    |   19 +++++
 funasr/models/llm_asr/adaptor.py                  |   63 +++++++++++++++++++++
 8 files changed, 111 insertions(+), 17 deletions(-)

diff --git a/funasr/auto/auto_frontend.py b/funasr/auto/auto_frontend.py
index 696a51e..501d1ab 100644
--- a/funasr/auto/auto_frontend.py
+++ b/funasr/auto/auto_frontend.py
@@ -60,7 +60,7 @@
 
         result_list = []
         num_samples = len(data_list)
-        pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
+        # pbar = tqdm(colour="blue", total=num_samples + 1, dynamic_ncols=True)
 
         time0 = time.perf_counter()
         for beg_idx in range(0, num_samples, batch_size):
@@ -95,15 +95,15 @@
                 "input": speech,
                 "input_len": speech_lengths,
                 "key": key_batch,
-                data_type: "fbank",
+                "data_type": "fbank",
             }
             result_list.append(batch)
 
-            pbar.update(1)
-            description = f"{meta_data}, "
-            pbar.set_description(description)
+            # pbar.update(1)
+            # description = f"{meta_data}, "
+            # pbar.set_description(description)
 
         time_end = time.perf_counter()
-        pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
+        # pbar.set_description(f"time escaped total: {time_end - time0:0.3f}")
 
         return result_list
diff --git a/funasr/datasets/audio_datasets/espnet_samplers.py b/funasr/datasets/audio_datasets/espnet_samplers.py
index b358fa3..004201e 100644
--- a/funasr/datasets/audio_datasets/espnet_samplers.py
+++ b/funasr/datasets/audio_datasets/espnet_samplers.py
@@ -147,7 +147,9 @@
         start_idx = self.rank * batches_per_rank
         end_idx = start_idx + batches_per_rank
         rank_batches = buffer_batches[start_idx + self.start_step : end_idx]
+
         self.batch_num = len(rank_batches)
+
         logging.info(
             f"rank: {self.rank}, dataloader start from step: {self.start_step}, batch_num: {end_idx-start_idx}, batch_num_after_step: {len(rank_batches)}"
         )
diff --git a/funasr/models/llm_asr/adaptor.py b/funasr/models/llm_asr/adaptor.py
index 8c2a804..9b79ed2 100644
--- a/funasr/models/llm_asr/adaptor.py
+++ b/funasr/models/llm_asr/adaptor.py
@@ -1,5 +1,7 @@
 import torch
 import torch.nn as nn
+import torch.nn.functional as F
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
 
 from funasr.register import tables
 
@@ -63,3 +65,64 @@
         query_proj = self.norm(self.linear(query_output.last_hidden_state))
 
         return query_proj
+
+
+@tables.register("adaptor_classes", "Transformer")
+class Transformer(nn.Module):
+    def __init__(
+        self, downsample_rate=2, encoder_dim=1280, llm_dim=4096, ffn_dim: int = 2048, **kwargs
+    ):
+        super().__init__()
+        self.k = downsample_rate
+        self.encoder_dim = encoder_dim
+        self.llm_dim = llm_dim
+        self.linear1 = nn.Linear(self.encoder_dim * self.k, ffn_dim)
+        self.relu = nn.ReLU()
+        self.linear2 = nn.Linear(ffn_dim, self.llm_dim)
+        from funasr.models.transformer.encoder import EncoderLayer
+        from funasr.models.transformer.attention import MultiHeadedAttention
+        from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
+
+        self.blocks = nn.ModuleList(
+            [
+                EncoderLayer(
+                    llm_dim,
+                    MultiHeadedAttention(
+                        kwargs.get("attention_heads", 8),
+                        llm_dim,
+                        kwargs.get("attention_dropout_rate", 0.0),
+                    ),
+                    PositionwiseFeedForward(
+                        llm_dim,
+                        llm_dim // 4,
+                        kwargs.get("dropout_rate", 0.0),
+                    ),
+                    kwargs.get("dropout_rate", 0.0),
+                )
+                for i in range(kwargs.get("n_layer", 2))
+            ]
+        )
+
+    def forward(self, x, ilens=None):
+
+        batch_size, seq_len, dim = x.size()
+        # num_frames_to_discard = seq_len % self.k
+        chunk_num = (seq_len - 1) // self.k + 1
+        pad_num = chunk_num * self.k - seq_len
+        x = F.pad(x, (0, 0, 0, pad_num, 0, 0), value=0.0)
+        # if num_frames_to_discard > 0:
+        #     x = x[:, :-num_frames_to_discard, :]
+        seq_len = x.size(1)
+
+        x = x.contiguous()
+        x = x.view(batch_size, chunk_num, dim * self.k)
+        x = self.linear1(x)
+        x = self.relu(x)
+        x = self.linear2(x)
+
+        olens = None
+        olens = (ilens - 1) // self.k + 1
+        masks = (~make_pad_mask(olens)[:, None, :]).to(x.device)
+        for layer, block in enumerate(self.blocks):
+            x, masks = block(x, masks)
+        return x, olens
diff --git a/funasr/models/sense_voice/decoder.py b/funasr/models/sense_voice/decoder.py
index 60af29a..ff933d7 100644
--- a/funasr/models/sense_voice/decoder.py
+++ b/funasr/models/sense_voice/decoder.py
@@ -360,6 +360,7 @@
         """Score."""
         ys_mask = subsequent_mask(len(ys), device=x.device).unsqueeze(0)
         logp = self.forward(ys.unsqueeze(0), x.unsqueeze(0), cache=state)
+        logp = torch.log_softmax(logp, dim=-1)
         return logp.squeeze(0)[-1, :], state
 
 
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index 127d5a0..22272ee 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -1264,15 +1264,29 @@
         if isinstance(task, str):
             task = [task]
         task = "".join([f"<|{x}|>" for x in task])
-        initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
+        
+        sos = kwargs.get("model_conf").get("sos")
+        if isinstance(sos, str):
+            initial_prompt = kwargs.get("initial_prompt", f"<|startoftranscript|>{task}")
 
-        language = DecodingOptions.get("language", None)
-        language = None if language == "auto" else language
+            language = DecodingOptions.get("language", None)
+            language = None if language == "auto" else language
 
-        sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
-        sos_int = tokenizer.encode(sos, allowed_special="all")
+            sos = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
+            sos_int = tokenizer.encode(sos, allowed_special="all")
+        else:
+            language = DecodingOptions.get("language", None)
+            language = None if language == "auto" else language
+            initial_prompt = kwargs.get("initial_prompt", f"{task}")
+            initial_prompt_lid = f"{initial_prompt}<|{language}|>" if language is not None else initial_prompt
+            initial_prompt_lid_int = tokenizer.encode(initial_prompt_lid, allowed_special="all")
+            sos_int = [sos] + initial_prompt_lid_int
         eos = kwargs.get("model_conf").get("eos")
-        eos_int = tokenizer.encode(eos, allowed_special="all")
+        if isinstance(eos, str):
+            eos_int = tokenizer.encode(eos, allowed_special="all")
+        else:
+            eos_int = [eos]
+
         self.beam_search.sos = sos_int
         self.beam_search.eos = eos_int[0]
 
@@ -1298,7 +1312,7 @@
         self.beam_search.event_score_ga = DecodingOptions.get("gain_tokens_score", [1, 1, 1, 1])
 
         encoder_out, encoder_out_lens = self.encode(
-            speech[None, :, :].permute(0, 2, 1), speech_lengths
+            speech[None, :, :], speech_lengths
         )
 
         if text_token_int is not None:
diff --git a/funasr/models/sense_voice/whisper_lib/model.py b/funasr/models/sense_voice/whisper_lib/model.py
index 8b3d3ab..3d0d6a8 100644
--- a/funasr/models/sense_voice/whisper_lib/model.py
+++ b/funasr/models/sense_voice/whisper_lib/model.py
@@ -27,9 +27,24 @@
     n_text_layer: int
 
 
+# class LayerNorm(nn.LayerNorm):
+#     def forward(self, x: Tensor) -> Tensor:
+#         return super().forward(x.float()).type(x.dtype)
+
+
 class LayerNorm(nn.LayerNorm):
-    def forward(self, x: Tensor) -> Tensor:
-        return super().forward(x.float()).type(x.dtype)
+    def __init__(self, *args, **kwargs):
+        super().__init__(*args, **kwargs)
+
+    def forward(self, input):
+        output = F.layer_norm(
+            input.float(),
+            self.normalized_shape,
+            self.weight.float() if self.weight is not None else None,
+            self.bias.float() if self.bias is not None else None,
+            self.eps,
+        )
+        return output.type_as(input)
 
 
 class Linear(nn.Linear):
diff --git a/funasr/models/transformer/encoder.py b/funasr/models/transformer/encoder.py
index a6a85ae..987924f 100644
--- a/funasr/models/transformer/encoder.py
+++ b/funasr/models/transformer/encoder.py
@@ -64,7 +64,7 @@
         stochastic_depth_rate=0.0,
     ):
         """Construct an EncoderLayer object."""
-        super(EncoderLayer, self).__init__()
+        super().__init__()
         self.self_attn = self_attn
         self.feed_forward = feed_forward
         self.norm1 = LayerNorm(size)
diff --git a/funasr/train_utils/trainer_ds.py b/funasr/train_utils/trainer_ds.py
index 1a553f8..ec887cc 100644
--- a/funasr/train_utils/trainer_ds.py
+++ b/funasr/train_utils/trainer_ds.py
@@ -621,7 +621,6 @@
             self.train_acc_avg = train_acc_avg.detach().cpu().item() / self.world_size
 
     def forward_step(self, model, batch, loss_dict={}):
-        dtype = torch.bfloat16
         with maybe_autocast(dtype=self.dtype, use_deepspeed=self.use_deepspeed):
             retval = model(**batch)
 

--
Gitblit v1.9.1