zhifu gao
2024-03-11 15c4709beb4b588db2135fc1133cd6955b5ef819
funasr/models/bicif_paraformer/model.py
@@ -1,37 +1,38 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import logging
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import tempfile
import codecs
import requests
import re
import copy
import torch
import torch.nn as nn
import random
import numpy as np
import time
import torch
import logging
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict, List, Optional, Tuple
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.metrics.compute_acc import th_accuracy
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.paraformer.search import Hypothesis
from funasr.utils.load_utils import load_audio_and_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.register import tables
from funasr.models.ctc.ctc import CTC
from funasr.utils import postprocess_utils
from funasr.metrics.compute_acc import th_accuracy
from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.paraformer.model import Paraformer
from funasr.models.paraformer.search import Hypothesis
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.train_utils.device_funcs import to_device
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
    # Nothing to do if torch<1.6.0
    @contextmanager
    def autocast(enabled=True):
        yield
@tables.register("model_classes", "BiCifParaformer")
class BiCifParaformer(Paraformer):
@@ -216,7 +217,7 @@
        return loss, stats, weight
    def generate(self,
    def inference(self,
                 data_in,
                 data_lengths=None,
                 key: list = None,
@@ -234,25 +235,26 @@
            self.nbest = kwargs.get("nbest", 1)
        
        meta_data = {}
        if isinstance(data_in, torch.Tensor):  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
            if speech_lengths is None:
                speech_lengths = speech.shape[1]
        else:
            # extract fbank feats
            time1 = time.perf_counter()
            audio_sample_list = load_audio_and_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
            time2 = time.perf_counter()
            meta_data["load_data"] = f"{time2 - time1:0.3f}"
            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
                                                   frontend=frontend)
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
        # if isinstance(data_in, torch.Tensor):  # fbank
        #     speech, speech_lengths = data_in, data_lengths
        #     if len(speech.shape) < 3:
        #         speech = speech[None, :, :]
        #     if speech_lengths is None:
        #         speech_lengths = speech.shape[1]
        # else:
        # extract fbank feats
        time1 = time.perf_counter()
        audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
        time2 = time.perf_counter()
        meta_data["load_data"] = f"{time2 - time1:0.3f}"
        speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
                                                frontend=frontend)
        time3 = time.perf_counter()
        meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
        meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
        
        speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
        speech = speech.to(device=kwargs["device"])
        speech_lengths = speech_lengths.to(device=kwargs["device"])
        
        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
@@ -298,9 +300,11 @@
                nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
            for nbest_idx, hyp in enumerate(nbest_hyps):
                ibest_writer = None
                if ibest_writer is None and kwargs.get("output_dir") is not None:
                    writer = DatadirWriter(kwargs.get("output_dir"))
                    ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
                if kwargs.get("output_dir") is not None:
                    if not hasattr(self, "writer"):
                        self.writer = DatadirWriter(kwargs.get("output_dir"))
                    ibest_writer = self.writer[f"{nbest_idx+1}best_recog"]
                # remove sos/eos and get results
                last_pos = -1
                if isinstance(hyp.yseq, list):
@@ -337,4 +341,88 @@
                    result_i = {"key": key[i], "token_int": token_int}
                results.append(result_i)
        
        return results, meta_data
        return results, meta_data
    def export(
        self,
        max_seq_len=512,
        **kwargs,
    ):
        self.device = kwargs.get("device")
        is_onnx = kwargs.get("type", "onnx") == "onnx"
        encoder_class = tables.encoder_classes.get(kwargs["encoder"] + "Export")
        self.encoder = encoder_class(self.encoder, onnx=is_onnx)
        predictor_class = tables.predictor_classes.get(kwargs["predictor"] + "Export")
        self.predictor = predictor_class(self.predictor, onnx=is_onnx)
        decoder_class = tables.decoder_classes.get(kwargs["decoder"] + "Export")
        self.decoder = decoder_class(self.decoder, onnx=is_onnx)
        from funasr.utils.torch_function import sequence_mask
        self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        self.forward = self.export_forward
        return self
    def export_forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
    ):
        # a. To device
        batch = {"speech": speech, "speech_lengths": speech_lengths}
        batch = to_device(batch, device=self.device)
        enc, enc_len = self.encoder(**batch)
        mask = self.make_pad_mask(enc_len)[:, None, :]
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
        pre_token_length = pre_token_length.round().type(torch.int32)
        decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length)
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        # get predicted timestamps
        us_alphas, us_cif_peak = self.predictor.get_upsample_timestmap(enc, mask, pre_token_length)
        return decoder_out, pre_token_length, us_alphas, us_cif_peak
    def export_dummy_inputs(self):
        speech = torch.randn(2, 30, 560)
        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
        return (speech, speech_lengths)
    def export_input_names(self):
        return ['speech', 'speech_lengths']
    def export_output_names(self):
        return ['logits', 'token_num', 'us_alphas', 'us_cif_peak']
    def export_dynamic_axes(self):
        return {
            'speech': {
                0: 'batch_size',
                1: 'feats_length'
            },
            'speech_lengths': {
                0: 'batch_size',
            },
            'logits': {
                0: 'batch_size',
                1: 'logits_length'
            },
            'us_alphas': {
                0: 'batch_size',
                1: 'alphas_length'
            },
            'us_cif_peak': {
                0: 'batch_size',
                1: 'alphas_length'
            },
        }
    def export_name(self, ):
        return "model.onnx"