游雁
2024-01-12 0143122a4e2ee86cc27ba137b2bb0530577cbf12
funasr/models/paraformer_streaming/model.py
@@ -64,8 +64,8 @@
      
      super().__init__(*args, **kwargs)
      
      import pdb;
      pdb.set_trace()
      # import pdb;
      # pdb.set_trace()
      self.sampling_ratio = kwargs.get("sampling_ratio", 0.2)
@@ -375,11 +375,10 @@
      
      return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
   
   def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None):
      pre_acoustic_embeds, pre_token_length = \
         self.predictor.forward_chunk(encoder_out, cache["encoder"])
      return pre_acoustic_embeds, pre_token_length
   def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None, **kwargs):
      is_final = kwargs.get("is_final", False)
      return self.predictor.forward_chunk(encoder_out, cache["encoder"], is_final=is_final)
   
   def cal_decoder_with_predictor(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens):
      decoder_outs = self.decoder(
@@ -416,7 +415,7 @@
                  "chunk_size": chunk_size}
      cache["decoder"] = cache_decoder
      cache["frontend"] = {}
      cache["prev_samples"] = []
      cache["prev_samples"] = torch.empty(0)
      
      return cache
   
@@ -432,12 +431,12 @@
      speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
      
      # Encoder
      encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache)
      encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache, is_final=kwargs.get("is_final", False))
      if isinstance(encoder_out, tuple):
         encoder_out = encoder_out[0]
      
      # predictor
      predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, cache=cache)
      predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, cache=cache, is_final=kwargs.get("is_final", False))
      pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
                                                                      predictor_outs[2], predictor_outs[3]
      pre_token_length = pre_token_length.round().long()
@@ -476,10 +475,7 @@
            )
            nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
         for nbest_idx, hyp in enumerate(nbest_hyps):
            ibest_writer = None
            if ibest_writer is None and kwargs.get("output_dir") is not None:
               writer = DatadirWriter(kwargs.get("output_dir"))
               ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
            # remove sos/eos and get results
            last_pos = -1
            if isinstance(hyp.yseq, list):
@@ -490,22 +486,15 @@
            # remove blank symbol id, which is assumed to be 0
            token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
            
            if tokenizer is not None:
               # Change integer-ids to tokens
               token = tokenizer.ids2tokens(token_int)
               text = tokenizer.tokens2text(token)
               text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
               result_i = {"key": key[i], "text": text_postprocessed}
               if ibest_writer is not None:
                  ibest_writer["token"][key[i]] = " ".join(token)
                  # ibest_writer["text"][key[i]] = text
                  ibest_writer["text"][key[i]] = text_postprocessed
            else:
               result_i = {"key": key[i], "token_int": token_int}
            results.append(result_i)
            # Change integer-ids to tokens
            token = tokenizer.ids2tokens(token_int)
            # text = tokenizer.tokens2text(token)
            result_i = token
            results.extend(result_i)
      
      return results
   
@@ -515,6 +504,7 @@
                key: list = None,
                tokenizer=None,
                frontend=None,
                cache: dict={},
                **kwargs,
                ):
@@ -526,38 +516,65 @@
         self.init_beam_search(**kwargs)
         self.nbest = kwargs.get("nbest", 1)
      
      cache = kwargs.get("cache", {})
      if len(cache) == 0:
         self.init_cache(cache, **kwargs)
      
      meta_data = {}
      chunk_size = kwargs.get("chunk_size", [0, 10, 5])
      chunk_stride_samples = chunk_size[1] * 960  # 600ms
      
      time1 = time.perf_counter()
      audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
                                                      data_type=kwargs.get("data_type", "sound"),
                                                      tokenizer=tokenizer)
      cfg = {"is_final": kwargs.get("is_final", False)}
      audio_sample_list = load_audio_text_image_video(data_in,
                                          fs=frontend.fs,
                                          audio_fs=kwargs.get("fs", 16000),
                                          data_type=kwargs.get("data_type", "sound"),
                                          tokenizer=tokenizer,
                                          **cfg,
                                          )
      _is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True
      time2 = time.perf_counter()
      meta_data["load_data"] = f"{time2 - time1:0.3f}"
      assert len(audio_sample_list) == 1, "batch_size must be set 1"
      
      audio_sample = cache["prev_samples"] + audio_sample_list[0]
      audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
      
      n = len(audio_sample) // chunk_stride_samples
      m = len(audio_sample) % chunk_stride_samples
      n = len(audio_sample) // chunk_stride_samples + int(_is_final)
      m = len(audio_sample) % chunk_stride_samples * (1-int(_is_final))
      tokens = []
      for i in range(n):
         kwargs["is_final"] = _is_final and i == n -1
         audio_sample_i = audio_sample[i*chunk_stride_samples:(i+1)*chunk_stride_samples]
         # extract fbank feats
         speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
                                                frontend=frontend, cache=cache["frontend"])
                                                frontend=frontend, cache=cache["frontend"], is_final=kwargs["is_final"])
         time3 = time.perf_counter()
         meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
         meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
         
         result_i = self.generate_chunk(speech, speech_lengths, **kwargs)
         tokens_i = self.generate_chunk(speech, speech_lengths, key=key, tokenizer=tokenizer, cache=cache, frontend=frontend, **kwargs)
         tokens.extend(tokens_i)
      text_postprocessed, _ = postprocess_utils.sentence_postprocess(tokens)
      result_i = {"key": key[0], "text": text_postprocessed}
      result = [result_i]
      
      cache["prev_samples"] = audio_sample[:-m]
      if _is_final:
         self.init_cache(cache, **kwargs)
      if kwargs.get("output_dir"):
         writer = DatadirWriter(kwargs.get("output_dir"))
         ibest_writer = writer[f"{1}best_recog"]
         ibest_writer["token"][key[0]] = " ".join(tokens)
         ibest_writer["text"][key[0]] = text_postprocessed
      return result, meta_data