From a73123bcfc14370b74b17084bc124f00c48613e4 Mon Sep 17 00:00:00 2001
From: smohan-speech <smohan@mail.ustc.edu.cn>
Date: 星期六, 06 五月 2023 16:17:48 +0800
Subject: [PATCH] add speaker-attributed ASR task for alimeeting

---
 funasr/bin/asr_inference.py |   34 ++++++++++++++++++----------------
 1 files changed, 18 insertions(+), 16 deletions(-)

diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py
index 4722602..c18472f 100644
--- a/funasr/bin/asr_inference.py
+++ b/funasr/bin/asr_inference.py
@@ -40,7 +40,6 @@
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
 from funasr.utils import asr_utils, wav_utils, postprocess_utils
-from funasr.models.frontend.wav_frontend import WavFrontend
 
 
 header_colors = '\033[95m'
@@ -91,8 +90,6 @@
             asr_train_config, asr_model_file, cmvn_file, device
         )
         frontend = None
-        if asr_train_args.frontend is not None and asr_train_args.frontend_conf is not None:
-            frontend = WavFrontend(cmvn_file=cmvn_file, **asr_train_args.frontend_conf)
 
         logging.info("asr_model: {}".format(asr_model))
         logging.info("asr_train_args: {}".format(asr_train_args))
@@ -111,7 +108,7 @@
         # 2. Build Language model
         if lm_train_config is not None:
             lm, lm_train_args = LMTask.build_model_from_file(
-                lm_train_config, lm_file, device
+                lm_train_config, lm_file, None, device
             )
             scorers["lm"] = lm.lm
 
@@ -141,6 +138,13 @@
             token_list=token_list,
             pre_beam_score_key=None if ctc_weight == 1.0 else "full",
         )
+
+        beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
+        for scorer in scorers.values():
+            if isinstance(scorer, torch.nn.Module):
+                scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
+        logging.info(f"Beam_search: {beam_search}")
+        logging.info(f"Decoding device={device}, dtype={dtype}")
 
         # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
         if token_type is None:
@@ -198,16 +202,7 @@
         if isinstance(speech, np.ndarray):
             speech = torch.tensor(speech)
 
-        if self.frontend is not None:
-            feats, feats_len = self.frontend.forward(speech, speech_lengths)
-            feats = to_device(feats, device=self.device)
-            feats_len = feats_len.int()
-            self.asr_model.frontend = None
-        else:
-            feats = speech
-            feats_len = speech_lengths
-        lfr_factor = max(1, (feats.size()[-1] // 80) - 1)
-        batch = {"speech": feats, "speech_lengths": feats_len}
+        batch = {"speech": speech, "speech_lengths": speech_lengths}
 
         # a. To device
         batch = to_device(batch, device=self.device)
@@ -355,6 +350,9 @@
     if ngpu > 1:
         raise NotImplementedError("only single GPU decoding is supported")
     
+    for handler in logging.root.handlers[:]:
+        logging.root.removeHandler(handler)
+
     logging.basicConfig(
         level=log_level,
         format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
@@ -408,6 +406,7 @@
             data_path_and_name_and_type,
             dtype=dtype,
             fs=fs,
+            mc=True,
             batch_size=batch_size,
             key_file=key_file,
             num_workers=num_workers,
@@ -452,7 +451,7 @@
                     
                     # Write the result to each file
                     ibest_writer["token"][key] = " ".join(token)
-                    # ibest_writer["token_int"][key] = " ".join(map(str, token_int))
+                    ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                     ibest_writer["score"][key] = str(hyp.score)
                 
                 if text is not None:
@@ -463,6 +462,9 @@
                     asr_utils.print_progress(finish_count / file_count)
                     if writer is not None:
                         ibest_writer["text"][key] = text
+
+                logging.info("uttid: {}".format(key))
+                logging.info("text predictions: {}\n".format(text))
         return asr_result_list
     
     return _forward
@@ -637,4 +639,4 @@
 
 
 if __name__ == "__main__":
-    main()
+    main()
\ No newline at end of file

--
Gitblit v1.9.1