游雁
2024-01-11 47088b8d1ebe42b6c376236c19184ef4f440cc0d
funasr/models/paraformer_streaming/model.py
@@ -375,7 +375,7 @@
      
      return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
   
   def calc_predictor_chunk(self, encoder_out, cache=None):
   def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None):
      
      pre_acoustic_embeds, pre_token_length = \
         self.predictor.forward_chunk(encoder_out, cache["encoder"])
@@ -389,48 +389,72 @@
      decoder_out = torch.log_softmax(decoder_out, dim=-1)
      return decoder_out, ys_pad_lens
   
   def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
   def cal_decoder_with_predictor_chunk(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, cache=None):
      decoder_outs = self.decoder.forward_chunk(
         encoder_out, sematic_embeds, cache["decoder"]
      )
      decoder_out = decoder_outs
      decoder_out = torch.log_softmax(decoder_out, dim=-1)
      return decoder_out
      return decoder_out, ys_pad_lens
   def init_cache(self, cache: dict = {}, **kwargs):
      chunk_size = kwargs.get("chunk_size", [0, 10, 5])
      encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
      decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
      batch_size = 1
   def generate(self,
                speech: torch.Tensor,
                speech_lengths: torch.Tensor,
                tokenizer=None,
                **kwargs,
                ):
      enc_output_size = kwargs["encoder_conf"]["output_size"]
      feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
      cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
                  "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
                  "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
                  "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
                  "tail_chunk": False}
      cache["encoder"] = cache_encoder
      
      is_use_ctc = kwargs.get("ctc_weight", 0.0) > 0.00001 and self.ctc != None
      print(is_use_ctc)
      is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
      cache_decoder = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None,
                  "chunk_size": chunk_size}
      cache["decoder"] = cache_decoder
      cache["frontend"] = {}
      cache["prev_samples"] = []
      
      if self.beam_search is None and (is_use_lm or is_use_ctc):
         logging.info("enable beam_search")
         self.init_beam_search(speech, speech_lengths, **kwargs)
         self.nbest = kwargs.get("nbest", 1)
      return cache
   def generate_chunk(self,
                      speech,
                      speech_lengths=None,
                      key: list = None,
                      tokenizer=None,
                      frontend=None,
                      **kwargs,
                      ):
      cache = kwargs.get("cache", {})
      speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
      
      # Forward Encoder
      encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
      # Encoder
      encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache)
      if isinstance(encoder_out, tuple):
         encoder_out = encoder_out[0]
      
      # predictor
      predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
      predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, cache=cache)
      pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
                                                                      predictor_outs[2], predictor_outs[3]
      pre_token_length = pre_token_length.round().long()
      if torch.max(pre_token_length) < 1:
         return []
      decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
                                                     pre_token_length)
      decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out,
                                                           encoder_out_lens,
                                                           pre_acoustic_embeds,
                                                           pre_token_length,
                                                           cache=cache
                                                           )
      decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
      results = []
      b, n, d = decoder_out.size()
      if isinstance(key[0], (list, tuple)):
         key = key[0]
      for i in range(b):
         x = encoder_out[i, :encoder_out_lens[i], :]
         am_scores = decoder_out[i, :pre_token_length[i], :]
@@ -451,9 +475,11 @@
               [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
            )
            nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
         for hyp in nbest_hyps:
            assert isinstance(hyp, (Hypothesis)), type(hyp)
         for nbest_idx, hyp in enumerate(nbest_hyps):
            ibest_writer = None
            if ibest_writer is None and kwargs.get("output_dir") is not None:
               writer = DatadirWriter(kwargs.get("output_dir"))
               ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
            # remove sos/eos and get results
            last_pos = -1
            if isinstance(hyp.yseq, list):
@@ -462,15 +488,76 @@
               token_int = hyp.yseq[1:last_pos].tolist()
            
            # remove blank symbol id, which is assumed to be 0
            token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
            token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
            
            # Change integer-ids to tokens
            token = tokenizer.ids2tokens(token_int)
            text = tokenizer.tokens2text(token)
            timestamp = []
            results.append((text, token, timestamp))
            if tokenizer is not None:
               # Change integer-ids to tokens
               token = tokenizer.ids2tokens(token_int)
               text = tokenizer.tokens2text(token)
               text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
               result_i = {"key": key[i], "text": text_postprocessed}
               if ibest_writer is not None:
                  ibest_writer["token"][key[i]] = " ".join(token)
                  # ibest_writer["text"][key[i]] = text
                  ibest_writer["text"][key[i]] = text_postprocessed
            else:
               result_i = {"key": key[i], "token_int": token_int}
            results.append(result_i)
      
      return results
   def generate(self,
                data_in,
                data_lengths=None,
                key: list = None,
                tokenizer=None,
                frontend=None,
                **kwargs,
                ):
      # init beamsearch
      is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
      is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
      if self.beam_search is None and (is_use_lm or is_use_ctc):
         logging.info("enable beam_search")
         self.init_beam_search(**kwargs)
         self.nbest = kwargs.get("nbest", 1)
      cache = kwargs.get("cache", {})
      if len(cache) == 0:
         self.init_cache(cache, **kwargs)
      meta_data = {}
      chunk_size = kwargs.get("chunk_size", [0, 10, 5])
      chunk_stride_samples = chunk_size[1] * 960  # 600ms
      time1 = time.perf_counter()
      audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
                                                      data_type=kwargs.get("data_type", "sound"),
                                                      tokenizer=tokenizer)
      time2 = time.perf_counter()
      meta_data["load_data"] = f"{time2 - time1:0.3f}"
      assert len(audio_sample_list) == 1, "batch_size must be set 1"
      audio_sample = cache["prev_samples"] + audio_sample_list[0]
      n = len(audio_sample) // chunk_stride_samples
      m = len(audio_sample) % chunk_stride_samples
      for i in range(n):
         audio_sample_i = audio_sample[i*chunk_stride_samples:(i+1)*chunk_stride_samples]
         # extract fbank feats
         speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
                                                frontend=frontend, cache=cache["frontend"])
         time3 = time.perf_counter()
         meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
         meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
         result_i = self.generate_chunk(speech, speech_lengths, **kwargs)
      cache["prev_samples"] = audio_sample[:-m]