| | |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import os |
| | | import re |
| | | import time |
| | |
| | | import tempfile |
| | | import requests |
| | | import numpy as np |
| | | from typing import Dict |
| | | from typing import List |
| | | from typing import Tuple |
| | | from typing import Union |
| | | from typing import Optional |
| | | from typing import Dict, Tuple |
| | | from contextlib import contextmanager |
| | | from distutils.version import LooseVersion |
| | | |
| | | from funasr.losses.label_smoothing_loss import ( |
| | | LabelSmoothingLoss, # noqa: H301 |
| | | ) |
| | | from funasr.register import tables |
| | | from funasr.utils import postprocess_utils |
| | | from funasr.metrics.compute_acc import th_accuracy |
| | | from funasr.models.paraformer.model import Paraformer |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | from funasr.models.paraformer.search import Hypothesis |
| | | from funasr.models.paraformer.cif_predictor import mae_loss |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | from funasr.models.bicif_paraformer.model import BiCifParaformer |
| | | from funasr.losses.label_smoothing_loss import LabelSmoothingLoss |
| | | from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard |
| | | from funasr.models.transformer.utils.add_sos_eos import add_sos_eos |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list |
| | | from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard |
| | | from funasr.metrics.compute_acc import th_accuracy |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | from funasr.models.paraformer.search import Hypothesis |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | |
| | | |
| | | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | |
| | | @contextmanager |
| | | def autocast(enabled=True): |
| | | yield |
| | | from funasr.utils.load_utils import load_audio_and_text_image_video, extract_fbank |
| | | from funasr.utils import postprocess_utils |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | |
| | | from funasr.models.paraformer.model import Paraformer |
| | | from funasr.models.bicif_paraformer.model import BiCifParaformer |
| | | from funasr.register import tables |
| | | |
| | | |
| | | @tables.register("model_classes", "SeacoParaformer") |
| | |
| | | seaco_decoder = kwargs.get("seaco_decoder", None) |
| | | if seaco_decoder is not None: |
| | | seaco_decoder_conf = kwargs.get("seaco_decoder_conf") |
| | | seaco_decoder_class = tables.decoder_classes.get(seaco_decoder.lower()) |
| | | seaco_decoder_class = tables.decoder_classes.get(seaco_decoder) |
| | | self.seaco_decoder = seaco_decoder_class( |
| | | vocab_size=self.vocab_size, |
| | | encoder_output_size=self.inner_dim, |
| | |
| | | return ds_alphas, ds_cif_peak, us_alphas, us_peaks |
| | | ''' |
| | | |
| | | def generate(self, |
| | | def inference(self, |
| | | data_in, |
| | | data_lengths=None, |
| | | key: list = None, |
| | |
| | | |
| | | # extract fbank feats |
| | | time1 = time.perf_counter() |
| | | audio_sample_list = load_audio_and_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) |
| | | audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), |
| | |
| | | meta_data[ |
| | | "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 |
| | | |
| | | speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"]) |
| | | speech = speech.to(device=kwargs["device"]) |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | | |
| | | # hotword |
| | | self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend) |