From b15db52e4e67da8a133a67e8ffa415386de48b40 Mon Sep 17 00:00:00 2001
From: zhuyunfeng <10596244@qq.com>
Date: 星期二, 09 五月 2023 23:03:15 +0800
Subject: [PATCH] Add contributor

---
 funasr/bin/asr_inference.py |   35 ++++++++++++++++++++++++-----------
 1 files changed, 24 insertions(+), 11 deletions(-)

diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py
index c18472f..a52e94a 100644
--- a/funasr/bin/asr_inference.py
+++ b/funasr/bin/asr_inference.py
@@ -40,6 +40,8 @@
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
 from funasr.utils import asr_utils, wav_utils, postprocess_utils
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.tasks.asr import frontend_choices
 
 
 header_colors = '\033[95m'
@@ -90,6 +92,12 @@
             asr_train_config, asr_model_file, cmvn_file, device
         )
         frontend = None
+        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
+            if asr_train_args.frontend=='wav_frontend':
+                frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
+            else:
+                frontend_class=frontend_choices.get_class(asr_train_args.frontend)
+                frontend = frontend_class(**asr_train_args.frontend_conf).eval()
 
         logging.info("asr_model: {}".format(asr_model))
         logging.info("asr_train_args: {}".format(asr_train_args))
@@ -138,13 +146,6 @@
             token_list=token_list,
             pre_beam_score_key=None if ctc_weight == 1.0 else "full",
         )
-
-        beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
-        for scorer in scorers.values():
-            if isinstance(scorer, torch.nn.Module):
-                scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
-        logging.info(f"Beam_search: {beam_search}")
-        logging.info(f"Decoding device={device}, dtype={dtype}")
 
         # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
         if token_type is None:
@@ -197,12 +198,21 @@
 
         """
         assert check_argument_types()
-
+        
         # Input as audio signal
         if isinstance(speech, np.ndarray):
             speech = torch.tensor(speech)
 
-        batch = {"speech": speech, "speech_lengths": speech_lengths}
+        if self.frontend is not None:
+            feats, feats_len = self.frontend.forward(speech, speech_lengths)
+            feats = to_device(feats, device=self.device)
+            feats_len = feats_len.int()
+            self.asr_model.frontend = None
+        else:
+            feats = speech
+            feats_len = speech_lengths
+        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
+        batch = {"speech": feats, "speech_lengths": feats_len}
 
         # a. To device
         batch = to_device(batch, device=self.device)
@@ -275,6 +285,7 @@
         ngram_weight: float = 0.9,
         nbest: int = 1,
         num_workers: int = 1,
+        mc: bool = False,
         **kwargs,
 ):
     inference_pipeline = inference_modelscope(
@@ -305,6 +316,7 @@
         ngram_weight=ngram_weight,
         nbest=nbest,
         num_workers=num_workers,
+        mc=mc,
         **kwargs,
     )
     return inference_pipeline(data_path_and_name_and_type, raw_inputs)
@@ -337,6 +349,7 @@
     ngram_weight: float = 0.9,
     nbest: int = 1,
     num_workers: int = 1,
+    mc: bool = False,
     param_dict: dict = None,
     **kwargs,
 ):
@@ -406,7 +419,7 @@
             data_path_and_name_and_type,
             dtype=dtype,
             fs=fs,
-            mc=True,
+            mc=mc,
             batch_size=batch_size,
             key_file=key_file,
             num_workers=num_workers,
@@ -415,7 +428,7 @@
             allow_variable_data_keys=allow_variable_data_keys,
             inference=True,
         )
-        
+
         finish_count = 0
         file_count = 1
         # 7 .Start for-loop

--
Gitblit v1.9.1