| | |
| | | import os |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import time |
| | | import torch |
| | | import logging |
| | | from typing import Dict, Tuple |
| | | from contextlib import contextmanager |
| | | from distutils.version import LooseVersion |
| | | from typing import Dict |
| | | from typing import List |
| | | from typing import Optional |
| | | from typing import Tuple |
| | | from typing import Union |
| | | import tempfile |
| | | import codecs |
| | | import requests |
| | | import re |
| | | import copy |
| | | import torch |
| | | import torch.nn as nn |
| | | import random |
| | | import numpy as np |
| | | import time |
| | | # from funasr.layers.abs_normalize import AbsNormalize |
| | | from funasr.losses.label_smoothing_loss import ( |
| | | LabelSmoothingLoss, # noqa: H301 |
| | | ) |
| | | |
| | | from funasr.register import tables |
| | | from funasr.models.ctc.ctc import CTC |
| | | from funasr.utils import postprocess_utils |
| | | from funasr.metrics.compute_acc import th_accuracy |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | from funasr.models.paraformer.model import Paraformer |
| | | from funasr.models.paraformer.search import Hypothesis |
| | | from funasr.models.paraformer.cif_predictor import mae_loss |
| | | |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | from funasr.losses.label_smoothing_loss import LabelSmoothingLoss |
| | | from funasr.models.transformer.utils.add_sos_eos import add_sos_eos |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list |
| | | from funasr.metrics.compute_acc import th_accuracy |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | |
| | | from funasr.models.paraformer.search import Hypothesis |
| | | |
| | | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | | from torch.cuda.amp import autocast |
| | |
| | | @contextmanager |
| | | def autocast(enabled=True): |
| | | yield |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | from funasr.utils import postprocess_utils |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard |
| | | |
| | | from funasr.models.ctc.ctc import CTC |
| | | from funasr.models.paraformer.model import Paraformer |
| | | |
| | | from funasr.register import tables |
| | | |
| | | @tables.register("model_classes", "ParaformerStreaming") |
| | | class ParaformerStreaming(Paraformer): |
| | |
| | | decoder_out_1st = None |
| | | pre_loss_att = None |
| | | if self.sampling_ratio > 0.0: |
| | | if self.step_cur < 2: |
| | | logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) |
| | | |
| | | if self.use_1st_decoder_loss: |
| | | sematic_embeds, decoder_out_1st, pre_loss_att = \ |
| | | self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad, |
| | |
| | | self.sampler(encoder_out, encoder_out_lens, ys_pad, |
| | | ys_pad_lens, pre_acoustic_embeds, scama_mask) |
| | | else: |
| | | if self.step_cur < 2: |
| | | logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) |
| | | sematic_embeds = pre_acoustic_embeds |
| | | |
| | | # 1. Forward decoder |
| | |
| | | |
| | | return results |
| | | |
| | | def generate(self, |
| | | def inference(self, |
| | | data_in, |
| | | data_lengths=None, |
| | | key: list = None, |
| | |
| | | logging.info("enable beam_search") |
| | | self.init_beam_search(**kwargs) |
| | | self.nbest = kwargs.get("nbest", 1) |
| | | |
| | | |
| | | |
| | | if len(cache) == 0: |
| | | self.init_cache(cache, **kwargs) |
| | | |
| | |
| | | for i in range(n): |
| | | kwargs["is_final"] = _is_final and i == n -1 |
| | | audio_sample_i = audio_sample[i*chunk_stride_samples:(i+1)*chunk_stride_samples] |
| | | |
| | | # extract fbank feats |
| | | speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"), |
| | | frontend=frontend, cache=cache["frontend"], is_final=kwargs["is_final"]) |
| | | if kwargs["is_final"] and len(audio_sample_i) < 960: |
| | | cache["encoder"]["tail_chunk"] = True |
| | | speech = cache["encoder"]["feats"] |
| | | speech_lengths = torch.tensor([speech.shape[1]], dtype=torch.int64).to(speech.device) |
| | | else: |
| | | # extract fbank feats |
| | | speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"), |
| | | frontend=frontend, cache=cache["frontend"], is_final=kwargs["is_final"]) |
| | | time3 = time.perf_counter() |
| | | meta_data["extract_feat"] = f"{time3 - time2:0.3f}" |
| | | meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 |
| | |
| | | self.init_cache(cache, **kwargs) |
| | | |
| | | if kwargs.get("output_dir"): |
| | | writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = writer[f"{1}best_recog"] |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{1}best_recog"] |
| | | ibest_writer["token"][key[0]] = " ".join(tokens) |
| | | ibest_writer["text"][key[0]] = text_postprocessed |
| | | |
| | | |
| | | return result, meta_data |
| | | |
| | | |
| | | |
| | | def export(self, **kwargs): |
| | | from .export_meta import export_rebuild_model |
| | | models = export_rebuild_model(model=self, **kwargs) |
| | | return models |