雾聪
2024-03-29 9ba0dbd98bf69c830dfcfde8f109a400cb65e4e5
funasr/models/paraformer_streaming/model.py
@@ -1,35 +1,29 @@
import os
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import time
import torch
import logging
from typing import Dict, Tuple
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import tempfile
import codecs
import requests
import re
import copy
import torch
import torch.nn as nn
import random
import numpy as np
import time
# from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
    LabelSmoothingLoss,  # noqa: H301
)
from funasr.register import tables
from funasr.models.ctc.ctc import CTC
from funasr.utils import postprocess_utils
from funasr.metrics.compute_acc import th_accuracy
from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.paraformer.model import Paraformer
from funasr.models.paraformer.search import Hypothesis
from funasr.models.paraformer.cif_predictor import mae_loss
from funasr.train_utils.device_funcs import force_gatherable
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.metrics.compute_acc import th_accuracy
from funasr.train_utils.device_funcs import force_gatherable
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.models.paraformer.search import Hypothesis
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
@@ -38,15 +32,7 @@
    @contextmanager
    def autocast(enabled=True):
        yield
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.models.ctc.ctc import CTC
from funasr.models.paraformer.model import Paraformer
from funasr.register import tables
@tables.register("model_classes", "ParaformerStreaming")
class ParaformerStreaming(Paraformer):
@@ -249,8 +235,7 @@
        decoder_out_1st = None
        pre_loss_att = None
        if self.sampling_ratio > 0.0:
            if self.step_cur < 2:
                logging.info("enable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
            if self.use_1st_decoder_loss:
                sematic_embeds, decoder_out_1st, pre_loss_att = \
                    self.sampler_with_grad(encoder_out, encoder_out_lens, ys_pad,
@@ -260,8 +245,6 @@
                    self.sampler(encoder_out, encoder_out_lens, ys_pad,
                                 ys_pad_lens, pre_acoustic_embeds, scama_mask)
        else:
            if self.step_cur < 2:
                logging.info("disable sampler in paraformer, sampling_ratio: {}".format(self.sampling_ratio))
            sematic_embeds = pre_acoustic_embeds
        
        # 1. Forward decoder
@@ -499,7 +482,7 @@
        
        return results
    
    def generate(self,
    def inference(self,
                 data_in,
                 data_lengths=None,
                 key: list = None,
@@ -516,8 +499,7 @@
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)
        if len(cache) == 0:
            self.init_cache(cache, **kwargs)
        
@@ -571,11 +553,137 @@
            self.init_cache(cache, **kwargs)
        
        if kwargs.get("output_dir"):
            writer = DatadirWriter(kwargs.get("output_dir"))
            ibest_writer = writer[f"{1}best_recog"]
            if not hasattr(self, "writer"):
                self.writer = DatadirWriter(kwargs.get("output_dir"))
            ibest_writer = self.writer[f"{1}best_recog"]
            ibest_writer["token"][key[0]] = " ".join(tokens)
            ibest_writer["text"][key[0]] = text_postprocessed
        return result, meta_data
    def export(
        self,
        max_seq_len=512,
        **kwargs,
    ):
        self.device = kwargs.get("device")
        is_onnx = kwargs.get("type", "onnx") == "onnx"
        encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
        predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export")
        self.predictor = predictor_class(self.predictor, onnx=is_onnx)
        if kwargs["decoder"] == "ParaformerSANMDecoder":
            kwargs["decoder"] = "ParaformerSANMDecoderOnline"
        decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export")
        self.decoder = decoder_class(self.decoder, onnx=is_onnx)
        from funasr.utils.torch_function import sequence_mask
        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        import copy
        import types
        encoder_model = copy.copy(self)
        decoder_model = copy.copy(self)
        # encoder
        encoder_model.forward = types.MethodType(ParaformerStreaming.export_encoder_forward, encoder_model)
        encoder_model.export_dummy_inputs = types.MethodType(ParaformerStreaming.export_encoder_dummy_inputs, encoder_model)
        encoder_model.export_input_names = types.MethodType(ParaformerStreaming.export_encoder_input_names, encoder_model)
        encoder_model.export_output_names = types.MethodType(ParaformerStreaming.export_encoder_output_names, encoder_model)
        encoder_model.export_dynamic_axes = types.MethodType(ParaformerStreaming.export_encoder_dynamic_axes, encoder_model)
        encoder_model.export_name = types.MethodType(ParaformerStreaming.export_encoder_name, encoder_model)
        # decoder
        decoder_model.forward = types.MethodType(ParaformerStreaming.export_decoder_forward, decoder_model)
        decoder_model.export_dummy_inputs = types.MethodType(ParaformerStreaming.export_decoder_dummy_inputs, decoder_model)
        decoder_model.export_input_names = types.MethodType(ParaformerStreaming.export_decoder_input_names, decoder_model)
        decoder_model.export_output_names = types.MethodType(ParaformerStreaming.export_decoder_output_names, decoder_model)
        decoder_model.export_dynamic_axes = types.MethodType(ParaformerStreaming.export_decoder_dynamic_axes, decoder_model)
        decoder_model.export_name = types.MethodType(ParaformerStreaming.export_decoder_name, decoder_model)
        return encoder_model, decoder_model
    def export_encoder_forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
    ):
        # a. To device
        batch = {"speech": speech, "speech_lengths": speech_lengths, "online": True}
        # batch = to_device(batch, device=self.device)
        enc, enc_len = self.encoder(**batch)
        mask = self.make_pad_mask(enc_len)[:, None, :]
        alphas, _ = self.predictor.forward_cnn(enc, mask)
        return enc, enc_len, alphas
    def export_encoder_dummy_inputs(self):
        speech = torch.randn(2, 30, 560)
        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
        return (speech, speech_lengths)
    def export_encoder_input_names(self):
        return ['speech', 'speech_lengths']
    def export_encoder_output_names(self):
        return ['enc', 'enc_len', 'alphas']
    def export_encoder_dynamic_axes(self):
        return {
            'speech': {
                0: 'batch_size',
                1: 'feats_length'
            },
            'speech_lengths': {
                0: 'batch_size',
            },
            'enc': {
                0: 'batch_size',
                1: 'feats_length'
            },
            'enc_len': {
                0: 'batch_size',
            },
            'alphas': {
                0: 'batch_size',
                1: 'feats_length'
            },
        }
    def export_encoder_name(self):
        return "model.onnx"
    def export_decoder_forward(
        self,
        enc: torch.Tensor,
        enc_len: torch.Tensor,
        acoustic_embeds: torch.Tensor,
        acoustic_embeds_len: torch.Tensor,
        *args,
    ):
        decoder_out, out_caches = self.decoder(enc, enc_len, acoustic_embeds, acoustic_embeds_len, *args)
        sample_ids = decoder_out.argmax(dim=-1)
        return decoder_out, sample_ids, out_caches
    def export_decoder_dummy_inputs(self):
        dummy_inputs = self.decoder.get_dummy_inputs(enc_size=self.encoder._output_size)
        return dummy_inputs
    def export_decoder_input_names(self):
        return self.decoder.get_input_names()
    def export_decoder_output_names(self):
        return self.decoder.get_output_names()
    def export_decoder_dynamic_axes(self):
        return self.decoder.get_dynamic_axes()
    def export_decoder_name(self):
        return "decoder.onnx"