From 861147c7308b91068ffa02724fdf74ee623a909e Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 24 四月 2024 16:03:38 +0800
Subject: [PATCH] Dev gzf exp (#1654)

---
 funasr/datasets/sense_voice_datasets/datasets.py |   83 +++++++++++++++++++++++------------------
 1 files changed, 46 insertions(+), 37 deletions(-)

diff --git a/funasr/datasets/sense_voice_datasets/datasets.py b/funasr/datasets/sense_voice_datasets/datasets.py
index 956cf79..5468ea6 100644
--- a/funasr/datasets/sense_voice_datasets/datasets.py
+++ b/funasr/datasets/sense_voice_datasets/datasets.py
@@ -10,28 +10,33 @@
     """
     SenseVoiceDataset
     """
-    def __init__(self,
-                 path,
-                 index_ds: str = None,
-                 frontend=None,
-                 tokenizer=None,
-                 int_pad_value: int = -1,
-                 float_pad_value: float = 0.0,
-                  **kwargs):
+
+    def __init__(
+        self,
+        path,
+        index_ds: str = None,
+        frontend=None,
+        tokenizer=None,
+        int_pad_value: int = -1,
+        float_pad_value: float = 0.0,
+        **kwargs,
+    ):
         super().__init__()
         index_ds_class = tables.index_ds_classes.get(index_ds)
         self.index_ds = index_ds_class(path, **kwargs)
         preprocessor_speech = kwargs.get("preprocessor_speech", None)
         if preprocessor_speech:
             preprocessor_speech_class = tables.preprocessor_classes.get(preprocessor_speech)
-            preprocessor_speech = preprocessor_speech_class(**kwargs.get("preprocessor_speech_conf"))
+            preprocessor_speech = preprocessor_speech_class(
+                **kwargs.get("preprocessor_speech_conf")
+            )
         self.preprocessor_speech = preprocessor_speech
         preprocessor_text = kwargs.get("preprocessor_text", None)
         if preprocessor_text:
             preprocessor_text_class = tables.preprocessor_classes.get(preprocessor_text)
             preprocessor_text = preprocessor_text_class(**kwargs.get("preprocessor_text_conf"))
         self.preprocessor_text = preprocessor_text
-        
+
         self.frontend = frontend
         self.fs = 16000 if frontend is None else frontend.fs
         self.data_type = "sound"
@@ -41,18 +46,18 @@
         self.float_pad_value = float_pad_value
         self.sos = kwargs.get("sos", "<|startoftranscript|>")
         self.eos = kwargs.get("eos", "<|endoftext|>")
-    
+
     def get_source_len(self, index):
         item = self.index_ds[index]
         return self.index_ds.get_source_len(item)
-    
+
     def get_target_len(self, index):
         item = self.index_ds[index]
         return self.index_ds.get_target_len(item)
-    
+
     def __len__(self):
         return len(self.index_ds)
-    
+
     def __getitem__(self, index):
         item = self.index_ds[index]
         # import pdb;
@@ -61,42 +66,46 @@
         data_src = load_audio_text_image_video(source, fs=self.fs)
         if self.preprocessor_speech:
             data_src = self.preprocessor_speech(data_src, fs=self.fs)
-        speech, speech_lengths = extract_fbank(data_src, data_type=self.data_type, frontend=self.frontend, is_final=True) # speech: [b, T, d]
+        speech, speech_lengths = extract_fbank(
+            data_src, data_type=self.data_type, frontend=self.frontend, is_final=True
+        )  # speech: [b, T, d]
         speech = speech.permute(0, 2, 1)
         target = item["target"]
         if self.preprocessor_text:
             target = self.preprocessor_text(target)
-        
+
         task = item.get("prompt", "<|ASR|>")
         text_language = item.get("text_language", "<|zh|>")
 
         prompt = f"{self.sos}{task}{text_language}"
         prompt_ids = self.tokenizer.encode(prompt, allowed_special="all")
-        prompt_ids_len = len(prompt_ids) - 1 # [sos, task]
+        prompt_ids_len = len(prompt_ids) - 1  # [sos, task]
 
         target_ids = self.tokenizer.encode(target, allowed_special="all")
-        target_ids_len = len(target_ids) + 1 # [lid, text]
-        
-        eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos]
-        
+        target_ids_len = len(target_ids) + 1  # [lid, text]
+
+        eos = self.tokenizer.encode(self.eos, allowed_special="all")  # [eos]
+
         ids = prompt_ids + target_ids + eos
         ids_lengths = len(ids)
-        
+
         text = torch.tensor(ids, dtype=torch.int64)
         text_lengths = torch.tensor([ids_lengths], dtype=torch.int32)
 
-        target_mask = [0] * (prompt_ids_len) + [1] * (target_ids_len) + [1]  # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1]
+        target_mask = (
+            [0] * (prompt_ids_len) + [1] * (target_ids_len) + [1]
+        )  # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1]
         target_mask = torch.tensor(target_mask, dtype=torch.float32)
 
-        return {"speech": speech[0, :, :],
-                "speech_lengths": speech_lengths,
-                "text": text,
-                "text_lengths": text_lengths,
-                "target_mask": target_mask,
-                }
-    
-    
-    def collator(self, samples: list=None):
+        return {
+            "speech": speech[0, :, :],
+            "speech_lengths": speech_lengths,
+            "text": text,
+            "text_lengths": text_lengths,
+            "target_mask": target_mask,
+        }
+
+    def collator(self, samples: list = None):
         outputs = {}
         for sample in samples:
             for key in sample.keys():
@@ -107,12 +116,12 @@
         for key, data_list in outputs.items():
             if isinstance(data_list[0], torch.Tensor):
                 if data_list[0].dtype == torch.int64 or data_list[0].dtype == torch.int32:
-    
+
                     pad_value = self.int_pad_value
                 else:
                     pad_value = self.float_pad_value
-                
-                outputs[key] = torch.nn.utils.rnn.pad_sequence(data_list, batch_first=True, padding_value=pad_value)
+
+                outputs[key] = torch.nn.utils.rnn.pad_sequence(
+                    data_list, batch_first=True, padding_value=pad_value
+                )
         return outputs
-
-

--
Gitblit v1.9.1