| | |
| | | import os |
| | | import logging |
| | | from contextlib import contextmanager |
| | | from distutils.version import LooseVersion |
| | | from typing import Dict |
| | | from typing import List |
| | | from typing import Optional |
| | | from typing import Tuple |
| | | from typing import Union |
| | | import tempfile |
| | | import codecs |
| | | import requests |
| | | import re |
| | | import time |
| | | import copy |
| | | import torch |
| | | import torch.nn as nn |
| | | import random |
| | | import codecs |
| | | import logging |
| | | import tempfile |
| | | import requests |
| | | import numpy as np |
| | | import time |
| | | # from funasr.layers.abs_normalize import AbsNormalize |
| | | from typing import Dict |
| | | from typing import List |
| | | from typing import Tuple |
| | | from typing import Union |
| | | from typing import Optional |
| | | from contextlib import contextmanager |
| | | from distutils.version import LooseVersion |
| | | |
| | | from funasr.losses.label_smoothing_loss import ( |
| | | LabelSmoothingLoss, # noqa: H301 |
| | | ) |
| | | # from funasr.models.ctc import CTC |
| | | # from funasr.models.decoder.abs_decoder import AbsDecoder |
| | | # from funasr.models.e2e_asr_common import ErrorCalculator |
| | | # from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | # from funasr.frontends.abs_frontend import AbsFrontend |
| | | # from funasr.models.postencoder.abs_postencoder import AbsPostEncoder |
| | | from funasr.models.paraformer.cif_predictor import mae_loss |
| | | # from funasr.models.preencoder.abs_preencoder import AbsPreEncoder |
| | | # from funasr.models.specaug.abs_specaug import AbsSpecAug |
| | | from funasr.models.transformer.utils.add_sos_eos import add_sos_eos |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list |
| | | from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard |
| | | from funasr.metrics.compute_acc import th_accuracy |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | # from funasr.models.base_model import FunASRModel |
| | | # from funasr.models.paraformer.cif_predictor import CifPredictorV3 |
| | | from funasr.models.paraformer.search import Hypothesis |
| | | |
| | | |
| | |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | |
| | | from funasr.models.paraformer.model import Paraformer |
| | | from funasr.models.bicif_paraformer.model import BiCifParaformer |
| | | from funasr.register import tables |
| | | |
| | | |
| | | @tables.register("model_classes", "SeacoParaformer") |
| | | class SeacoParaformer(Paraformer): |
| | | class SeacoParaformer(BiCifParaformer, Paraformer): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | SeACo-Paraformer: A Non-Autoregressive ASR System with Flexible and Effective Hotword Customization Ability |
| | |
| | | selected = hw_embed[_ind, [i-1 for i in hotword_lengths.detach().cpu().tolist()]] |
| | | return selected |
| | | |
| | | ''' |
| | | def calc_predictor(self, encoder_out, encoder_out_lens): |
| | | encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( |
| | | encoder_out.device) |
| | | pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index, pre_token_length2 = self.predictor(encoder_out, |
| | | None, |
| | | encoder_out_mask, |
| | | ignore_id=self.ignore_id) |
| | | return pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index |
| | | |
| | | |
| | | def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num): |
| | | encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( |
| | | encoder_out.device) |
| | | ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out, |
| | | encoder_out_mask, |
| | | token_num) |
| | | return ds_alphas, ds_cif_peak, us_alphas, us_peaks |
| | | ''' |
| | | |
| | | def generate(self, |
| | | data_in, |
| | | data_lengths=None, |
| | |
| | | pre_token_length, |
| | | hw_list=self.hotword_list) |
| | | # decoder_out, _ = decoder_outs[0], decoder_outs[1] |
| | | _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens, |
| | | pre_token_length) |
| | | |
| | | results = [] |
| | | b, n, d = decoder_out.size() |
| | |
| | | token = tokenizer.ids2tokens(token_int) |
| | | text = tokenizer.tokens2text(token) |
| | | |
| | | text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) |
| | | result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed} |
| | | _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3], |
| | | us_peaks[i][:encoder_out_lens[i] * 3], |
| | | copy.copy(token), |
| | | vad_offset=kwargs.get("begin_time", 0)) |
| | | |
| | | text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess( |
| | | token, timestamp) |
| | | |
| | | result_i = {"key": key[i], "text": text_postprocessed, |
| | | "timestamp": time_stamp_postprocessed, |
| | | } |
| | | |
| | | if ibest_writer is not None: |
| | | ibest_writer["token"][key[i]] = " ".join(token) |
| | | ibest_writer["text"][key[i]] = text |
| | | ibest_writer["text_postprocessed"][key[i]] = text_postprocessed |
| | | # ibest_writer["text"][key[i]] = text |
| | | ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed |
| | | ibest_writer["text"][key[i]] = text_postprocessed |
| | | else: |
| | | result_i = {"key": key[i], "token_int": token_int} |
| | | results.append(result_i) |