From f57b68121a526baea43b2e93f4540d8a2995f633 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 29 四月 2024 15:15:24 +0800
Subject: [PATCH] batch
---
funasr/models/llm_asr/model.py | 135 ++++++++++++++++++++++++--------------------
1 files changed, 73 insertions(+), 62 deletions(-)
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index 3223190..4345f69 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -23,7 +23,7 @@
@tables.register("model_classes", "LLMASR")
class LLMASR(nn.Module):
""" """
-
+
def __init__(
self,
specaug: str = None,
@@ -59,28 +59,29 @@
# postencoder: Optional[AbsPostEncoder] = None,
**kwargs,
):
-
+
super().__init__()
-
+
if specaug is not None:
specaug_class = tables.specaug_classes.get(specaug)
specaug = specaug_class(**specaug_conf)
if normalize is not None:
normalize_class = tables.normalize_classes.get(normalize)
normalize = normalize_class(**normalize_conf)
-
+
# audio encoder
hub = audio_encoder_conf.get("hub", None)
if hub == "ms":
from funasr import AutoModel
- model = AutoModel(model=audio_encoder, model_revision="v2.0.4")
+
+ model = AutoModel(model=audio_encoder, model_revision="master")
# frontend = model.kwargs.get("frontend")
audio_encoder_output_size = model.model.encoder_output_size
audio_encoder = model.model.model.encoder
-
+
# self.frontend = frontend
-
+
elif hub == "hf":
pass
else:
@@ -92,7 +93,7 @@
for name, param in audio_encoder.named_parameters():
param.requires_grad = False
audio_encoder.eval()
-
+
self.audio_encoder = audio_encoder
# llm
@@ -102,7 +103,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
-
+
model = AutoModelForCausalLM.from_pretrained(
init_param_path,
load_in_8bit=None,
@@ -115,15 +116,14 @@
param.requires_grad = False
model.eval()
self.llm = model
-
+
# adaptor
adaptor_class = tables.adaptor_classes.get(audio_adaptor)
audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size
audio_adaptor = adaptor_class(**audio_adaptor_conf)
-
+
self.audio_adaptor = audio_adaptor
-
-
+
self.blank_id = blank_id
self.sos = sos if sos is not None else vocab_size - 1
self.eos = eos if eos is not None else vocab_size - 1
@@ -143,7 +143,7 @@
self.length_normalized_loss = length_normalized_loss
self.beam_search = None
-
+
def forward(
self,
speech: torch.Tensor,
@@ -151,7 +151,7 @@
text: torch.Tensor,
text_lengths: torch.Tensor,
input_ids: torch.Tensor,
- attention_mask:torch.Tensor,
+ attention_mask: torch.Tensor,
labels_ids: torch.Tensor,
label_mask: torch.Tensor,
audio_mask: torch.Tensor,
@@ -170,15 +170,15 @@
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
-
+
batch_size = speech.shape[0]
-
+
# audio encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-
+
# audio_adaptor
encoder_out = self.audio_adaptor(encoder_out)
-
+
input_ids[input_ids == -1] = 0
input_ids[input_ids == -100] = 0
if hasattr(self.llm.model, "embed_tokens"):
@@ -193,11 +193,14 @@
_, l, _ = encoder_out.shape
# [audio, bos, prompt, input, pad]
encoder_outs_pad = F.pad(encoder_out, (0, 0, 0, token_num - l, 0, 0), value=0.0)
- inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (1.0-audio_mask[:, :, None])
+ inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (
+ 1.0 - audio_mask[:, :, None]
+ )
- model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
+ model_outputs = self.llm(
+ inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids
+ )
loss = model_outputs.loss
-
stats = {}
with torch.no_grad():
@@ -214,34 +217,38 @@
return loss, stats, weight
def encode(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ **kwargs,
):
speech = speech.permute(0, 2, 1)
res = self.audio_encoder(speech)
- if len(res) > 1:
+ if isinstance(res, (list, tuple)):
encoder_out, encoder_out_lens = res[0], res[1]
else:
encoder_out, encoder_out_lens = res, speech_lengths
return encoder_out, encoder_out_lens
-
- def inference(self,
- data_in,
- data_lengths=None,
- key: list = None,
- tokenizer=None,
- frontend=None,
- **kwargs,
- ):
-
+
+ def inference(
+ self,
+ data_in,
+ data_lengths=None,
+ key: list = None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+
prompt = kwargs.get("prompt", "Transcribe speech to text.")
-
+
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
-
-
meta_data = {}
- if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
+ if (
+ isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank"
+ ): # fbank
speech, speech_lengths = data_in, data_lengths
if len(speech.shape) < 3:
speech = speech[None, :, :]
@@ -250,32 +257,37 @@
else:
# extract fbank feats
time1 = time.perf_counter()
- audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
- data_type=kwargs.get("data_type", "sound"),
- tokenizer=tokenizer)
+ audio_sample_list = load_audio_text_image_video(
+ data_in,
+ fs=frontend.fs,
+ audio_fs=kwargs.get("fs", 16000),
+ data_type=kwargs.get("data_type", "sound"),
+ tokenizer=tokenizer,
+ )
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
- frontend=frontend)
+ speech, speech_lengths = extract_fbank(
+ audio_sample_list, data_type=kwargs.get("data_type", "sound"), frontend=frontend
+ )
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
-
+ meta_data["batch_data_time"] = (
+ speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+ )
+
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
-
+
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
# adaptor
encoder_out = self.audio_adaptor(encoder_out)
-
-
+
prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
prompt_ids = tokenizer.encode(prompt_pre)
prompt_length = len(prompt_ids)
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"])
-
if hasattr(self.llm.model, "embed_tokens"):
inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
@@ -284,9 +296,13 @@
else:
inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
- inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1) # [prompt, audio]
- attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
-
+ inputs_embeds = torch.cat(
+ (inputs_embeds[None, :, :], encoder_out), dim=1
+ ) # [prompt, audio]
+ attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(
+ kwargs["device"]
+ )
+
preds = self.llm.generate(
inputs_embeds=inputs_embeds,
max_length=kwargs.get("max_length", 200),
@@ -301,17 +317,16 @@
attention_mask=attention_mask,
bos_token_id=tokenizer.bos_token_id,
eos_token_id=tokenizer.eos_token_id,
- pad_token_id=tokenizer.pad_token_id
+ pad_token_id=tokenizer.pad_token_id,
)
-
text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
- text = text[0].split(': ')[-1]
+ text = text[0].split(": ")[-1]
text = text.strip()
-
+
# preds = torch.argmax(model_outputs.logits, -1)
-
+
ibest_writer = None
if kwargs.get("output_dir") is not None:
if not hasattr(self, "writer"):
@@ -324,9 +339,5 @@
if ibest_writer is not None:
ibest_writer["text"][key[0]] = text
-
-
-
-
- return results, meta_data
+ return results, meta_data
--
Gitblit v1.9.1