| | |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import os |
| | | import re |
| | | import time |
| | |
| | | import tempfile |
| | | import requests |
| | | import numpy as np |
| | | from typing import Dict |
| | | from typing import List |
| | | from typing import Tuple |
| | | from typing import Union |
| | | from typing import Optional |
| | | from typing import Dict, Tuple |
| | | from contextlib import contextmanager |
| | | from distutils.version import LooseVersion |
| | | |
| | | from funasr.losses.label_smoothing_loss import ( |
| | | LabelSmoothingLoss, # noqa: H301 |
| | | ) |
| | | from funasr.register import tables |
| | | from funasr.utils import postprocess_utils |
| | | from funasr.metrics.compute_acc import th_accuracy |
| | | from funasr.models.paraformer.model import Paraformer |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | from funasr.models.paraformer.search import Hypothesis |
| | | from funasr.models.paraformer.cif_predictor import mae_loss |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | from funasr.models.bicif_paraformer.model import BiCifParaformer |
| | | from funasr.losses.label_smoothing_loss import LabelSmoothingLoss |
| | | from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard |
| | | from funasr.models.transformer.utils.add_sos_eos import add_sos_eos |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list |
| | | from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard |
| | | from funasr.metrics.compute_acc import th_accuracy |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | from funasr.models.paraformer.search import Hypothesis |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | |
| | | |
| | | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | |
| | | @contextmanager |
| | | def autocast(enabled=True): |
| | | yield |
| | | from funasr.utils.load_utils import load_audio_and_text_image_video, extract_fbank |
| | | from funasr.utils import postprocess_utils |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | |
| | | from funasr.models.paraformer.model import Paraformer |
| | | from funasr.models.bicif_paraformer.model import BiCifParaformer |
| | | from funasr.register import tables |
| | | |
| | | |
| | | @tables.register("model_classes", "SeacoParaformer") |
| | |
| | | |
| | | # bias encoder |
| | | if self.bias_encoder_type == 'lstm': |
| | | logging.warning("enable bias encoder sampling and contextual training") |
| | | self.bias_encoder = torch.nn.LSTM(self.inner_dim, |
| | | self.inner_dim, |
| | | 2, |
| | |
| | | self.lstm_proj = None |
| | | self.bias_embed = torch.nn.Embedding(self.vocab_size, self.inner_dim) |
| | | elif self.bias_encoder_type == 'mean': |
| | | logging.warning("enable bias encoder sampling and contextual training") |
| | | self.bias_embed = torch.nn.Embedding(self.vocab_size, self.inner_dim) |
| | | else: |
| | | logging.error("Unsupport bias encoder type: {}".format(self.bias_encoder_type)) |
| | |
| | | seaco_decoder = kwargs.get("seaco_decoder", None) |
| | | if seaco_decoder is not None: |
| | | seaco_decoder_conf = kwargs.get("seaco_decoder_conf") |
| | | seaco_decoder_class = tables.decoder_classes.get(seaco_decoder.lower()) |
| | | seaco_decoder_class = tables.decoder_classes.get(seaco_decoder) |
| | | self.seaco_decoder = seaco_decoder_class( |
| | | vocab_size=self.vocab_size, |
| | | encoder_output_size=self.inner_dim, |
| | |
| | | ys_pad_lens, |
| | | hw_list, |
| | | nfilter=50, |
| | | seaco_weight=1.0): |
| | | seaco_weight=1.0): |
| | | # decoder forward |
| | | decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True) |
| | | decoder_pred = torch.log_softmax(decoder_out, dim=-1) |
| | |
| | | |
| | | dha_output = self.hotword_output_layer(merged) # remove the last token in loss calculation |
| | | dha_pred = torch.log_softmax(dha_output, dim=-1) |
| | | # import pdb; pdb.set_trace() |
| | | def _merge_res(dec_output, dha_output): |
| | | lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0]) |
| | | dha_ids = dha_output.max(-1)[-1][0] |
| | | dha_ids = dha_output.max(-1)[-1]# [0] |
| | | dha_mask = (dha_ids == 8377).int().unsqueeze(-1) |
| | | a = (1 - lmbd) / lmbd |
| | | b = 1 / lmbd |
| | |
| | | logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask) |
| | | return logits |
| | | merged_pred = _merge_res(decoder_pred, dha_pred) |
| | | # import pdb; pdb.set_trace() |
| | | return merged_pred |
| | | else: |
| | | return decoder_pred |
| | |
| | | return ds_alphas, ds_cif_peak, us_alphas, us_peaks |
| | | ''' |
| | | |
| | | def generate(self, |
| | | def inference(self, |
| | | data_in, |
| | | data_lengths=None, |
| | | key: list = None, |
| | |
| | | |
| | | # extract fbank feats |
| | | time1 = time.perf_counter() |
| | | audio_sample_list = load_audio_and_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) |
| | | audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), |
| | |
| | | meta_data[ |
| | | "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 |
| | | |
| | | speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"]) |
| | | speech = speech.to(device=kwargs["device"]) |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | | |
| | | # hotword |
| | | self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend) |
| | |
| | | token, timestamp) |
| | | |
| | | result_i = {"key": key[i], "text": text_postprocessed, |
| | | "timestamp": time_stamp_postprocessed, |
| | | "timestamp": time_stamp_postprocessed, "raw_text": copy.copy(text_postprocessed) |
| | | } |
| | | |
| | | if ibest_writer is not None: |
| | | ibest_writer["token"][key[i]] = " ".join(token) |
| | | # ibest_writer["text"][key[i]] = text |
| | | # ibest_writer["raw_text"][key[i]] = text |
| | | ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed |
| | | ibest_writer["text"][key[i]] = text_postprocessed |
| | | else: |