From 8c87a9d8a7c2f136053476670a9a83980f142aec Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 28 六月 2024 17:28:09 +0800
Subject: [PATCH] Dev gzf deepspeed (#1858)
---
funasr/models/llm_asr/model.py | 63 ++++++++++++++++++++++---------
1 files changed, 45 insertions(+), 18 deletions(-)
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index 43c044e..b4d9e7c 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -1145,6 +1145,7 @@
fake_token_len_i = 0
fbank_beg_i = -1
fbank_lens_i = []
+ speech, speech_lengths = [], []
for k, sub_str in enumerate(splits):
if not sub_str.startswith("<|startofspeech|>"):
sub_token = tokenizer.encode(sub_str)
@@ -1155,9 +1156,12 @@
"<|endofspeech|>", ""
)
if sub_str.startswith("!"):
+ sub_str = sub_str[1:]
+ if sub_str.startswith("!"): # !!bytes
+ sub_str = eval(sub_str[1:])
try:
time1 = time.perf_counter()
- data_src = load_audio_text_image_video(sub_str[1:], fs=frontend.fs)
+ data_src = load_audio_text_image_video(sub_str, fs=frontend.fs)
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
except Exception as e:
@@ -1203,9 +1207,10 @@
input_source_ids = input_ids + source_ids
input_ids += source_ids + target_ids
labels += source_mask + target_ids
- fbank.append(speech[0, :, :])
fbank_mask += fbank_mask_i
- fbank_lens.append(speech_lengths)
+ if len(speech) > 0:
+ fbank.append(speech[0, :, :])
+ fbank_lens.append(speech_lengths)
input_ids = torch.tensor(input_ids, dtype=torch.int64) # [: self.max_token_length]
attention_mask = torch.tensor([1] * len(input_ids), dtype=torch.int32)
@@ -1219,10 +1224,14 @@
source_ids = torch.tensor(input_source_ids, dtype=torch.int64)
target_ids = torch.tensor(target_ids, dtype=torch.int64)
- speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0)
- speech_lengths = torch.nn.utils.rnn.pad_sequence(
- fbank_lens, batch_first=True, padding_value=-1
- )
+ if len(fbank) > 0:
+ speech = torch.nn.utils.rnn.pad_sequence(fbank, batch_first=True, padding_value=0.0)
+ speech_lengths = torch.nn.utils.rnn.pad_sequence(
+ fbank_lens, batch_first=True, padding_value=-1
+ )
+ else:
+ speech = []
+ speech_lengths = []
output = {
"speech": speech,
"speech_lengths": speech_lengths,
@@ -1238,7 +1247,8 @@
return output
- def inference(
+
+ def inference_prepare(
self,
data_in,
data_lengths=None,
@@ -1260,17 +1270,18 @@
# audio encoder
speech = batch["speech"]
- speech_lengths = batch["speech_lengths"][:, 0]
- # fp16
- if kwargs.get("fp16", False):
- speech = speech.to(torch.float16)
- elif kwargs.get("bf16", False):
- speech = speech.to(torch.bfloat16)
- # audio encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ if len(speech) > 0:
+ speech_lengths = batch["speech_lengths"][:, 0]
+ # fp16
+ if kwargs.get("fp16", False):
+ speech = speech.to(torch.float16)
+ elif kwargs.get("bf16", False):
+ speech = speech.to(torch.bfloat16)
+ # audio encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- # audio_adaptor
- encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
+ # audio_adaptor
+ encoder_out, encoder_out_lens = self.audio_adaptor(encoder_out, encoder_out_lens)
input_ids = batch["input_ids"]
source_ids = batch["source_ids"]
@@ -1316,6 +1327,22 @@
] = speech_token
speech_idx += 1
+ return inputs_embeds, contents, batch, source_ids, meta_data
+
+
+ def inference(
+ self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+
+ inputs_embeds, contents, batch, source_ids, meta_data = self.inference_prepare(
+ data_in, data_lengths, key, tokenizer, frontend, **kwargs
+ )
llm_dtype = kwargs.get("llm_dtype", "fp32")
if llm_dtype == "fp32":
--
Gitblit v1.9.1