From 8dab6d184a034ca86eafa644ea0d2100aadfe27d Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 09 五月 2023 10:58:33 +0800
Subject: [PATCH] Merge pull request #473 from alibaba-damo-academy/dev_smohan

---
 funasr/bin/asr_inference.py |   27 +++++++++++++++++++++------
 1 files changed, 21 insertions(+), 6 deletions(-)

diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py
index 4722602..a52e94a 100644
--- a/funasr/bin/asr_inference.py
+++ b/funasr/bin/asr_inference.py
@@ -41,6 +41,7 @@
 from funasr.utils.types import str_or_none
 from funasr.utils import asr_utils, wav_utils, postprocess_utils
 from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.tasks.asr import frontend_choices
 
 
 header_colors = '\033[95m'
@@ -92,7 +93,11 @@
         )
         frontend = None
         if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
-            frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+            if asr_train_args.frontend=='wav_frontend':
+                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+            else:
+                frontend_class=frontend_choices.get_class(asr_train_args.frontend)
+                frontend = frontend_class(**asr_train_args.frontend_conf).eval()
 
         logging.info("asr_model: {}".format(asr_model))
         logging.info("asr_train_args: {}".format(asr_train_args))
@@ -111,7 +116,7 @@
         # 2. Build Language model
         if lm_train_config is not None:
             lm, lm_train_args = LMTask.build_model_from_file(
-                lm_train_config, lm_file, device
+                lm_train_config, lm_file, None, device
             )
             scorers["lm"] = lm.lm
 
@@ -193,7 +198,7 @@
 
         """
         assert check_argument_types()
-
+        
         # Input as audio signal
         if isinstance(speech, np.ndarray):
             speech = torch.tensor(speech)
@@ -280,6 +285,7 @@
         ngram_weight: float = 0.9,
         nbest: int = 1,
         num_workers: int = 1,
+        mc: bool = False,
         **kwargs,
 ):
     inference_pipeline = inference_modelscope(
@@ -310,6 +316,7 @@
         ngram_weight=ngram_weight,
         nbest=nbest,
         num_workers=num_workers,
+        mc=mc,
         **kwargs,
     )
     return inference_pipeline(data_path_and_name_and_type, raw_inputs)
@@ -342,6 +349,7 @@
     ngram_weight: float = 0.9,
     nbest: int = 1,
     num_workers: int = 1,
+    mc: bool = False,
     param_dict: dict = None,
     **kwargs,
 ):
@@ -355,6 +363,9 @@
     if ngpu > 1:
         raise NotImplementedError("only single GPU decoding is supported")
     
+    for handler in logging.root.handlers[:]:
+        logging.root.removeHandler(handler)
+
     logging.basicConfig(
         level=log_level,
         format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
@@ -408,6 +419,7 @@
             data_path_and_name_and_type,
             dtype=dtype,
             fs=fs,
+            mc=mc,
             batch_size=batch_size,
             key_file=key_file,
             num_workers=num_workers,
@@ -416,7 +428,7 @@
             allow_variable_data_keys=allow_variable_data_keys,
             inference=True,
         )
-        
+
         finish_count = 0
         file_count = 1
         # 7 .Start for-loop
@@ -452,7 +464,7 @@
                     
                     # Write the result to each file
                     ibest_writer["token"][key] = " ".join(token)
-                    # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                     ibest_writer["score"][key] = str(hyp.score)
                 
                 if text is not None:
@@ -463,6 +475,9 @@
                     asr_utils.print_progress(finish_count / file_count)
                     if writer is not None:
                         ibest_writer["text"][key] = text
+
+                logging.info("uttid: {}".format(key))
+                logging.info("text predictions: {}\n".format(text))
         return asr_result_list
     
     return _forward
@@ -637,4 +652,4 @@
 
 
 if __name__ == "__main__":
-    main()
+    main()
\ No newline at end of file

--
Gitblit v1.9.1