| | |
| | | import os |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import time |
| | | import torch |
| | | import logging |
| | | from typing import Dict, Tuple |
| | | from contextlib import contextmanager |
| | | from distutils.version import LooseVersion |
| | | from typing import Dict |
| | | from typing import List |
| | | from typing import Optional |
| | | from typing import Tuple |
| | | from typing import Union |
| | | import tempfile |
| | | import codecs |
| | | import requests |
| | | import re |
| | | import copy |
| | | import torch |
| | | import torch.nn as nn |
| | | import random |
| | | import numpy as np |
| | | import time |
| | | # from funasr.layers.abs_normalize import AbsNormalize |
| | | from funasr.losses.label_smoothing_loss import ( |
| | | LabelSmoothingLoss, # noqa: H301 |
| | | ) |
| | | |
| | | from funasr.register import tables |
| | | from funasr.models.ctc.ctc import CTC |
| | | from funasr.utils import postprocess_utils |
| | | from funasr.metrics.compute_acc import th_accuracy |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | from funasr.models.paraformer.model import Paraformer |
| | | from funasr.models.paraformer.search import Hypothesis |
| | | from funasr.models.paraformer.cif_predictor import mae_loss |
| | | |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | from funasr.losses.label_smoothing_loss import LabelSmoothingLoss |
| | | from funasr.models.transformer.utils.add_sos_eos import add_sos_eos |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list |
| | | from funasr.metrics.compute_acc import th_accuracy |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | |
| | | from funasr.models.paraformer.search import Hypothesis |
| | | |
| | | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | | from torch.cuda.amp import autocast |
| | |
| | | @contextmanager |
| | | def autocast(enabled=True): |
| | | yield |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | from funasr.utils import postprocess_utils |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard |
| | | |
| | | from funasr.models.ctc.ctc import CTC |
| | | from funasr.models.paraformer.model import Paraformer |
| | | |
| | | from funasr.register import tables |
| | | |
| | | @tables.register("model_classes", "ParaformerStreaming") |
| | | class ParaformerStreaming(Paraformer): |
| | |
| | | decoder_out_1st = None |
| | | pre_loss_att = None |
| | | if self.sampling_ratio > 0.0: |
| | | if self.step_cur < 2: |
| | | logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) |
| | | |
| | | if self.use_1st_decoder_loss: |
| | | sematic_embeds, decoder_out_1st, pre_loss_att = \ |
| | | self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad, |
| | |
| | | self.sampler(encoder_out, encoder_out_lens, ys_pad, |
| | | ys_pad_lens, pre_acoustic_embeds, scama_mask) |
| | | else: |
| | | if self.step_cur < 2: |
| | | logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio)) |
| | | sematic_embeds = pre_acoustic_embeds |
| | | |
| | | # 1. Forward decoder |
| | |
| | | |
| | | return results |
| | | |
| | | def generate(self, |
| | | def inference(self, |
| | | data_in, |
| | | data_lengths=None, |
| | | key: list = None, |
| | |
| | | logging.info("enable beam_search") |
| | | self.init_beam_search(**kwargs) |
| | | self.nbest = kwargs.get("nbest", 1) |
| | | |
| | | |
| | | |
| | | if len(cache) == 0: |
| | | self.init_cache(cache, **kwargs) |
| | | |
| | |
| | | self.init_cache(cache, **kwargs) |
| | | |
| | | if kwargs.get("output_dir"): |
| | | writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = writer[f"{1}best_recog"] |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{1}best_recog"] |
| | | ibest_writer["token"][key[0]] = " ".join(tokens) |
| | | ibest_writer["text"][key[0]] = text_postprocessed |
| | | |
| | | |
| | | return result, meta_data |
| | | |
| | | |