From b75d1e89bb2f513a79bb07e9100ba1cd2bbcf40c Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期日, 09 六月 2024 00:32:57 +0800
Subject: [PATCH] fix bug
---
runtime/python/libtorch/funasr_torch/paraformer_bin.py | 120 ++++++++++++++++++++++++++++++++----------------------------
1 files changed, 64 insertions(+), 56 deletions(-)
diff --git a/runtime/python/libtorch/funasr_torch/paraformer_bin.py b/runtime/python/libtorch/funasr_torch/paraformer_bin.py
index 9954daa..68886df 100644
--- a/runtime/python/libtorch/funasr_torch/paraformer_bin.py
+++ b/runtime/python/libtorch/funasr_torch/paraformer_bin.py
@@ -7,48 +7,46 @@
import librosa
import numpy as np
-from .utils.utils import (CharTokenizer, Hypothesis,
- TokenIDConverter, get_logger,
- read_yaml)
+from .utils.utils import CharTokenizer, Hypothesis, TokenIDConverter, get_logger, read_yaml
from .utils.postprocess_utils import sentence_postprocess
from .utils.frontend import WavFrontend
from .utils.timestamp_utils import time_stamp_lfr6_onnx
+
logging = get_logger()
import torch
-class Paraformer():
- def __init__(self, model_dir: Union[str, Path] = None,
- batch_size: int = 1,
- device_id: Union[str, int] = "-1",
- plot_timestamp_to: str = "",
- quantize: bool = False,
- intra_op_num_threads: int = 1,
- ):
+class Paraformer:
+ def __init__(
+ self,
+ model_dir: Union[str, Path] = None,
+ batch_size: int = 1,
+ device_id: Union[str, int] = "-1",
+ plot_timestamp_to: str = "",
+ quantize: bool = False,
+ intra_op_num_threads: int = 1,
+ ):
if not Path(model_dir).exists():
- raise FileNotFoundError(f'{model_dir} does not exist.')
+ raise FileNotFoundError(f"{model_dir} does not exist.")
- model_file = os.path.join(model_dir, 'model.torchscripts')
+ model_file = os.path.join(model_dir, "model.torchscripts")
if quantize:
- model_file = os.path.join(model_dir, 'model_quant.torchscripts')
- config_file = os.path.join(model_dir, 'config.yaml')
- cmvn_file = os.path.join(model_dir, 'am.mvn')
+ model_file = os.path.join(model_dir, "model_quant.torchscripts")
+ config_file = os.path.join(model_dir, "config.yaml")
+ cmvn_file = os.path.join(model_dir, "am.mvn")
config = read_yaml(config_file)
- self.converter = TokenIDConverter(config['token_list'])
+ self.converter = TokenIDConverter(config["token_list"])
self.tokenizer = CharTokenizer()
- self.frontend = WavFrontend(
- cmvn_file=cmvn_file,
- **config['frontend_conf']
- )
+ self.frontend = WavFrontend(cmvn_file=cmvn_file, **config["frontend_conf"])
self.ort_infer = torch.jit.load(model_file)
self.batch_size = batch_size
self.device_id = device_id
self.plot_timestamp_to = plot_timestamp_to
- if "predictor_bias" in config['model_conf'].keys():
- self.pred_bias = config['model_conf']['predictor_bias']
+ if "predictor_bias" in config["model_conf"].keys():
+ self.pred_bias = config["model_conf"]["predictor_bias"]
else:
self.pred_bias = 0
@@ -57,7 +55,7 @@
waveform_nums = len(waveform_list)
asr_res = []
for beg_idx in range(0, waveform_nums, self.batch_size):
-
+
end_idx = min(waveform_nums, beg_idx + self.batch_size)
feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
try:
@@ -74,51 +72,66 @@
else:
us_alphas, us_peaks = None, None
except:
- #logging.warning(traceback.format_exc())
+ # logging.warning(traceback.format_exc())
logging.warning("input wav is silence or noise")
- preds = ['']
+ preds = [""]
else:
preds = self.decode(am_scores, valid_token_lens)
if us_peaks is None:
for pred in preds:
pred = sentence_postprocess(pred)
- asr_res.append({'preds': pred})
+ asr_res.append({"preds": pred})
else:
for pred, us_peaks_ in zip(preds, us_peaks):
raw_tokens = pred
- timestamp, timestamp_raw = time_stamp_lfr6_onnx(us_peaks_, copy.copy(raw_tokens))
- text_proc, timestamp_proc, _ = sentence_postprocess(raw_tokens, timestamp_raw)
+ timestamp, timestamp_raw = time_stamp_lfr6_onnx(
+ us_peaks_, copy.copy(raw_tokens)
+ )
+ text_proc, timestamp_proc, _ = sentence_postprocess(
+ raw_tokens, timestamp_raw
+ )
# logging.warning(timestamp)
if len(self.plot_timestamp_to):
- self.plot_wave_timestamp(waveform_list[0], timestamp, self.plot_timestamp_to)
- asr_res.append({'preds': text_proc, 'timestamp': timestamp_proc, "raw_tokens": raw_tokens})
+ self.plot_wave_timestamp(
+ waveform_list[0], timestamp, self.plot_timestamp_to
+ )
+ asr_res.append(
+ {
+ "preds": text_proc,
+ "timestamp": timestamp_proc,
+ "raw_tokens": raw_tokens,
+ }
+ )
return asr_res
def plot_wave_timestamp(self, wav, text_timestamp, dest):
# TODO: Plot the wav and timestamp results with matplotlib
import matplotlib
- matplotlib.use('Agg')
- matplotlib.rc("font", family='Alibaba PuHuiTi') # set it to a font that your system supports
+
+ matplotlib.use("Agg")
+ matplotlib.rc(
+ "font", family="Alibaba PuHuiTi"
+ ) # set it to a font that your system supports
import matplotlib.pyplot as plt
+
fig, ax1 = plt.subplots(figsize=(11, 3.5), dpi=320)
ax2 = ax1.twinx()
ax2.set_ylim([0, 2.0])
# plot waveform
ax1.set_ylim([-0.3, 0.3])
time = np.arange(wav.shape[0]) / 16000
- ax1.plot(time, wav/wav.max()*0.3, color='gray', alpha=0.4)
+ ax1.plot(time, wav / wav.max() * 0.3, color="gray", alpha=0.4)
# plot lines and text
- for (char, start, end) in text_timestamp:
- ax1.vlines(start, -0.3, 0.3, ls='--')
- ax1.vlines(end, -0.3, 0.3, ls='--')
- x_adj = 0.045 if char != '<sil>' else 0.12
+ for char, start, end in text_timestamp:
+ ax1.vlines(start, -0.3, 0.3, ls="--")
+ ax1.vlines(end, -0.3, 0.3, ls="--")
+ x_adj = 0.045 if char != "<sil>" else 0.12
ax1.text((start + end) * 0.5 - x_adj, 0, char)
# plt.legend()
plotname = "{}/timestamp.png".format(dest)
- plt.savefig(plotname, bbox_inches='tight')
+ plt.savefig(plotname, bbox_inches="tight")
- def load_data(self,
- wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
+ def load_data(self, wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
def load_wav(path: str) -> np.ndarray:
waveform, _ = librosa.load(path, sr=fs)
return waveform
@@ -132,12 +145,9 @@
if isinstance(wav_content, list):
return [load_wav(path) for path in wav_content]
- raise TypeError(
- f'The type of {wav_content} is not in [str, np.ndarray, list]')
+ raise TypeError(f"The type of {wav_content} is not in [str, np.ndarray, list]")
- def extract_feat(self,
- waveform_list: List[np.ndarray]
- ) -> Tuple[np.ndarray, np.ndarray]:
+ def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
feats, feats_len = [], []
for waveform in waveform_list:
speech, _ = self.frontend.fbank(waveform)
@@ -155,24 +165,23 @@
def pad_feats(feats: List[np.ndarray], max_feat_len: int) -> np.ndarray:
def pad_feat(feat: np.ndarray, cur_len: int) -> np.ndarray:
pad_width = ((0, max_feat_len - cur_len), (0, 0))
- return np.pad(feat, pad_width, 'constant', constant_values=0)
+ return np.pad(feat, pad_width, "constant", constant_values=0)
feat_res = [pad_feat(feat, feat.shape[0]) for feat in feats]
feats = np.array(feat_res).astype(np.float32)
return feats
- def infer(self, feats: np.ndarray,
- feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
+ def infer(self, feats: np.ndarray, feats_len: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
outputs = self.ort_infer([feats, feats_len])
return outputs
def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
- return [self.decode_one(am_score, token_num)
- for am_score, token_num in zip(am_scores, token_nums)]
+ return [
+ self.decode_one(am_score, token_num)
+ for am_score, token_num in zip(am_scores, token_nums)
+ ]
- def decode_one(self,
- am_score: np.ndarray,
- valid_token_num: int) -> List[str]:
+ def decode_one(self, am_score: np.ndarray, valid_token_num: int) -> List[str]:
yseq = am_score.argmax(axis=-1)
score = am_score.max(axis=-1)
score = np.sum(score, axis=-1)
@@ -191,7 +200,6 @@
# Change integer-ids to tokens
token = self.converter.ids2tokens(token_int)
- token = token[:valid_token_num-self.pred_bias]
+ token = token[: valid_token_num - self.pred_bias]
# texts = sentence_postprocess(token)
return token
-
--
Gitblit v1.9.1