From 149063ced4d2d5269f0472677228eadfcb4a4d8a Mon Sep 17 00:00:00 2001
From: 维石 <shixian.shi@alibaba-inc.com>
Date: 星期三, 17 四月 2024 14:33:24 +0800
Subject: [PATCH] update seaco finetune
---
funasr/models/llm_asr_nar/model.py | 88 ++++++++++++++++++++++++++++++++++++--------
1 files changed, 72 insertions(+), 16 deletions(-)
diff --git a/funasr/models/llm_asr_nar/model.py b/funasr/models/llm_asr_nar/model.py
index a6096b2..994259a 100644
--- a/funasr/models/llm_asr_nar/model.py
+++ b/funasr/models/llm_asr_nar/model.py
@@ -75,7 +75,7 @@
if hub == "funasr":
from funasr import AutoModel
init_param_path = encoder_conf.get("init_param_path", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
- model = AutoModel(model=init_param_path, model_revision="v2.0.4")
+ model = AutoModel(model=init_param_path, model_revision="master")
# frontend = model.kwargs.get("frontend")
model.model.decoder = None
@@ -264,7 +264,7 @@
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=None)
- if len(kwargs.get("data_type")) > 1:
+ if len(kwargs.get("data_type", [])) > 1:
audio_sample_list, text_token_int_list = audio_sample_list
text_token_int = text_token_int_list[0].replace(" ", "")
text_token_int = tokenizer.encode(text_token_int)
@@ -366,7 +366,7 @@
decoder_conf: dict = None,
ctc: str = None,
ctc_conf: dict = None,
- ctc_weight: float = 0.5,
+ ctc_weight: float = 0.0,
llm: str = None,
llm_conf: dict = None,
adaptor: str = None,
@@ -406,7 +406,7 @@
from funasr import AutoModel
init_param_path = encoder_conf.get("init_param_path",
"iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
- model = AutoModel(model=init_param_path, model_revision="v2.0.4")
+ model = AutoModel(model=init_param_path, model_revision="master")
# frontend = model.kwargs.get("frontend")
model.model.decoder = None
@@ -473,6 +473,15 @@
self.length_normalized_loss = length_normalized_loss
self.beam_search = None
+ if ctc_weight > 0.0:
+ if ctc_conf is None:
+ ctc_conf = {}
+
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=adaptor_conf["encoder_dim"], **ctc_conf
+ )
+ self.ctc_weight = ctc_weight
+ self.ctc = ctc
def forward(
self,
@@ -502,9 +511,23 @@
speech_lengths = speech_lengths[:, 0]
batch_size = speech.shape[0]
-
+
+ stats = {}
# audio encoder
- encoder_out, encoder_out_lens, loss_pre = self.encode(speech, speech_lengths, audio_mask=audio_mask)
+ outs = self.encode(speech, speech_lengths, audio_mask=audio_mask)
+ enc, enc_lens = outs[0], outs[1]
+ encoder_out, encoder_out_lens, loss_pre = outs[2], outs[3], outs[4]
+
+
+ # decoder: CTC branch
+
+ if self.ctc_weight != 0.0:
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ enc, enc_lens, text, text_lengths
+ )
+
+ # Collect CTC branch stats
+ stats["loss_ctc"] = torch.clone(loss_ctc.detach()) if loss_ctc is not None else None
# adaptor
encoder_out = self.adaptor(encoder_out)
@@ -536,17 +559,19 @@
# labels_ids[1:] -> [prompt, input, target, eos] -> [-1, input, target, eos];
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
loss_llm = model_outputs.loss
+ stats["loss_llm"] = torch.clone(loss_llm.detach())
+ if self.ctc_weight > 0.0:
+ loss_llm = self.ctc_weight * loss_ctc + loss_llm
loss = loss_llm + loss_pre * self.predictor_weight
- stats = {}
+
with torch.no_grad():
preds = torch.argmax(model_outputs.logits, -1)
acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
stats["acc"] = acc_att
-
stats["loss_pre"] = torch.clone(loss_pre.detach())
- stats["loss_llm"] = torch.clone(loss_llm.detach())
stats["loss"] = torch.clone(loss.detach())
+ stats["batch_size"] = batch_size
# force_gatherable: to-device and to-tensor if scalar for DataParallel
if self.length_normalized_loss:
@@ -561,7 +586,7 @@
audio_mask = kwargs.get("audio_mask", None)
audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
text_token_int = kwargs.get("text_token_int", None)
- if audio_token_lengths is None:
+ if audio_token_lengths is None and text_token_int is not None:
audio_token_lengths = torch.tensor([len(text_token_int)], dtype=torch.int64)
batch = {"speech": speech, "speech_lengths": speech_lengths}
@@ -572,9 +597,28 @@
mask=enc_mask,
target_label_length=audio_token_lengths,
)
- loss_pre = self.criterion_pre(audio_token_lengths.type_as(pre_token_length), pre_token_length)
+ loss_pre = 0.0
+ if audio_token_lengths is not None:
+ loss_pre = self.criterion_pre(audio_token_lengths.type_as(pre_token_length), pre_token_length)
- return pre_acoustic_embeds, pre_token_length, loss_pre
+ return enc, enc_lens, pre_acoustic_embeds, pre_token_length, loss_pre
+
+ def _calc_ctc_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ # Calc CTC loss
+ loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+ # Calc CER using CTC
+ cer_ctc = None
+ if not self.training and self.error_calculator is not None:
+ ys_hat = self.ctc.argmax(encoder_out).data
+ cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+ return loss_ctc, cer_ctc
def inference(self,
data_in,
@@ -603,10 +647,12 @@
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=None)
- if len(kwargs.get("data_type")) > 1:
+ if len(kwargs.get("data_type", [])) > 1:
audio_sample_list, text_token_int_list = audio_sample_list
- text_token_int = text_token_int_list[0].replace(" ", "")
+ text_token_int = text_token_int_list[0]
text_token_int = tokenizer.encode(text_token_int)
+ if text_token_int[0] == tokenizer.bos_token_id:
+ text_token_int = text_token_int[1:]
else:
text_token_int = None
time2 = time.perf_counter()
@@ -621,23 +667,30 @@
speech_lengths = speech_lengths.to(device=kwargs["device"])
# Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, text_token_int=text_token_int)
+ res = self.encode(speech, speech_lengths, text_token_int=text_token_int)
+ encoder_out = res[0]
# adaptor
encoder_out = self.adaptor(encoder_out)
prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
prompt_ids = tokenizer.encode(prompt_pre)
+ if prompt_ids[0] == tokenizer.bos_token_id:
+ prompt_ids = prompt_ids[1:]
+ # prompt_ids = prompt_ids + [tokenizer.pad_token_id]
prompt_length = len(prompt_ids)
prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"])
+ pad = torch.tensor([tokenizer.pad_token_id], dtype=torch.int64).to(kwargs["device"])
if hasattr(self.llm.model, "embed_tokens"):
inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
+ pad = self.llm.model.embed_tokens(pad)
elif hasattr(self.llm.model.model, "embed_tokens"):
inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
else:
inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
+ # inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out, pad[None, :, :]), dim=1) # [prompt, audio, pad]
inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1) # [prompt, audio]
attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
@@ -662,8 +715,11 @@
preds = torch.argmax(model_outputs.logits, -1)
text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
- text = text[0].split(': ')[-1]
+ text = text[0].split(':')[-1]
text = text.strip()
+ if text.startswith("Please\n "):
+ text = text.replace("Please\n ", "")
+ text = text.strip()
# preds = torch.argmax(model_outputs.logits, -1)
--
Gitblit v1.9.1