| | |
| | | # -*- encoding: utf-8 -*- |
| | | # @Author: SWHL |
| | | # @Contact: liekkaskono@163.com |
| | | from cgitb import text |
| | | import os.path |
| | | from pathlib import Path |
| | | from typing import List, Union, Tuple |
| | |
| | | def __init__(self, model_dir: Union[str, Path] = None, |
| | | batch_size: int = 1, |
| | | device_id: Union[str, int] = "-1", |
| | | plot_timestamp: bool = False, |
| | | ): |
| | | |
| | | if not Path(model_dir).exists(): |
| | |
| | | ) |
| | | self.ort_infer = OrtInferSession(model_file, device_id) |
| | | self.batch_size = batch_size |
| | | self.plot = True |
| | | self.plot = plot_timestamp |
| | | |
| | | def __call__(self, wav_content: Union[str, np.ndarray, List[str]], **kwargs) -> List: |
| | | waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq) |
| | |
| | | |
| | | def plot_wave_timestamp(self, wav, text_timestamp): |
| | | # TODO: Plot the wav and timestamp results with matplotlib |
| | | import pdb; pdb.set_trace() |
| | | import matplotlib |
| | | matplotlib.use('Agg') |
| | | matplotlib.rc("font", family='Alibaba PuHuiTi') # set it to a font that your system supports |
| | | import matplotlib.pyplot as plt |
| | | fig, ax1 = plt.subplots(figsize=(11, 3.5), dpi=320) |
| | | ax2 = ax1.twinx() |
| | | ax2.set_ylim([0, 2.0]) |
| | | # plot waveform |
| | | ax1.set_ylim([-0.3, 0.3]) |
| | | time = np.arange(wav.shape[0]) / 16000 |
| | | ax1.plot(time, wav/wav.max()*0.3, color='gray', alpha=0.4) |
| | | # plot lines and text |
| | | for (char, start, end) in text_timestamp: |
| | | ax1.vlines(start, -0.3, 0.3, ls='--') |
| | | ax1.vlines(end, -0.3, 0.3, ls='--') |
| | | x_adj = 0.045 if char != '<sil>' else 0.12 |
| | | ax1.text((start + end) * 0.5 - x_adj, 0, char) |
| | | # plt.legend() |
| | | plotname = "funasr/runtime/python/onnxruntime/debug.png" |
| | | plt.savefig(plotname, bbox_inches='tight') |
| | | |
| | | def load_data(self, |
| | | wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List: |