游雁
2023-12-21 5a8f37908469d9550f905ba0876c7c4e6f9b8026
funasr/models/bici_paraformer/model.py
@@ -29,6 +29,7 @@
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.register import tables
from funasr.models.ctc.ctc import CTC
from funasr.utils.timestamp_tools import time_stamp_sentence
from funasr.models.paraformer.model import Paraformer
@@ -211,10 +212,11 @@
      
      loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
      return loss, stats, weight
   def generate(self,
                data_in: list,
                data_lengths: list = None,
                data_in,
                data_lengths=None,
                key: list = None,
                tokenizer=None,
                frontend=None,
@@ -230,17 +232,23 @@
         self.nbest = kwargs.get("nbest", 1)
      
      meta_data = {}
      # extract fbank feats
      time1 = time.perf_counter()
      audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
      time2 = time.perf_counter()
      meta_data["load_data"] = f"{time2 - time1:0.3f}"
      speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
                                             frontend=self.frontend)
      time3 = time.perf_counter()
      meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
      meta_data[
         "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
      if isinstance(data_in, torch.Tensor):  # fbank
         speech, speech_lengths = data_in, data_lengths
         if len(speech.shape) < 3:
            speech = speech[None, :, :]
         if speech_lengths is None:
            speech_lengths = speech.shape[1]
      else:
         # extract fbank feats
         time1 = time.perf_counter()
         audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
         time2 = time.perf_counter()
         meta_data["load_data"] = f"{time2 - time1:0.3f}"
         speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
                                                frontend=frontend)
         time3 = time.perf_counter()
         meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
         meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
      
      speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
      
@@ -261,9 +269,8 @@
      decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
      
      # BiCifParaformer, test no bias cif2
      _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
                                                                             pre_token_length)
                                                                pre_token_length)
      
      results = []
      b, n, d = decoder_out.size()
@@ -302,27 +309,32 @@
            # remove blank symbol id, which is assumed to be 0
            token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
            
            # Change integer-ids to tokens
            token = tokenizer.ids2tokens(token_int)
            text = tokenizer.tokens2text(token)
            _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
                                                       us_peaks[i][:encoder_out_lens[i] * 3],
                                                       copy.copy(token),
                                                       vad_offset=kwargs.get("begin_time", 0))
            text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token, timestamp)
            result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed,
                        "time_stamp_postprocessed": time_stamp_postprocessed,
                        "word_lists": word_lists
                        }
            results.append(result_i)
            if ibest_writer is not None:
               ibest_writer["token"][key[i]] = " ".join(token)
               ibest_writer["text"][key[i]] = text
               ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
            if tokenizer is not None:
               # Change integer-ids to tokens
               token = tokenizer.ids2tokens(token_int)
               text = tokenizer.tokens2text(token)
               
               _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
                                                          us_peaks[i][:encoder_out_lens[i] * 3],
                                                          copy.copy(token),
                                                          vad_offset=kwargs.get("begin_time", 0))
               text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(
                  token, timestamp)
               sentences = time_stamp_sentence(None, time_stamp_postprocessed, text_postprocessed)
               result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed,
                           "timestamp": time_stamp_postprocessed,
                           "word_lists": word_lists,
                           "sentences": sentences
                           }
               if ibest_writer is not None:
                  ibest_writer["token"][key[i]] = " ".join(token)
                  ibest_writer["text"][key[i]] = text
                  ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed
                  ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
            else:
               result_i = {"key": key[i], "token_int": token_int}
            results.append(result_i)
      
      return results, meta_data
      return results, meta_data