From 2acd24f0158b2c86d2fb4e6f1134b67a1150500e Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 29 二月 2024 17:14:59 +0800
Subject: [PATCH] update whisper lid (#1407)

---
 funasr/datasets/llm_datasets/datasets.py |   15 ++++++++-------
 1 files changed, 8 insertions(+), 7 deletions(-)

diff --git a/funasr/datasets/llm_datasets/datasets.py b/funasr/datasets/llm_datasets/datasets.py
index 20eb8aa..9673d76 100644
--- a/funasr/datasets/llm_datasets/datasets.py
+++ b/funasr/datasets/llm_datasets/datasets.py
@@ -24,12 +24,12 @@
         preprocessor_speech = kwargs.get("preprocessor_speech", None)
         if preprocessor_speech:
             preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
-            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
+            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf", {}))
         self.preprocessor_speech = preprocessor_speech
         preprocessor_text = kwargs.get("preprocessor_text", None)
         if preprocessor_text:
             preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
-            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
+            preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf", {}))
         self.preprocessor_text = preprocessor_text
         
         self.frontend = frontend
@@ -37,12 +37,13 @@
         self.data_type = "sound"
         self.tokenizer = tokenizer
 
-        self.int_pad_value = int_pad_value
         self.float_pad_value = float_pad_value
         self.prompt = kwargs.get("prompt", "Transcribe speech to text.")
         self.prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(
             self.prompt)  # "USER: \nINSTRUCTION: {}\nnINPUT: {}\nASSISTANT: "
         self.prompt_af = ""
+        self.IGNORE_INDEX = kwargs.get("IGNORE_INDEX", -100)
+        self.int_pad_value = self.IGNORE_INDEX
     
     def get_source_len(self, index):
         item = self.index_ds[index]
@@ -64,7 +65,7 @@
         if self.preprocessor_speech:
             data_src = self.preprocessor_speech(data_src, fs=self.fs)
         speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d]
-        speech = speech.sequeeze(0)
+        speech = speech.squeeze(0)
 
         target = item["target"]
         if self.preprocessor_text:
@@ -91,10 +92,10 @@
         label_mask = labels_ids.ge(0)  # [False,False,True,True]
         labels_ids[~label_mask] = self.IGNORE_INDEX  # [-100,-100,input,eos]
         
-        audio_mask = [0] * prompt_pre_length + [1] * audio_length
-        torch.tensor(audio_mask, dtype=torch.float32)
+        audio_mask = [0] * prompt_pre_length + [1] * audio_length + [0]
+        audio_mask = torch.tensor(audio_mask, dtype=torch.float32)
         
-        ids = self.tokenizer.encode(target)
+        ids = self.tokenizer.encode(target) # token ids is different from labels_ids
         text = torch.tensor(ids, dtype=torch.int64)
         text_lengths = torch.tensor([len(ids)], dtype=torch.int32)
         

--
Gitblit v1.9.1