From e9acc5db07daa51a22cd51ea9233ee09a38d726d Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 06 六月 2024 18:36:22 +0800
Subject: [PATCH] auto frontend

---
 funasr/models/llm_asr/model.py                                              |   76 ++++++++++----------------------------
 examples/industrial_data_pretraining/llm_asr/conf/whisper_qwen_linear2.yaml |    2 
 funasr/datasets/openai_datasets/datasets.py                                 |   11 ++---
 3 files changed, 26 insertions(+), 63 deletions(-)

diff --git a/examples/industrial_data_pretraining/llm_asr/conf/whisper_qwen_linear2.yaml b/examples/industrial_data_pretraining/llm_asr/conf/whisper_qwen_linear2.yaml
index 59e93a6..62d77e2 100644
--- a/examples/industrial_data_pretraining/llm_asr/conf/whisper_qwen_linear2.yaml
+++ b/examples/industrial_data_pretraining/llm_asr/conf/whisper_qwen_linear2.yaml
@@ -35,7 +35,7 @@
 frontend_conf:
     fs: 16000
     whisper_model: large-v3
-    do_pad_trim: true
+    do_pad_trim: false
     permute: false # true: [bs, frames, dims]; false: [bs, dims, frames]
     filters_path: "/nfs/zhifu.gzf/init_model/SenseVoiceModelscope/assets/mel_filters.npz"
 
diff --git a/funasr/datasets/openai_datasets/datasets.py b/funasr/datasets/openai_datasets/datasets.py
index 9a542ad..9bd0698 100644
--- a/funasr/datasets/openai_datasets/datasets.py
+++ b/funasr/datasets/openai_datasets/datasets.py
@@ -123,21 +123,20 @@
                             )  # speech: [b, T, d]
                             if self.permute:
                                 speech = speech.permute(0, 2, 1)
-                            if speech_lengths > self.batch_size:
-                                continue
+                            # if speech_lengths > self.batch_size:
+                            #     continue
 
-                            fbank_lens = speech_lengths[0].item()
-                            olens = 1 + (fbanks_len - 3 + 2 * 1) // 2
+                            olens = 1 + (speech_lengths[0].item() - 3 + 2 * 1) // 2
                             olens = 1 + (olens - 3 + 2 * 1) // 2
                             sub_token_len = (olens - 1) // 2 + 1
-                            sub_token = [0] * sub_token_len[0]
+                            sub_token = [0] * sub_token_len
                             fbank_beg_i = [len(source_ids)]
                             source_ids += sub_token
                             fbank_mask_i += [1] * len(sub_token)
 
                 source_mask = [-100] * len(source_ids)
                 target_out = f"{target_out}<|im_end|>"
-                target_ids = tokenizer.encode(target_out)
+                target_ids = self.tokenizer.encode(target_out)
                 input_ids += source_ids + target_ids
                 labels += source_mask + target_ids
                 fbank_mask += fbank_mask_i
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index 11db009..411b59d 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -385,13 +385,6 @@
 
         super().__init__()
 
-        if specaug is not None:
-            specaug_class = tables.specaug_classes.get(specaug)
-            specaug = specaug_class(**specaug_conf)
-        if normalize is not None:
-            normalize_class = tables.normalize_classes.get(normalize)
-            normalize = normalize_class(**normalize_conf)
-
         # audio encoder
         hub = audio_encoder_conf.get("hub", None)
         if hub == "ms":
@@ -422,23 +415,23 @@
         # llm
         hub = llm_conf.get("hub", "hf")
         self.llm = None
-        # if hub == "hf":
-        #     from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
-        #
-        #     init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
-        #
-        #     model = AutoModelForCausalLM.from_pretrained(
-        #         init_param_path,
-        #         load_in_8bit=None,
-        #         device_map=None,
-        #         use_cache=None,
-        #     )
-        #     freeze = llm_conf.get("freeze", True)
-        #     if freeze:
-        #         for name, param in model.named_parameters():
-        #             param.requires_grad = False
-        #         model.eval()
-        #     self.llm = model
+        if hub == "hf":
+            from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
+
+            init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
+
+            model = AutoModelForCausalLM.from_pretrained(
+                init_param_path,
+                load_in_8bit=None,
+                device_map=None,
+                use_cache=None,
+            )
+            freeze = llm_conf.get("freeze", True)
+            if freeze:
+                for name, param in model.named_parameters():
+                    param.requires_grad = False
+                model.eval()
+            self.llm = model
 
         # adaptor
         adaptor_class = tables.adaptor_classes.get(audio_adaptor)
@@ -446,21 +439,6 @@
         audio_adaptor = adaptor_class(**audio_adaptor_conf)
 
         self.audio_adaptor = audio_adaptor
-
-        self.blank_id = blank_id
-        self.sos = sos if sos is not None else vocab_size - 1
-        self.eos = eos if eos is not None else vocab_size - 1
-        self.vocab_size = vocab_size
-        self.ignore_id = ignore_id
-        self.specaug = specaug
-        self.normalize = normalize
-
-        self.criterion_att = LabelSmoothingLoss(
-            size=vocab_size,
-            padding_idx=ignore_id,
-            smoothing=lsm_weight,
-            normalize_length=length_normalized_loss,
-        )
 
         self.error_calculator = None
 
@@ -493,10 +471,10 @@
         batch_size = speech.shape[0]
 
         # audio encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        encoder_out, encoder_out_lens = self.audio_encoder(speech.permute(0, 2, 1), speech_lengths)
 
         # audio_adaptor
-        encoder_out = self.audio_adaptor(encoder_out)
+        encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
 
         input_ids[input_ids == -1] = 0
         input_ids[input_ids == -100] = 0
@@ -530,23 +508,9 @@
 
         # force_gatherable: to-device and to-tensor if scalar for DataParallel
         if self.length_normalized_loss:
-            batch_size = int((text_lengths + 1).sum())
+            batch_size = int((labels_ids > 0 + 1).sum())
         loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
         return loss, stats, weight
-
-    def encode(
-        self,
-        speech: torch.Tensor,
-        speech_lengths: torch.Tensor,
-        **kwargs,
-    ):
-        speech = speech.permute(0, 2, 1)
-        res = self.audio_encoder(speech)
-        if isinstance(res, (list, tuple)):
-            encoder_out, encoder_out_lens = res[0], res[1]
-        else:
-            encoder_out, encoder_out_lens = res, speech_lengths
-        return encoder_out, encoder_out_lens
 
     def inference(
         self,

--
Gitblit v1.9.1