From e451eb799a5bccd53dfd4b86cf66a4668b0088b7 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期三, 06 三月 2024 15:31:47 +0800
Subject: [PATCH] infer for word punc model
---
funasr/models/llm_asr/model.py | 141 +++++++++++++++++++++++++++--------------------
1 files changed, 81 insertions(+), 60 deletions(-)
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index a903262..4139d8c 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -20,8 +20,8 @@
from funasr.register import tables
-@tables.register("model_classes", "LLMASRNAR")
-class LLMASRNAR(nn.Module):
+@tables.register("model_classes", "LLMASR")
+class LLMASR(nn.Module):
""" """
def __init__(
@@ -73,7 +73,7 @@
hub = encoder_conf.get("hub", None)
if hub == "funasr":
from funasr import AutoModel
- init_param_path = encoder_conf.get("hub", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
+ init_param_path = encoder_conf.get("init_param_path", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
model = AutoModel(model=init_param_path, model_revision="v2.0.4")
# frontend = model.kwargs.get("frontend")
model.model.decoder = None
@@ -179,6 +179,7 @@
if input_ids is not None:
input_ids[input_ids == -1] = 0
+ input_ids[input_ids == -100] = 0
if hasattr(self.llm.model, "embed_tokens"):
inputs_embeds = self.llm.model.embed_tokens(input_ids)
elif hasattr(self.llm.model.model, "embed_tokens"):
@@ -190,7 +191,7 @@
batch_size, token_num, dims = inputs_embeds.shape
_, l, _ = encoder_out.shape
encoder_outs_pad = F.pad(encoder_out, (0, 0, token_num-l-1, 1, 0, 0), value=0.0)
- inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (~audio_mask[:, :, None])
+ inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (1.0-audio_mask[:, :, None])
inputs_embeds = F.pad(inputs_embeds[:, 1:, :], (0, 0, 0, 1, 0, 0), value=0.0)
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
@@ -198,11 +199,10 @@
stats = {}
- if self.metric:
- with torch.no_grad():
- preds = torch.argmax(model_outputs.logits, -1)
- acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
- stats["acc"] = acc_att
+ with torch.no_grad():
+ preds = torch.argmax(model_outputs.logits, -1)
+ acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
+ stats["acc"] = acc_att
stats["loss"] = torch.clone(loss.detach())
@@ -216,16 +216,17 @@
self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
) -> Tuple[torch.Tensor, torch.Tensor]:
- audio_mask = kwargs.get("audio_mask")
- audio_token_lengths = audio_mask.sum(-1)
+ audio_mask = kwargs.get("audio_mask", None)
+ audio_token_lengths = audio_mask.sum(-1) if audio_mask is not None else None
batch = {"speech": speech, "speech_lengths": speech_lengths}
enc, enc_lens = self.audio_encoder.encode(**batch)
- enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
- pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc,
- mask=enc_mask,
- target_label_length=audio_token_lengths,
- )
+ with autocast(False):
+ enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
+ pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc,
+ mask=enc_mask,
+ target_label_length=audio_token_lengths,
+ )
return pre_acoustic_embeds, pre_token_length
@@ -239,14 +240,12 @@
**kwargs,
):
+ prompt = kwargs.get("prompt", "Transcribe speech to text.")
+
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
-
- # init beamsearch
- if self.beam_search is None:
- logging.info("enable beam_search")
- self.init_beam_search(**kwargs)
- self.nbest = kwargs.get("nbest", 1)
+
+
meta_data = {}
if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
@@ -271,50 +270,72 @@
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
+
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- if isinstance(encoder_out, tuple):
- encoder_out = encoder_out[0]
+
+ # adaptor
+ encoder_out = self.adaptor(encoder_out)
- # c. Passed the encoder result and the beam search
- nbest_hyps = self.beam_search(
- x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
- )
+
+ prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
+ prompt_ids = tokenizer.encode(prompt_pre)
+ prompt_length = len(prompt_ids)
+ prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"])
+
+
+ if hasattr(self.llm.model, "embed_tokens"):
+ inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
+ elif hasattr(self.llm.model.model, "embed_tokens"):
+ inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
+ else:
+ inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
+
+ inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1) # [prompt, audio]
+ attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
- nbest_hyps = nbest_hyps[: self.nbest]
+ # model_outputs = self.llm.generate(
+ # inputs_embeds=inputs_embeds,
+ # max_length=kwargs.get("max_length", 200),
+ # max_new_tokens=kwargs.get("max_new_tokens", 200),
+ # num_beams=kwargs.get("num_beams", 4),
+ # do_sample=kwargs.get("do_sample", False),
+ # min_length=kwargs.get("min_length", 1),
+ # top_p=kwargs.get("top_p", 1.0),
+ # repetition_penalty=kwargs.get("repetition_penalty", 1.0),
+ # length_penalty=kwargs.get("length_penalty", 1.0),
+ # temperature=kwargs.get("temperature", 1.0),
+ # attention_mask=attention_mask,
+ # bos_token_id=tokenizer.bos_token_id,
+ # eos_token_id=tokenizer.eos_token_id,
+ # pad_token_id=tokenizer.pad_token_id
+ # )
+
+
+ model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=None)
+ preds = torch.argmax(model_outputs.logits, -1)
+ text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
+
+ text = text[0].split(': ')[-1]
+ text = text.strip()
+ # preds = torch.argmax(model_outputs.logits, -1)
+
+ ibest_writer = None
+ if kwargs.get("output_dir") is not None:
+ if not hasattr(self, "writer"):
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = self.writer[f"{0 + 1}best_recog"]
+
results = []
- b, n, d = encoder_out.size()
- for i in range(b):
-
- for nbest_idx, hyp in enumerate(nbest_hyps):
- ibest_writer = None
- if kwargs.get("output_dir") is not None:
- if not hasattr(self, "writer"):
- self.writer = DatadirWriter(kwargs.get("output_dir"))
- ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
-
- # Change integer-ids to tokens
- token = tokenizer.ids2tokens(token_int)
- text = tokenizer.tokens2text(token)
-
- text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
- result_i = {"key": key[i], "token": token, "text": text_postprocessed}
- results.append(result_i)
-
- if ibest_writer is not None:
- ibest_writer["token"][key[i]] = " ".join(token)
- ibest_writer["text"][key[i]] = text_postprocessed
+ result_i = {"key": key[0], "text": text}
+ results.append(result_i)
+
+ if ibest_writer is not None:
+ ibest_writer["text"][key[0]] = text
+
+
+
return results, meta_data
--
Gitblit v1.9.1