From e9d2cfc3a134b00f4e98271fbee3838d1ccecbcc Mon Sep 17 00:00:00 2001
From: VirtuosoQ <2416050435@qq.com>
Date: 星期五, 26 四月 2024 14:59:30 +0800
Subject: [PATCH] FunASR java http client
---
funasr/models/llm_asr/model.py | 230 ++++++++++++++++++++++++++++++---------------------------
1 files changed, 121 insertions(+), 109 deletions(-)
diff --git a/funasr/models/llm_asr/model.py b/funasr/models/llm_asr/model.py
index a903262..90cbd94 100644
--- a/funasr/models/llm_asr/model.py
+++ b/funasr/models/llm_asr/model.py
@@ -12,7 +12,7 @@
from funasr.models.ctc.ctc import CTC
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.metrics.compute_acc import th_accuracy, compute_accuracy
-# from funasr.models.e2e_asr_common import ErrorCalculator
+from funasr.metrics.common import ErrorCalculator
from funasr.train_utils.device_funcs import force_gatherable
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
@@ -20,8 +20,8 @@
from funasr.register import tables
-@tables.register("model_classes", "LLMASRNAR")
-class LLMASRNAR(nn.Module):
+@tables.register("model_classes", "LLMASR")
+class LLMASR(nn.Module):
""" """
def __init__(
@@ -30,8 +30,10 @@
specaug_conf: dict = None,
normalize: str = None,
normalize_conf: dict = None,
- encoder: str = None,
- encoder_conf: dict = None,
+ audio_encoder: str = None,
+ audio_encoder_conf: dict = None,
+ audio_adaptor: str = None,
+ audio_adaptor_conf: dict = None,
decoder: str = None,
decoder_conf: dict = None,
ctc: str = None,
@@ -39,8 +41,6 @@
ctc_weight: float = 0.5,
llm: str = None,
llm_conf: dict = None,
- adaptor: str = None,
- adaptor_conf: dict = None,
input_size: int = 80,
vocab_size: int = -1,
ignore_id: int = -1,
@@ -70,23 +70,30 @@
normalize = normalize_class(**normalize_conf)
# audio encoder
- hub = encoder_conf.get("hub", None)
- if hub == "funasr":
+ hub = audio_encoder_conf.get("hub", None)
+ if hub == "ms":
from funasr import AutoModel
- init_param_path = encoder_conf.get("hub", "iic/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch")
- model = AutoModel(model=init_param_path, model_revision="v2.0.4")
+ model = AutoModel(model=audio_encoder, model_revision="master")
# frontend = model.kwargs.get("frontend")
- model.model.decoder = None
+ audio_encoder_output_size = model.model.encoder_output_size
+
+ audio_encoder = model.model.model.encoder
- self.audio_encoder = model.model
# self.frontend = frontend
elif hub == "hf":
pass
else:
- encoder_class = tables.encoder_classes.get(encoder)
- encoder = encoder_class(input_size=input_size, **encoder_conf)
- encoder_output_size = encoder.output_size()
+ encoder_class = tables.encoder_classes.get(audio_encoder)
+ audio_encoder = encoder_class(input_size=input_size, **audio_encoder_conf)
+ audio_encoder_output_size = audio_encoder.output_size()
+ freeze = audio_encoder_conf.get("freeze", True)
+ if freeze:
+ for name, param in audio_encoder.named_parameters():
+ param.requires_grad = False
+ audio_encoder.eval()
+
+ self.audio_encoder = audio_encoder
# llm
hub = llm_conf.get("hub", "hf")
@@ -95,6 +102,7 @@
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoConfig
init_param_path = llm_conf.get("init_param_path", "vicuna-7b-v1.5")
+
model = AutoModelForCausalLM.from_pretrained(
init_param_path,
load_in_8bit=None,
@@ -109,10 +117,11 @@
self.llm = model
# adaptor
- adaptor_class = tables.adaptor_classes.get(adaptor)
- adaptor = adaptor_class(**adaptor_conf)
+ adaptor_class = tables.adaptor_classes.get(audio_adaptor)
+ audio_adaptor_conf["encoder_dim"] = audio_encoder_output_size
+ audio_adaptor = adaptor_class(**audio_adaptor_conf)
- self.adaptor = adaptor
+ self.audio_adaptor = audio_adaptor
self.blank_id = blank_id
@@ -122,8 +131,6 @@
self.ignore_id = ignore_id
self.specaug = specaug
self.normalize = normalize
- self.encoder = encoder
-
self.criterion_att = LabelSmoothingLoss(
size=vocab_size,
@@ -131,12 +138,7 @@
smoothing=lsm_weight,
normalize_length=length_normalized_loss,
)
- #
- # if report_cer or report_wer:
- # self.error_calculator = ErrorCalculator(
- # token_list, sym_space, sym_blank, report_cer, report_wer
- # )
- #
+
self.error_calculator = None
self.length_normalized_loss = length_normalized_loss
@@ -172,37 +174,36 @@
batch_size = speech.shape[0]
# audio encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths, audio_mask=audio_mask)
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- # adaptor
- encoder_out = self.adaptor(encoder_out)
+ # audio_adaptor
+ encoder_out = self.audio_adaptor(encoder_out)
+
+ input_ids[input_ids == -1] = 0
+ input_ids[input_ids == -100] = 0
+ if hasattr(self.llm.model, "embed_tokens"):
+ inputs_embeds = self.llm.model.embed_tokens(input_ids)
+ elif hasattr(self.llm.model.model, "embed_tokens"):
+ inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
+ else:
+ inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
- if input_ids is not None:
- input_ids[input_ids == -1] = 0
- if hasattr(self.llm.model, "embed_tokens"):
- inputs_embeds = self.llm.model.embed_tokens(input_ids)
- elif hasattr(self.llm.model.model, "embed_tokens"):
- inputs_embeds = self.llm.model.model.embed_tokens(input_ids)
- else:
- inputs_embeds = self.llm.model.model.model.embed_tokens(input_ids)
-
- if audio_mask is not None:
- batch_size, token_num, dims = inputs_embeds.shape
- _, l, _ = encoder_out.shape
- encoder_outs_pad = F.pad(encoder_out, (0, 0, token_num-l-1, 1, 0, 0), value=0.0)
- inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (~audio_mask[:, :, None])
- inputs_embeds = F.pad(inputs_embeds[:, 1:, :], (0, 0, 0, 1, 0, 0), value=0.0)
+ if audio_mask is not None:
+ batch_size, token_num, dims = inputs_embeds.shape
+ _, l, _ = encoder_out.shape
+ # [audio, bos, prompt, input, pad]
+ encoder_outs_pad = F.pad(encoder_out, (0, 0, 0, token_num - l, 0, 0), value=0.0)
+ inputs_embeds = encoder_outs_pad * audio_mask[:, :, None] + inputs_embeds * (1.0-audio_mask[:, :, None])
model_outputs = self.llm(inputs_embeds=inputs_embeds, attention_mask=attention_mask, labels=labels_ids)
loss = model_outputs.loss
stats = {}
- if self.metric:
- with torch.no_grad():
- preds = torch.argmax(model_outputs.logits, -1)
- acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
- stats["acc"] = acc_att
+ with torch.no_grad():
+ preds = torch.argmax(model_outputs.logits, -1)
+ acc_att = compute_accuracy(preds[:, :-1], labels_ids[:, 1:], ignore_label=-100)
+ stats["acc"] = acc_att
stats["loss"] = torch.clone(loss.detach())
@@ -211,25 +212,18 @@
batch_size = int((text_lengths + 1).sum())
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
-
+
def encode(
self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
+ ):
+ speech = speech.permute(0, 2, 1)
+ res = self.audio_encoder(speech)
+ if isinstance(res, (list, tuple)):
+ encoder_out, encoder_out_lens = res[0], res[1]
+ else:
+ encoder_out, encoder_out_lens = res, speech_lengths
+ return encoder_out, encoder_out_lens
- audio_mask = kwargs.get("audio_mask")
- audio_token_lengths = audio_mask.sum(-1)
-
- batch = {"speech": speech, "speech_lengths": speech_lengths}
- enc, enc_lens = self.audio_encoder.encode(**batch)
- enc_mask = sequence_mask(enc_lens, enc.size(1), device=enc.device)[:, None, :]
- pre_acoustic_embeds, pre_token_length, _, _ = self.audio_encoder.predictor(enc,
- mask=enc_mask,
- target_label_length=audio_token_lengths,
- )
-
- return pre_acoustic_embeds, pre_token_length
-
-
def inference(self,
data_in,
data_lengths=None,
@@ -239,14 +233,12 @@
**kwargs,
):
+ prompt = kwargs.get("prompt", "Transcribe speech to text.")
+
if kwargs.get("batch_size", 1) > 1:
raise NotImplementedError("batch decoding is not implemented")
-
- # init beamsearch
- if self.beam_search is None:
- logging.info("enable beam_search")
- self.init_beam_search(**kwargs)
- self.nbest = kwargs.get("nbest", 1)
+
+
meta_data = {}
if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
@@ -271,50 +263,70 @@
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
+
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- if isinstance(encoder_out, tuple):
- encoder_out = encoder_out[0]
+
+ # adaptor
+ encoder_out = self.audio_adaptor(encoder_out)
- # c. Passed the encoder result and the beam search
- nbest_hyps = self.beam_search(
- x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
+
+ prompt_pre = "USER: \nINSTRUCTION: {}\nINPUT: ".format(prompt)
+ prompt_ids = tokenizer.encode(prompt_pre)
+ prompt_length = len(prompt_ids)
+ prompt_ids = torch.tensor(prompt_ids, dtype=torch.int64).to(kwargs["device"])
+
+
+ if hasattr(self.llm.model, "embed_tokens"):
+ inputs_embeds = self.llm.model.embed_tokens(prompt_ids)
+ elif hasattr(self.llm.model.model, "embed_tokens"):
+ inputs_embeds = self.llm.model.model.embed_tokens(prompt_ids)
+ else:
+ inputs_embeds = self.llm.model.model.model.embed_tokens(prompt_ids)
+
+ inputs_embeds = torch.cat((inputs_embeds[None, :, :], encoder_out), dim=1) # [prompt, audio]
+ attention_mask = torch.ones(inputs_embeds.size()[:-1], dtype=torch.long).to(kwargs["device"])
+
+ preds = self.llm.generate(
+ inputs_embeds=inputs_embeds,
+ max_length=kwargs.get("max_length", 200),
+ max_new_tokens=kwargs.get("max_new_tokens", 200),
+ num_beams=kwargs.get("num_beams", 4),
+ do_sample=kwargs.get("do_sample", False),
+ min_length=kwargs.get("min_length", 1),
+ top_p=kwargs.get("top_p", 1.0),
+ repetition_penalty=kwargs.get("repetition_penalty", 1.0),
+ length_penalty=kwargs.get("length_penalty", 1.0),
+ temperature=kwargs.get("temperature", 1.0),
+ attention_mask=attention_mask,
+ bos_token_id=tokenizer.bos_token_id,
+ eos_token_id=tokenizer.eos_token_id,
+ pad_token_id=tokenizer.pad_token_id
)
+
+
+ text = tokenizer.batch_decode(preds, add_special_tokens=False, skip_special_tokens=True)
+
+ text = text[0].split(': ')[-1]
+ text = text.strip()
- nbest_hyps = nbest_hyps[: self.nbest]
+ # preds = torch.argmax(model_outputs.logits, -1)
+ ibest_writer = None
+ if kwargs.get("output_dir") is not None:
+ if not hasattr(self, "writer"):
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = self.writer[f"{0 + 1}best_recog"]
+
results = []
- b, n, d = encoder_out.size()
- for i in range(b):
-
- for nbest_idx, hyp in enumerate(nbest_hyps):
- ibest_writer = None
- if kwargs.get("output_dir") is not None:
- if not hasattr(self, "writer"):
- self.writer = DatadirWriter(kwargs.get("output_dir"))
- ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
-
- # Change integer-ids to tokens
- token = tokenizer.ids2tokens(token_int)
- text = tokenizer.tokens2text(token)
-
- text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
- result_i = {"key": key[i], "token": token, "text": text_postprocessed}
- results.append(result_i)
-
- if ibest_writer is not None:
- ibest_writer["token"][key[i]] = " ".join(token)
- ibest_writer["text"][key[i]] = text_postprocessed
+ result_i = {"key": key[0], "text": text}
+ results.append(result_i)
+
+ if ibest_writer is not None:
+ ibest_writer["text"][key[0]] = text
+
+
+
return results, meta_data
--
Gitblit v1.9.1