语帆
2024-02-29 96eaabca5b2e9c93b40c9840e2ae0003a618bb6e
funasr/models/seaco_paraformer/model.py
@@ -1,3 +1,8 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import os
import re
import time
@@ -8,26 +13,26 @@
import tempfile
import requests
import numpy as np
from typing import Dict
from typing import List
from typing import Tuple
from typing import Union
from typing import Optional
from typing import Dict, Tuple
from contextlib import contextmanager
from distutils.version import LooseVersion
from funasr.losses.label_smoothing_loss import (
    LabelSmoothingLoss,  # noqa: H301
)
from funasr.register import tables
from funasr.utils import postprocess_utils
from funasr.metrics.compute_acc import th_accuracy
from funasr.models.paraformer.model import Paraformer
from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.paraformer.search import Hypothesis
from funasr.models.paraformer.cif_predictor import mae_loss
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.bicif_paraformer.model import BiCifParaformer
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.metrics.compute_acc import th_accuracy
from funasr.train_utils.device_funcs import force_gatherable
from funasr.models.paraformer.search import Hypothesis
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
import pdb
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
@@ -35,13 +40,6 @@
    @contextmanager
    def autocast(enabled=True):
        yield
from funasr.utils.load_utils import load_audio_and_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.models.paraformer.model import Paraformer
from funasr.models.bicif_paraformer.model import BiCifParaformer
from funasr.register import tables
@tables.register("model_classes", "SeacoParaformer")
@@ -68,7 +66,6 @@
  
        # bias encoder
        if self.bias_encoder_type == 'lstm':
            logging.warning("enable bias encoder sampling and contextual training")
            self.bias_encoder = torch.nn.LSTM(self.inner_dim, 
                                              self.inner_dim, 
                                              2, 
@@ -81,7 +78,6 @@
                self.lstm_proj = None
            self.bias_embed = torch.nn.Embedding(self.vocab_size, self.inner_dim)
        elif self.bias_encoder_type == 'mean':
            logging.warning("enable bias encoder sampling and contextual training")
            self.bias_embed = torch.nn.Embedding(self.vocab_size, self.inner_dim)
        else:
            logging.error("Unsupport bias encoder type: {}".format(self.bias_encoder_type))
@@ -90,7 +86,7 @@
        seaco_decoder = kwargs.get("seaco_decoder", None)
        if seaco_decoder is not None:
            seaco_decoder_conf = kwargs.get("seaco_decoder_conf")
            seaco_decoder_class = tables.decoder_classes.get(seaco_decoder.lower())
            seaco_decoder_class = tables.decoder_classes.get(seaco_decoder)
            self.seaco_decoder = seaco_decoder_class(
                vocab_size=self.vocab_size,
                encoder_output_size=self.inner_dim,
@@ -134,7 +130,7 @@
        hotword_pad = kwargs.get("hotword_pad")
        hotword_lengths = kwargs.get("hotword_lengths")
        dha_pad = kwargs.get("dha_pad")
        batch_size = speech.shape[0]
        self.step_cur += 1
        # for data-parallel
@@ -214,26 +210,31 @@
                               ys_pad_lens, 
                               hw_list,
                               nfilter=50,
                                 seaco_weight=1.0):
                               seaco_weight=1.0):
        # decoder forward
        decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True)
        decoder_pred = torch.log_softmax(decoder_out, dim=-1)
        if hw_list is not None:
            hw_lengths = [len(i) for i in hw_list]
            hw_list_ = [torch.Tensor(i).long() for i in hw_list]
            hw_list_pad = pad_list(hw_list_, 0).to(encoder_out.device)
            selected = self._hotword_representation(hw_list_pad, torch.Tensor(hw_lengths).int().to(encoder_out.device))
            contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device)
            num_hot_word = contextual_info.shape[1]
            _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device)
            # ASF Core
            if nfilter > 0 and nfilter < num_hot_word:
                for dec in self.seaco_decoder.decoders:
                    dec.reserve_attn = True
                # cif_attended, _ = self.decoder2(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens)
                dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens)
                # cif_filter = torch.topk(self.decoder2.decoders[-1].attn_mat[0][0].sum(0).sum(0)[:-1], min(nfilter, num_hot_word-1))[1].tolist()
                hotword_scores = self.seaco_decoder.decoders[-1].attn_mat[0][0].sum(0).sum(0)[:-1]
                # hotword_scores /= torch.sqrt(torch.tensor(hw_lengths)[:-1].float()).to(hotword_scores.device)
                dec_filter = torch.topk(hotword_scores, min(nfilter, num_hot_word-1))[1].tolist()
@@ -248,18 +249,16 @@
                for dec in self.seaco_decoder.decoders:
                    dec.attn_mat = []
                    dec.reserve_attn = False
            # SeACo Core
            cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens)
            dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens)
            merged = self._merge(cif_attended, dec_attended)
            dha_output = self.hotword_output_layer(merged)  # remove the last token in loss calculation
            dha_pred = torch.log_softmax(dha_output, dim=-1)
            # import pdb; pdb.set_trace()
            def _merge_res(dec_output, dha_output):
                lmbd = torch.Tensor([seaco_weight] * dha_output.shape[0])
                dha_ids = dha_output.max(-1)[-1][0]
                dha_ids = dha_output.max(-1)[-1]# [0]
                dha_mask = (dha_ids == 8377).int().unsqueeze(-1)
                a = (1 - lmbd) / lmbd
                b = 1 / lmbd
@@ -268,6 +267,7 @@
                # logits = dec_output * dha_mask + dha_output[:,:,:-1] * (1-dha_mask)
                logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask)
                return logits
            merged_pred = _merge_res(decoder_pred, dha_pred)
            return merged_pred
        else:
@@ -306,7 +306,7 @@
        return ds_alphas, ds_cif_peak, us_alphas, us_peaks
    '''        
  
    def generate(self,
    def inference(self,
                 data_in,
                 data_lengths=None,
                 key: list = None,
@@ -322,12 +322,11 @@
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)
        meta_data = {}
        
        # extract fbank feats
        time1 = time.perf_counter()
        audio_sample_list = load_audio_and_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
        audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
        time2 = time.perf_counter()
        meta_data["load_data"] = f"{time2 - time1:0.3f}"
        speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
@@ -337,8 +336,9 @@
        meta_data[
            "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
        
        speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
        speech = speech.to(device=kwargs["device"])
        speech_lengths = speech_lengths.to(device=kwargs["device"])
        # hotword
        self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend)
        
@@ -347,6 +347,7 @@
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]
        
        # predictor
        predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
        pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \
@@ -355,15 +356,14 @@
        if torch.max(pre_token_length) < 1:
            return []
        decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens,
                                                   pre_acoustic_embeds,
                                                   pre_token_length,
                                                   hw_list=self.hotword_list)
        # decoder_out, _ = decoder_outs[0], decoder_outs[1]
        _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
                                                                  pre_token_length)
        results = []
        b, n, d = decoder_out.size()
        for i in range(b):
@@ -388,9 +388,11 @@
                nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
            for nbest_idx, hyp in enumerate(nbest_hyps):
                ibest_writer = None
                if ibest_writer is None and kwargs.get("output_dir") is not None:
                    writer = DatadirWriter(kwargs.get("output_dir"))
                    ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
                if kwargs.get("output_dir") is not None:
                    if not hasattr(self, "writer"):
                        self.writer = DatadirWriter(kwargs.get("output_dir"))
                    ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
                # remove sos/eos and get results
                last_pos = -1
                if isinstance(hyp.yseq, list):
@@ -416,12 +418,11 @@
                        token, timestamp)
                    result_i = {"key": key[i], "text": text_postprocessed,
                                "timestamp": time_stamp_postprocessed,
                                "timestamp": time_stamp_postprocessed
                                }
                    
                    if ibest_writer is not None:
                        ibest_writer["token"][key[i]] = " ".join(token)
                        # ibest_writer["text"][key[i]] = text
                        ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed
                        ibest_writer["text"][key[i]] = text_postprocessed
                else: