From 79bd015ab0ded4e5aed1b1ecf32fcbc84eefde68 Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 03 二月 2023 17:42:47 +0800
Subject: [PATCH] Merge pull request #58 from alibaba-damo-academy/dev

---
 funasr/bin/asr_inference_paraformer.py |  523 ++++++++++++++++++++++++++++++++++++++++-----------------
 1 files changed, 363 insertions(+), 160 deletions(-)

diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py
old mode 100755
new mode 100644
index 09c61bc..0929436
--- a/funasr/bin/asr_inference_paraformer.py
+++ b/funasr/bin/asr_inference_paraformer.py
@@ -36,10 +36,6 @@
 from funasr.utils import asr_utils, wav_utils, postprocess_utils
 from funasr.models.frontend.wav_frontend import WavFrontend
 
-from modelscope.utils.logger import get_logger
-
-logger = get_logger()
-
 header_colors = '\033[95m'
 end_colors = '\033[0m'
 
@@ -48,6 +44,7 @@
     'audio_fs': 16000,
     'model_fs': 16000
 }
+
 
 class Speech2Text:
     """Speech2Text class
@@ -65,6 +62,7 @@
             self,
             asr_train_config: Union[Path, str] = None,
             asr_model_file: Union[Path, str] = None,
+            cmvn_file: Union[Path, str] = None,
             lm_train_config: Union[Path, str] = None,
             lm_file: Union[Path, str] = None,
             token_type: str = None,
@@ -87,19 +85,23 @@
         # 1. Build ASR model
         scorers = {}
         asr_model, asr_train_args = ASRTask.build_model_from_file(
-            asr_train_config, asr_model_file, device
+            asr_train_config, asr_model_file, cmvn_file, device
         )
-        if asr_model.frontend is None and frontend_conf is not None:
-            frontend = WavFrontend(**frontend_conf)
-            asr_model.frontend = frontend
-        # logging.info("asr_model: {}".format(asr_model))
-        # logging.info("asr_train_args: {}".format(asr_train_args))
+        frontend = None
+        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+            frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+
+        logging.info("asr_model: {}".format(asr_model))
+        logging.info("asr_train_args: {}".format(asr_train_args))
         asr_model.to(dtype=getattr(torch, dtype)).eval()
 
-        ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+        if asr_model.ctc != None:
+            ctc = CTCPrefixScorer(ctc=asr_model.ctc, eos=asr_model.eos)
+            scorers.update(
+                ctc=ctc
+            )
         token_list = asr_model.token_list
         scorers.update(
-            ctc=ctc,
             length_bonus=LengthBonus(len(token_list)),
         )
 
@@ -141,8 +143,8 @@
         for scorer in scorers.values():
             if isinstance(scorer, torch.nn.Module):
                 scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
-        # logging.info(f"Beam_search: {beam_search}")
-        # logging.info(f"Decoding device={device}, dtype={dtype}")
+
+        logging.info(f"Decoding device={device}, dtype={dtype}")
 
         # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
         if token_type is None:
@@ -160,22 +162,27 @@
         else:
             tokenizer = build_tokenizer(token_type=token_type)
         converter = TokenIDConverter(token_list=token_list)
-        # logging.info(f"Text tokenizer: {tokenizer}")
+        logging.info(f"Text tokenizer: {tokenizer}")
 
         self.asr_model = asr_model
         self.asr_train_args = asr_train_args
         self.converter = converter
         self.tokenizer = tokenizer
-        has_lm = lm_weight == 0.0 or lm_file is None
-        if ctc_weight == 0.0 and has_lm:
+        is_use_lm = lm_weight != 0.0 and lm_file is not None
+        if (ctc_weight == 0.0 or asr_model.ctc == None) and not is_use_lm:
             beam_search = None
         self.beam_search = beam_search
+        logging.info(f"Beam_search: {self.beam_search}")
         self.beam_search_transducer = beam_search_transducer
         self.maxlenratio = maxlenratio
         self.minlenratio = minlenratio
         self.device = device
         self.dtype = dtype
         self.nbest = nbest
+        self.frontend = frontend
+        self.encoder_downsampling_factor = 1
+        if asr_train_args.encoder_conf["input_layer"] == "conv2d":
+            self.encoder_downsampling_factor = 4
 
     @torch.no_grad()
     def __call__(
@@ -195,14 +202,16 @@
         if isinstance(speech, np.ndarray):
             speech = torch.tensor(speech)
 
-        # data: (Nsamples,) -> (1, Nsamples)
-        # lengths: (1,)
-        # if len(speech.size()) < 3:
-        #     speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
-        #     speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
-        lfr_factor = max(1, (speech.size()[-1]//80)-1)
-        
-        batch = {"speech": speech, "speech_lengths": speech_lengths}
+        if self.frontend is not None:
+            feats, feats_len = self.frontend.forward(speech, speech_lengths)
+            feats = to_device(feats, device=self.device)
+            feats_len = feats_len.int()
+            self.asr_model.frontend = None
+        else:
+            feats = speech
+            feats_len = speech_lengths
+        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
+        batch = {"speech": feats, "speech_lengths": feats_len}
 
         # a. To device
         batch = to_device(batch, device=self.device)
@@ -212,11 +221,14 @@
         if isinstance(enc, tuple):
             enc = enc[0]
         # assert len(enc) == 1, len(enc)
-        enc_len_batch_total = torch.sum(enc_len).item()
+        enc_len_batch_total = torch.sum(enc_len).item() * self.encoder_downsampling_factor
 
         predictor_outs = self.asr_model.calc_predictor(enc, enc_len)
-        pre_acoustic_embeds, pre_token_length = predictor_outs[0], predictor_outs[1]
+        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
+                                                                        predictor_outs[2], predictor_outs[3]
         pre_token_length = pre_token_length.round().long()
+        if torch.max(pre_token_length) < 1:
+            return []
         decoder_outs = self.asr_model.cal_decoder_with_predictor(enc, enc_len, pre_acoustic_embeds, pre_token_length)
         decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
 
@@ -229,7 +241,7 @@
                 nbest_hyps = self.beam_search(
                     x=x, am_scores=am_scores, maxlenratio=self.maxlenratio, minlenratio=self.minlenratio
                 )
-    
+
                 nbest_hyps = nbest_hyps[: self.nbest]
             else:
                 yseq = am_scores.argmax(dim=-1)
@@ -240,33 +252,189 @@
                     [self.asr_model.sos] + yseq.tolist() + [self.asr_model.eos], device=yseq.device
                 )
                 nbest_hyps = [Hypothesis(yseq=yseq, score=score)]
-                
+
             for hyp in nbest_hyps:
                 assert isinstance(hyp, (Hypothesis)), type(hyp)
-    
+
                 # remove sos/eos and get results
                 last_pos = -1
                 if isinstance(hyp.yseq, list):
                     token_int = hyp.yseq[1:last_pos]
                 else:
                     token_int = hyp.yseq[1:last_pos].tolist()
-    
+
                 # remove blank symbol id, which is assumed to be 0
-                token_int = list(filter(lambda x: x != 0, token_int))
-    
+                token_int = list(filter(lambda x: x != 0 and x != 2, token_int))
+
                 # Change integer-ids to tokens
                 token = self.converter.ids2tokens(token_int)
-    
+
                 if self.tokenizer is not None:
                     text = self.tokenizer.tokens2text(token)
                 else:
                     text = None
-    
+
                 results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
 
         # assert check_return_type(results)
         return results
 
+
+# def inference(
+#         maxlenratio: float,
+#         minlenratio: float,
+#         batch_size: int,
+#         beam_size: int,
+#         ngpu: int,
+#         ctc_weight: float,
+#         lm_weight: float,
+#         penalty: float,
+#         log_level: Union[int, str],
+#         data_path_and_name_and_type,
+#         asr_train_config: Optional[str],
+#         asr_model_file: Optional[str],
+#         cmvn_file: Optional[str] = None,
+#         raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+#         lm_train_config: Optional[str] = None,
+#         lm_file: Optional[str] = None,
+#         token_type: Optional[str] = None,
+#         key_file: Optional[str] = None,
+#         word_lm_train_config: Optional[str] = None,
+#         bpemodel: Optional[str] = None,
+#         allow_variable_data_keys: bool = False,
+#         streaming: bool = False,
+#         output_dir: Optional[str] = None,
+#         dtype: str = "float32",
+#         seed: int = 0,
+#         ngram_weight: float = 0.9,
+#         nbest: int = 1,
+#         num_workers: int = 1,
+#         frontend_conf: dict = None,
+#         fs: Union[dict, int] = 16000,
+#         lang: Optional[str] = None,
+#         **kwargs,
+# ):
+#     assert check_argument_types()
+#
+#     if word_lm_train_config is not None:
+#         raise NotImplementedError("Word LM is not implemented")
+#     if ngpu > 1:
+#         raise NotImplementedError("only single GPU decoding is supported")
+#
+#     logging.basicConfig(
+#         level=log_level,
+#         format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+#     )
+#
+#     if ngpu >= 1 and torch.cuda.is_available():
+#         device = "cuda"
+#     else:
+#         device = "cpu"
+#
+#     # 1. Set random-seed
+#     set_all_random_seed(seed)
+#
+#     # 2. Build speech2text
+#     speech2text_kwargs = dict(
+#         asr_train_config=asr_train_config,
+#         asr_model_file=asr_model_file,
+#         cmvn_file=cmvn_file,
+#         lm_train_config=lm_train_config,
+#         lm_file=lm_file,
+#         token_type=token_type,
+#         bpemodel=bpemodel,
+#         device=device,
+#         maxlenratio=maxlenratio,
+#         minlenratio=minlenratio,
+#         dtype=dtype,
+#         beam_size=beam_size,
+#         ctc_weight=ctc_weight,
+#         lm_weight=lm_weight,
+#         ngram_weight=ngram_weight,
+#         penalty=penalty,
+#         nbest=nbest,
+#         frontend_conf=frontend_conf,
+#     )
+#     speech2text = Speech2Text(**speech2text_kwargs)
+#
+#     # 3. Build data-iterator
+#     loader = ASRTask.build_streaming_iterator(
+#         data_path_and_name_and_type,
+#         dtype=dtype,
+#         batch_size=batch_size,
+#         key_file=key_file,
+#         num_workers=num_workers,
+#         preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+#         collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+#         allow_variable_data_keys=allow_variable_data_keys,
+#         inference=True,
+#     )
+#
+#     forward_time_total = 0.0
+#     length_total = 0.0
+#     finish_count = 0
+#     file_count = 1
+#     # 7 .Start for-loop
+#     # FIXME(kamo): The output format should be discussed about
+#     asr_result_list = []
+#     if output_dir is not None:
+#         writer = DatadirWriter(output_dir)
+#     else:
+#         writer = None
+#
+#     for keys, batch in loader:
+#         assert isinstance(batch, dict), type(batch)
+#         assert all(isinstance(s, str) for s in keys), keys
+#         _bs = len(next(iter(batch.values())))
+#         assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+#         # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
+#
+#         logging.info("decoding, utt_id: {}".format(keys))
+#         # N-best list of (text, token, token_int, hyp_object)
+#
+#         time_beg = time.time()
+#         results = speech2text(**batch)
+#         if len(results) < 1:
+#             hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+#             results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
+#         time_end = time.time()
+#         forward_time = time_end - time_beg
+#         lfr_factor = results[0][-1]
+#         length = results[0][-2]
+#         forward_time_total += forward_time
+#         length_total += length
+#         logging.info(
+#             "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".
+#                 format(length, forward_time, 100 * forward_time / (length*lfr_factor)))
+#
+#         for batch_id in range(_bs):
+#             result = [results[batch_id][:-2]]
+#
+#             key = keys[batch_id]
+#             for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), result):
+#                 # Create a directory: outdir/{n}best_recog
+#                 if writer is not None:
+#                     ibest_writer = writer[f"{n}best_recog"]
+#
+#                     # Write the result to each file
+#                     ibest_writer["token"][key] = " ".join(token)
+#                     ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+#                     ibest_writer["score"][key] = str(hyp.score)
+#
+#                 if text is not None:
+#                     text_postprocessed = postprocess_utils.sentence_postprocess(token)
+#                     item = {'key': key, 'value': text_postprocessed}
+#                     asr_result_list.append(item)
+#                     finish_count += 1
+#                     # asr_utils.print_progress(finish_count / file_count)
+#                     if writer is not None:
+#                         ibest_writer["text"][key] = text
+#
+#                 logging.info("decoding, utt: {}, predictions: {}".format(key, text))
+#
+#     logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".
+#                  format(length_total, forward_time_total, 100 * forward_time_total / (length_total*lfr_factor)))
+#     return asr_result_list
 
 def inference(
         maxlenratio: float,
@@ -281,7 +449,8 @@
         data_path_and_name_and_type,
         asr_train_config: Optional[str],
         asr_model_file: Optional[str],
-        audio_lists: Union[List[Any], bytes] = None,
+        cmvn_file: Optional[str] = None,
+        raw_inputs: Union[np.ndarray, torch.Tensor] = None,
         lm_train_config: Optional[str] = None,
         lm_file: Optional[str] = None,
         token_type: Optional[str] = None,
@@ -296,9 +465,71 @@
         ngram_weight: float = 0.9,
         nbest: int = 1,
         num_workers: int = 1,
-        frontend_conf: dict = None,
-        fs: Union[dict, int] = 16000,
-        lang: Optional[str] = None,
+
+        **kwargs,
+):
+    inference_pipeline = inference_modelscope(
+        maxlenratio=maxlenratio,
+        minlenratio=minlenratio,
+        batch_size=batch_size,
+        beam_size=beam_size,
+        ngpu=ngpu,
+        ctc_weight=ctc_weight,
+        lm_weight=lm_weight,
+        penalty=penalty,
+        log_level=log_level,
+        asr_train_config=asr_train_config,
+        asr_model_file=asr_model_file,
+        cmvn_file=cmvn_file,
+        raw_inputs=raw_inputs,
+        lm_train_config=lm_train_config,
+        lm_file=lm_file,
+        token_type=token_type,
+        key_file=key_file,
+        word_lm_train_config=word_lm_train_config,
+        bpemodel=bpemodel,
+        allow_variable_data_keys=allow_variable_data_keys,
+        streaming=streaming,
+        output_dir=output_dir,
+        dtype=dtype,
+        seed=seed,
+        ngram_weight=ngram_weight,
+        nbest=nbest,
+        num_workers=num_workers,
+
+        **kwargs,
+    )
+    return inference_pipeline(data_path_and_name_and_type, raw_inputs)
+
+
+def inference_modelscope(
+        maxlenratio: float,
+        minlenratio: float,
+        batch_size: int,
+        beam_size: int,
+        ngpu: int,
+        ctc_weight: float,
+        lm_weight: float,
+        penalty: float,
+        log_level: Union[int, str],
+        # data_path_and_name_and_type,
+        asr_train_config: Optional[str],
+        asr_model_file: Optional[str],
+        cmvn_file: Optional[str] = None,
+        lm_train_config: Optional[str] = None,
+        lm_file: Optional[str] = None,
+        token_type: Optional[str] = None,
+        key_file: Optional[str] = None,
+        word_lm_train_config: Optional[str] = None,
+        bpemodel: Optional[str] = None,
+        allow_variable_data_keys: bool = False,
+        dtype: str = "float32",
+        seed: int = 0,
+        ngram_weight: float = 0.9,
+        nbest: int = 1,
+        num_workers: int = 1,
+        output_dir: Optional[str] = None,
+        param_dict: dict = None,
         **kwargs,
 ):
     assert check_argument_types()
@@ -313,50 +544,12 @@
         format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
     )
 
-    if ngpu >= 1:
+    if ngpu >= 1 and torch.cuda.is_available():
         device = "cuda"
     else:
         device = "cpu"
-    hop_length: int = 160
-    sr: int = 16000
-    if isinstance(fs, int):
-        sr = fs
-    else:
-        if 'model_fs' in fs and fs['model_fs'] is not None:
-            sr = fs['model_fs']
-    # data_path_and_name_and_type for modelscope: (data from audio_lists)
-    # ['speech', 'sound', 'am.mvn']
-    # data_path_and_name_and_type for funasr:
-    # [('/mnt/data/jiangyu.xzy/exp/maas/mvn.1.scp', 'speech', 'kaldi_ark')]
-    if isinstance(data_path_and_name_and_type[0], Tuple):
-        features_type: str = data_path_and_name_and_type[0][1]
-    elif isinstance(data_path_and_name_and_type[0], str):
-        features_type: str = data_path_and_name_and_type[1]
-    else:
-        raise NotImplementedError("unknown features type:{0}".format(data_path_and_name_and_type))
-    if features_type != 'sound':
-        frontend_conf = None
-        flag_modelscope = False
-    else:
-        flag_modelscope = True
-    if frontend_conf is not None:
-        if 'hop_length' in frontend_conf:
-            hop_length = frontend_conf['hop_length']
+        batch_size = 1
 
-    finish_count = 0
-    file_count = 1
-    if flag_modelscope and not isinstance(data_path_and_name_and_type[0], Tuple):
-        data_path_and_name_and_type_new = [
-            audio_lists, data_path_and_name_and_type[0], data_path_and_name_and_type[1]
-        ]
-        if isinstance(audio_lists, bytes):
-            file_count = 1
-        else:
-            file_count = len(audio_lists)
-        if len(data_path_and_name_and_type) >= 3 and frontend_conf is not None:
-            mvn_file = data_path_and_name_and_type[2]
-            mvn_data = wav_utils.extract_CMVN_featrures(mvn_file)
-            frontend_conf['mvn_data'] = mvn_data
     # 1. Set random-seed
     set_all_random_seed(seed)
 
@@ -364,6 +557,7 @@
     speech2text_kwargs = dict(
         asr_train_config=asr_train_config,
         asr_model_file=asr_model_file,
+        cmvn_file=cmvn_file,
         lm_train_config=lm_train_config,
         lm_file=lm_file,
         token_type=token_type,
@@ -378,25 +572,20 @@
         ngram_weight=ngram_weight,
         penalty=penalty,
         nbest=nbest,
-        frontend_conf=frontend_conf,
     )
     speech2text = Speech2Text(**speech2text_kwargs)
 
-    # 3. Build data-iterator
-    if flag_modelscope:
-        loader = ASRTask.build_streaming_iterator_modelscope(
-            data_path_and_name_and_type_new,
-            dtype=dtype,
-            batch_size=batch_size,
-            key_file=key_file,
-            num_workers=num_workers,
-            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
-            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
-            allow_variable_data_keys=allow_variable_data_keys,
-            inference=True,
-            sample_rate=fs
-        )
-    else:
+    def _forward(
+            data_path_and_name_and_type,
+            raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+            output_dir_v2: Optional[str] = None,
+            param_dict: dict = None,
+    ):
+        # 3. Build data-iterator
+        if data_path_and_name_and_type is None and raw_inputs is not None:
+            if isinstance(raw_inputs, torch.Tensor):
+                raw_inputs = raw_inputs.numpy()
+            data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
         loader = ASRTask.build_streaming_iterator(
             data_path_and_name_and_type,
             dtype=dtype,
@@ -409,76 +598,75 @@
             inference=True,
         )
 
-    forward_time_total = 0.0
-    length_total = 0.0
-    # 7 .Start for-loop
-    # FIXME(kamo): The output format should be discussed about
-    asr_result_list = []
-    if output_dir is not None:
-        writer = DatadirWriter(output_dir)
-    else:
-        writer = None
+        forward_time_total = 0.0
+        length_total = 0.0
+        finish_count = 0
+        file_count = 1
+        # 7 .Start for-loop
+        # FIXME(kamo): The output format should be discussed about
+        asr_result_list = []
+        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+        if output_path is not None:
+            writer = DatadirWriter(output_path)
+        else:
+            writer = None
 
-    for keys, batch in loader:
-        assert isinstance(batch, dict), type(batch)
-        assert all(isinstance(s, str) for s in keys), keys
-        _bs = len(next(iter(batch.values())))
-        assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-        # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
+        for keys, batch in loader:
+            assert isinstance(batch, dict), type(batch)
+            assert all(isinstance(s, str) for s in keys), keys
+            _bs = len(next(iter(batch.values())))
+            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+            # batch = {k: v for k, v in batch.items() if not k.endswith("_lengths")}
 
-        # logging.info("decoding, utt_id: {}".format(keys))
-        # N-best list of (text, token, token_int, hyp_object)
+            logging.info("decoding, utt_id: {}".format(keys))
+            # N-best list of (text, token, token_int, hyp_object)
 
-        time_beg = time.time()
-        results = speech2text(**batch)
-        time_end = time.time()
-        forward_time = time_end - time_beg
-        lfr_factor = results[0][-1]
-        length = results[0][-2]
-        forward_time_total += forward_time
-        length_total += length
-        logging.info(
-            "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".
-                format(length, forward_time, 100 * forward_time / (length*lfr_factor)))
-        
-        for batch_id in range(_bs):
-            result = [results[batch_id][:-2]]
-    
-            key = keys[batch_id]
-            for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), result):
-                # Create a directory: outdir/{n}best_recog
-                if writer is not None:
-                    ibest_writer = writer[f"{n}best_recog"]
-    
-                    # Write the result to each file
-                    ibest_writer["token"][key] = " ".join(token)
-                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
-                    ibest_writer["score"][key] = str(hyp.score)
-    
-                if text is not None:
-                    text_postprocessed = postprocess_utils.sentence_postprocess(token)
-                    item = {'key': key, 'value': text_postprocessed}
-                    asr_result_list.append(item)
-                    finish_count += 1
-                    asr_utils.print_progress(finish_count / file_count)
+            time_beg = time.time()
+            results = speech2text(**batch)
+            if len(results) < 1:
+                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+                results = [[" ", ["sil"], [2], hyp, 10, 6]] * nbest
+            time_end = time.time()
+            forward_time = time_end - time_beg
+            lfr_factor = results[0][-1]
+            length = results[0][-2]
+            forward_time_total += forward_time
+            length_total += length
+            rtf_cur = "decoding, feature length: {}, forward_time: {:.4f}, rtf: {:.4f}".format(length, forward_time, 100 * forward_time / (length * lfr_factor))
+            logging.info(rtf_cur)
+
+            for batch_id in range(_bs):
+                result = [results[batch_id][:-2]]
+
+                key = keys[batch_id]
+                for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), result):
+                    # Create a directory: outdir/{n}best_recog
                     if writer is not None:
-                        ibest_writer["text"][key] = text
-    
-                logging.info("decoding, utt: {}, predictions: {}".format(key, text))
+                        ibest_writer = writer[f"{n}best_recog"]
 
-    logging.info("decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".
-                 format(length_total, forward_time_total, 100 * forward_time_total / (length_total*lfr_factor)))
-    return asr_result_list
+                        # Write the result to each file
+                        ibest_writer["token"][key] = " ".join(token)
+                        # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+                        ibest_writer["score"][key] = str(hyp.score)
+                        ibest_writer["rtf"][key] = rtf_cur
 
+                    if text is not None:
+                        text_postprocessed = postprocess_utils.sentence_postprocess(token)
+                        item = {'key': key, 'value': text_postprocessed}
+                        asr_result_list.append(item)
+                        finish_count += 1
+                        # asr_utils.print_progress(finish_count / file_count)
+                        if writer is not None:
+                            ibest_writer["text"][key] = text_postprocessed
 
-def set_parameters(language: str = None,
-                   sample_rate: Union[int, Dict[Any, int]] = None):
-    if language is not None:
-        global global_asr_language
-        global_asr_language = language
-    if sample_rate is not None:
-        global global_sample_rate
-        global_sample_rate = sample_rate
+                    logging.info("decoding, utt: {}, predictions: {}".format(key, text))
+        rtf_avg = "decoding, feature length total: {}, forward_time total: {:.4f}, rtf avg: {:.4f}".format(length_total, forward_time_total, 100 * forward_time_total / (length_total * lfr_factor))
+        logging.info(rtf_avg)
+        if writer is not None:
+            ibest_writer["rtf"]["rtf_avf"] = rtf_avg
+        return asr_result_list
+
+    return _forward
 
 
 def get_parser():
@@ -522,7 +710,7 @@
     group.add_argument(
         "--data_path_and_name_and_type",
         type=str2triple_str,
-        required=True,
+        required=False,
         action="append",
     )
     group.add_argument("--key_file", type=str_or_none)
@@ -538,6 +726,11 @@
         "--asr_model_file",
         type=str,
         help="ASR model parameter file",
+    )
+    group.add_argument(
+        "--cmvn_file",
+        type=str,
+        help="Global cmvn file",
     )
     group.add_argument(
         "--lm_train_config",
@@ -613,7 +806,7 @@
         default=None,
         help="",
     )
-    group.add_argument("--audio_lists", type=list, default=None)
+    group.add_argument("--raw_inputs", type=list, default=None)
     # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
 
     group = parser.add_argument_group("Text converter related")
@@ -647,3 +840,13 @@
 
 if __name__ == "__main__":
     main()
+
+    # from modelscope.pipelines import pipeline
+    # from modelscope.utils.constant import Tasks
+    #
+    # inference_16k_pipline = pipeline(
+    #     task=Tasks.auto_speech_recognition,
+    #     model='damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch')
+    #
+    # rec_result = inference_16k_pipline(audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav')
+    # print(rec_result)

--
Gitblit v1.9.1