From 8c87a9d8a7c2f136053476670a9a83980f142aec Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 28 六月 2024 17:28:09 +0800
Subject: [PATCH] Dev gzf deepspeed (#1858)

---
 funasr/models/llm_asr/model.py |   63 ++++++++++++++++++++++---------
 1 files changed, 45 insertions(+), 18 deletions(-)

diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index 43c044e..b4d9e7c 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -1145,6 +1145,7 @@
             fake_token_len_i = 0
             fbank_beg_i = -1
             fbank_lens_i = []
+            speech, speech_lengths = [], []
             for k, sub_str in enumerate(splits):
                 if not sub_str.startswith("<|startofspeech|>"):
                     sub_token = tokenizer.encode(sub_str)
@@ -1155,9 +1156,12 @@
                         "<|endofspeech|>", ""
                     )
                     if sub_str.startswith("!"):
+                        sub_str = sub_str[1:]
+                        if sub_str.startswith("!"):  # !!bytes
+                            sub_str = eval(sub_str[1:])
                         try:
                             time1 = time.perf_counter()
-                            data_src = load_audio_text_image_video(sub_str[1:], fs=frontend.fs)
+                            data_src = load_audio_text_image_video(sub_str, fs=frontend.fs)
                             time2 = time.perf_counter()
                             meta_data["load_data"] = f"{time2 - time1:0.3f}"
                         except Exception as e:
@@ -1203,9 +1207,10 @@
             input_source_ids = input_ids + source_ids
             input_ids += source_ids + target_ids
             labels += source_mask + target_ids
-            fbank.append(speech[0, :, :])
             fbank_mask += fbank_mask_i
-            fbank_lens.append(speech_lengths)
+            if len(speech) > 0:
+                fbank.append(speech[0, :, :])
+                fbank_lens.append(speech_lengths)
 
         input_ids = torch.tensor(input_ids, dtype=torch.int64)  # [: self.max_token_length]
         attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
@@ -1219,10 +1224,14 @@
         source_ids = torch.tensor(input_source_ids, dtype=torch.int64)
         target_ids = torch.tensor(target_ids, dtype=torch.int64)
 
-        speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0)
-        speech_lengths = torch.nn.utils.rnn.pad_sequence(
-            fbank_lens, batch_first=True, padding_value=-1
-        )
+        if len(fbank) > 0:
+            speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0)
+            speech_lengths = torch.nn.utils.rnn.pad_sequence(
+                fbank_lens, batch_first=True, padding_value=-1
+            )
+        else:
+            speech = []
+            speech_lengths = []
         output = {
             "speech": speech,
             "speech_lengths": speech_lengths,
@@ -1238,7 +1247,8 @@
 
         return output
 
-    def inference(
+
+    def inference_prepare(
         self,
         data_in,
         data_lengths=None,
@@ -1260,17 +1270,18 @@
 
         # audio encoder
         speech = batch["speech"]
-        speech_lengths = batch["speech_lengths"][:, 0]
-        # fp16
-        if kwargs.get("fp16", False):
-            speech = speech.to(torch.float16)
-        elif kwargs.get("bf16", False):
-            speech = speech.to(torch.bfloat16)
-        # audio encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        if len(speech) > 0:
+            speech_lengths = batch["speech_lengths"][:, 0]
+            # fp16
+            if kwargs.get("fp16", False):
+                speech = speech.to(torch.float16)
+            elif kwargs.get("bf16", False):
+                speech = speech.to(torch.bfloat16)
+            # audio encoder
+            encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
 
-        # audio_adaptor
-        encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
+            # audio_adaptor
+            encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
 
         input_ids = batch["input_ids"]
         source_ids = batch["source_ids"]
@@ -1316,6 +1327,22 @@
                         ] = speech_token
 
                     speech_idx += 1
+        return inputs_embeds, contents, batch, source_ids, meta_data
+    
+
+    def inference(
+        self,
+        data_in,
+        data_lengths=None,
+        key: list = None,
+        tokenizer=None,
+        frontend=None,
+        **kwargs,
+    ):
+
+        inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
+            data_in, data_lengths, key, tokenizer, frontend, **kwargs
+        )
 
         llm_dtype = kwargs.get("llm_dtype", "fp32")
         if llm_dtype == "fp32":

--
Gitblit v1.9.1