From 273d0d6015a4655cb34cc77cee2c3267a23d7d03 Mon Sep 17 00:00:00 2001
From: 九耳 <mengzhe.cmz@alibaba-inc.com>
Date: 星期五, 03 二月 2023 13:09:05 +0800
Subject: [PATCH] update punc and asr_inference_paraformer_vad_punc
---
funasr/bin/asr_inference_paraformer_vad_punc.py | 107 +++--------------------------------------------------
1 files changed, 7 insertions(+), 100 deletions(-)
diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py
index 619e6fd..7a289aa 100644
--- a/funasr/bin/asr_inference_paraformer_vad_punc.py
+++ b/funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -1,9 +1,10 @@
#!/usr/bin/env python3
+
+import json
import argparse
import logging
import sys
import time
-import json
from pathlib import Path
from typing import Optional
from typing import Sequence
@@ -38,10 +39,10 @@
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.vad import VADTask
from funasr.utils.timestamp_tools import time_stamp_lfr6
-from funasr.tasks.punctuation import PunctuationTask
+from funasr.bin.punctuation_infer import Text2Punc
from funasr.torch_utils.forward_adaptor import ForwardAdaptor
from funasr.datasets.preprocessor import CommonPreprocessor
-from funasr.punctuation.text_preprocessor import split_words, split_to_mini_sentence
+from funasr.punctuation.text_preprocessor import split_to_mini_sentence
header_colors = '\033[95m'
end_colors = '\033[0m'
@@ -235,9 +236,9 @@
predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], predictor_outs[2], predictor_outs[3]
+ pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
- pre_token_length = pre_token_length.round().long()
decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
@@ -481,6 +482,7 @@
punc_infer_config: Optional[str] = None,
punc_model_file: Optional[str] = None,
outputs_dict: Optional[bool] = True,
+ param_dict: dict = None,
**kwargs,
):
assert check_argument_types()
@@ -546,6 +548,7 @@
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
output_dir_v2: Optional[str] = None,
+ param_dict: dict = None,
):
# 3. Build data-iterator
if data_path_and_name_and_type is None and raw_inputs is not None:
@@ -678,102 +681,6 @@
logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".
format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor+1e-6)))
return asr_result_list
- return _forward
-
-def Text2Punc(
- train_config: Optional[str],
- model_file: Optional[str],
- device: str = "cpu",
- dtype: str = "float32",
-):
-
- # 2. Build Model
- model, train_args = PunctuationTask.build_model_from_file(
- train_config, model_file, device)
- # Wrape model to make model.nll() data-parallel
- wrapped_model = ForwardAdaptor(model, "inference")
- wrapped_model.to(dtype=getattr(torch, dtype)).to(device=device).eval()
- # logging.info(f"Model:\n{model}")
- punc_list = train_args.punc_list
- period = 0
- for i in range(len(punc_list)):
- if punc_list[i] == ",":
- punc_list[i] = "锛�"
- elif punc_list[i] == "?":
- punc_list[i] = "锛�"
- elif punc_list[i] == "銆�":
- period = i
- preprocessor = CommonPreprocessor(
- train=False,
- token_type="word",
- token_list=train_args.token_list,
- bpemodel=train_args.bpemodel,
- text_cleaner=train_args.cleaner,
- g2p_type=train_args.g2p,
- text_name="text",
- non_linguistic_symbols=train_args.non_linguistic_symbols,
- )
-
- print("start decoding!!!")
-
- def _forward(words, split_size = 20):
- cache_sent = []
- mini_sentences = split_to_mini_sentence(words, split_size)
- new_mini_sentence = ""
- new_mini_sentence_punc = []
- cache_pop_trigger_limit = 200
- for mini_sentence_i in range(len(mini_sentences)):
- mini_sentence = mini_sentences[mini_sentence_i]
- mini_sentence = cache_sent + mini_sentence
- data = {"text": " ".join(mini_sentence)}
- batch = preprocessor(data=data, uid="12938712838719")
- batch["text_lengths"] = torch.from_numpy(np.array([len(batch["text"])], dtype='int32'))
- batch["text"] = torch.from_numpy(batch["text"])
- # Extend one dimension to fake a batch dim.
- batch["text"] = torch.unsqueeze(batch["text"], 0)
- batch = to_device(batch, device)
- y, _ = wrapped_model(**batch)
- _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1)
- punctuations = indices
- if indices.size()[0] != 1:
- punctuations = torch.squeeze(indices)
- assert punctuations.size()[0] == len(mini_sentence)
-
- # Search for the last Period/QuestionMark as cache
- if mini_sentence_i < len(mini_sentences) - 1:
- sentenceEnd = -1
- last_comma_index = -1
- for i in range(len(punctuations) - 2, 1, -1):
- if punc_list[punctuations[i]] == "銆�" or punc_list[punctuations[i]] == "锛�":
- sentenceEnd = i
- break
- if last_comma_index < 0 and punc_list[punctuations[i]] == "锛�":
- last_comma_index = i
-
- if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0:
- # The sentence it too long, cut off at a comma.
- sentenceEnd = last_comma_index
- punctuations[sentenceEnd] = period
- cache_sent = mini_sentence[sentenceEnd + 1:]
- mini_sentence = mini_sentence[0:sentenceEnd + 1]
- punctuations = punctuations[0:sentenceEnd + 1]
-
- # if len(punctuations) == 0:
- # continue
-
- punctuations_np = punctuations.cpu().numpy()
- new_mini_sentence_punc += [int(x) for x in punctuations_np]
- words_with_punc = []
- for i in range(len(mini_sentence)):
- if i > 0:
- if len(mini_sentence[i][0].encode()) == 1 and len(mini_sentence[i - 1][0].encode()) == 1:
- mini_sentence[i] = " " + mini_sentence[i]
- words_with_punc.append(mini_sentence[i])
- if punc_list[punctuations[i]] != "_":
- words_with_punc.append(punc_list[punctuations[i]])
- new_mini_sentence += "".join(words_with_punc)
-
- return new_mini_sentence, new_mini_sentence_punc
return _forward
def get_parser():
--
Gitblit v1.9.1