游雁
2024-01-11 47088b8d1ebe42b6c376236c19184ef4f440cc0d
funasr1.0 paraformer_streaming
2个文件已修改
141 ■■■■ 已修改文件
funasr/models/paraformer_streaming/model.py 139 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/python/onnxruntime/setup.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/paraformer_streaming/model.py
@@ -375,7 +375,7 @@
        
        return pre_acoustic_embeds, pre_token_length, pre_alphas, pre_peak_index
    
    def calc_predictor_chunk(self, encoder_out, cache=None):
    def calc_predictor_chunk(self, encoder_out, encoder_out_lens, cache=None):
        
        pre_acoustic_embeds, pre_token_length = \
            self.predictor.forward_chunk(encoder_out, cache["encoder"])
@@ -389,48 +389,72 @@
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        return decoder_out, ys_pad_lens
    
    def cal_decoder_with_predictor_chunk(self, encoder_out, sematic_embeds, cache=None):
    def cal_decoder_with_predictor_chunk(self, encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, cache=None):
        decoder_outs = self.decoder.forward_chunk(
            encoder_out, sematic_embeds, cache["decoder"]
        )
        decoder_out = decoder_outs
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        return decoder_out
        return decoder_out, ys_pad_lens
    def generate(self,
                 speech: torch.Tensor,
                 speech_lengths: torch.Tensor,
    def init_cache(self, cache: dict = {}, **kwargs):
        chunk_size = kwargs.get("chunk_size", [0, 10, 5])
        encoder_chunk_look_back = kwargs.get("encoder_chunk_look_back", 0)
        decoder_chunk_look_back = kwargs.get("decoder_chunk_look_back", 0)
        batch_size = 1
        enc_output_size = kwargs["encoder_conf"]["output_size"]
        feats_dims = kwargs["frontend_conf"]["n_mels"] * kwargs["frontend_conf"]["lfr_m"]
        cache_encoder = {"start_idx": 0, "cif_hidden": torch.zeros((batch_size, 1, enc_output_size)),
                    "cif_alphas": torch.zeros((batch_size, 1)), "chunk_size": chunk_size,
                    "encoder_chunk_look_back": encoder_chunk_look_back, "last_chunk": False, "opt": None,
                    "feats": torch.zeros((batch_size, chunk_size[0] + chunk_size[2], feats_dims)),
                    "tail_chunk": False}
        cache["encoder"] = cache_encoder
        cache_decoder = {"decode_fsmn": None, "decoder_chunk_look_back": decoder_chunk_look_back, "opt": None,
                    "chunk_size": chunk_size}
        cache["decoder"] = cache_decoder
        cache["frontend"] = {}
        cache["prev_samples"] = []
        return cache
    def generate_chunk(self,
                       speech,
                       speech_lengths=None,
                       key: list = None,
                 tokenizer=None,
                       frontend=None,
                 **kwargs,
                 ):
        cache = kwargs.get("cache", {})
        speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
        
        is_use_ctc = kwargs.get("ctc_weight", 0.0) > 0.00001 and self.ctc != None
        print(is_use_ctc)
        is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        if self.beam_search is None and (is_use_lm or is_use_ctc):
            logging.info("enable beam_search")
            self.init_beam_search(speech, speech_lengths, **kwargs)
            self.nbest = kwargs.get("nbest", 1)
        # Forward Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        # Encoder
        encoder_out, encoder_out_lens = self.encode_chunk(speech, speech_lengths, cache=cache)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]
        
        # predictor
        predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
        predictor_outs = self.calc_predictor_chunk(encoder_out, encoder_out_lens, cache=cache)
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
                                                                        predictor_outs[2], predictor_outs[3]
        pre_token_length = pre_token_length.round().long()
        if torch.max(pre_token_length) < 1:
            return []
        decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, pre_acoustic_embeds,
                                                       pre_token_length)
        decoder_outs = self.cal_decoder_with_predictor_chunk(encoder_out,
                                                             encoder_out_lens,
                                                             pre_acoustic_embeds,
                                                             pre_token_length,
                                                             cache=cache
                                                             )
        decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
        
        results = []
        b, n, d = decoder_out.size()
        if isinstance(key[0], (list, tuple)):
            key = key[0]
        for i in range(b):
            x = encoder_out[i, :encoder_out_lens[i], :]
            am_scores = decoder_out[i, :pre_token_length[i], :]
@@ -451,9 +475,11 @@
                    [self.sos] + yseq.tolist() + [self.eos], device=yseq.device
                )
                nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
            for hyp in nbest_hyps:
                assert isinstance(hyp, (Hypothesis)), type(hyp)
            for nbest_idx, hyp in enumerate(nbest_hyps):
                ibest_writer = None
                if ibest_writer is None and kwargs.get("output_dir") is not None:
                    writer = DatadirWriter(kwargs.get("output_dir"))
                    ibest_writer = writer[f"{nbest_idx + 1}best_recog"]
                # remove sos/eos and get results
                last_pos = -1
                if isinstance(hyp.yseq, list):
@@ -462,15 +488,76 @@
                    token_int = hyp.yseq[1:last_pos].tolist()
                
                # remove blank symbol id, which is assumed to be 0
                token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
                token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
                
                if tokenizer is not None:
                # Change integer-ids to tokens
                token = tokenizer.ids2tokens(token_int)
                text = tokenizer.tokens2text(token)
                
                timestamp = []
                    text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                
                results.append((text, token, timestamp))
                    result_i = {"key": key[i], "text": text_postprocessed}
                    if ibest_writer is not None:
                        ibest_writer["token"][key[i]] = " ".join(token)
                        # ibest_writer["text"][key[i]] = text
                        ibest_writer["text"][key[i]] = text_postprocessed
                else:
                    result_i = {"key": key[i], "token_int": token_int}
                results.append(result_i)
        
        return results
    def generate(self,
                 data_in,
                 data_lengths=None,
                 key: list = None,
                 tokenizer=None,
                 frontend=None,
                 **kwargs,
                 ):
        # init beamsearch
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        if self.beam_search is None and (is_use_lm or is_use_ctc):
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)
        cache = kwargs.get("cache", {})
        if len(cache) == 0:
            self.init_cache(cache, **kwargs)
        meta_data = {}
        chunk_size = kwargs.get("chunk_size", [0, 10, 5])
        chunk_stride_samples = chunk_size[1] * 960  # 600ms
        time1 = time.perf_counter()
        audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
                                                        data_type=kwargs.get("data_type", "sound"),
                                                        tokenizer=tokenizer)
        time2 = time.perf_counter()
        meta_data["load_data"] = f"{time2 - time1:0.3f}"
        assert len(audio_sample_list) == 1, "batch_size must be set 1"
        audio_sample = cache["prev_samples"] + audio_sample_list[0]
        n = len(audio_sample) // chunk_stride_samples
        m = len(audio_sample) % chunk_stride_samples
        for i in range(n):
            audio_sample_i = audio_sample[i*chunk_stride_samples:(i+1)*chunk_stride_samples]
            # extract fbank feats
            speech, speech_lengths = extract_fbank([audio_sample_i], data_type=kwargs.get("data_type", "sound"),
                                                   frontend=frontend, cache=cache["frontend"])
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
            result_i = self.generate_chunk(speech, speech_lengths, **kwargs)
        cache["prev_samples"] = audio_sample[:-m]
runtime/python/onnxruntime/setup.py
@@ -13,7 +13,7 @@
MODULE_NAME = 'funasr_onnx'
VERSION_NUM = '0.2.4'
VERSION_NUM = '0.2.5'
setuptools.setup(
    name=MODULE_NAME,