From f9fed09e96f43e7eab88378fc444c4987933badb Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期五, 09 十二月 2022 23:57:51 +0800
Subject: [PATCH] Merge pull request #10 from alibaba-damo-academy/dev

---
 funasr/bin/asr_inference.py |  225 +++++++++++++++++++++++++++++++++++++-------------------
 1 files changed, 149 insertions(+), 76 deletions(-)

diff --git a/funasr/bin/asr_inference.py b/funasr/bin/asr_inference.py
index 6ee0ffe..bd5d7f4 100755
--- a/funasr/bin/asr_inference.py
+++ b/funasr/bin/asr_inference.py
@@ -12,6 +12,7 @@
 from typing import Sequence
 from typing import Tuple
 from typing import Union
+from typing import Dict
 
 import numpy as np
 import torch
@@ -38,7 +39,21 @@
 from funasr.utils.types import str2bool
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
+from funasr.utils import asr_utils, wav_utils, postprocess_utils
+from funasr.models.frontend.wav_frontend import WavFrontend
 
+from modelscope.utils.logger import get_logger
+
+logger = get_logger()
+
+header_colors = '\033[95m'
+end_colors = '\033[0m'
+
+global_asr_language: str = 'zh-cn'
+global_sample_rate: Union[int, Dict[Any, int]] = {
+    'audio_fs': 16000,
+    'model_fs': 16000
+}
 
 class Speech2Text:
     """Speech2Text class
@@ -72,6 +87,7 @@
             penalty: float = 0.0,
             nbest: int = 1,
             streaming: bool = False,
+            frontend_conf: dict = None,
             **kwargs,
     ):
         assert check_argument_types()
@@ -81,6 +97,9 @@
         asr_model, asr_train_args = ASRTask.build_model_from_file(
             asr_train_config, asr_model_file, device
         )
+        if asr_model.frontend is None and frontend_conf is not None:
+            frontend = WavFrontend(**frontend_conf)
+            asr_model.frontend = frontend
         logging.info("asr_model: {}".format(asr_model))
         logging.info("asr_train_args: {}".format(asr_train_args))
         asr_model.to(dtype=getattr(torch, dtype)).eval()
@@ -129,36 +148,6 @@
             pre_beam_score_key=None if ctc_weight == 1.0 else "full",
         )
 
-        # TODO(karita): make all scorers batchfied
-        if batch_size == 1:
-            non_batch = [
-                k
-                for k, v in beam_search.full_scorers.items()
-                if not isinstance(v, BatchScorerInterface)
-            ]
-            if len(non_batch) == 0:
-                if streaming:
-                    beam_search.__class__ = BatchBeamSearchOnlineSim
-                    beam_search.set_streaming_config(asr_train_config)
-                    logging.info(
-                        "BatchBeamSearchOnlineSim implementation is selected."
-                    )
-                else:
-                    beam_search.__class__ = BatchBeamSearch
-                    logging.info("BatchBeamSearch implementation is selected.")
-            else:
-                logging.warning(
-                    f"As non-batch scorers {non_batch} are found, "
-                    f"fall back to non-batch implementation."
-                )
-
-            beam_search.to(device=device, dtype=getattr(torch, dtype)).eval()
-            for scorer in scorers.values():
-                if isinstance(scorer, torch.nn.Module):
-                    scorer.to(device=device, dtype=getattr(torch, dtype)).eval()
-            logging.info(f"Beam_search: {beam_search}")
-            logging.info(f"Decoding device={device}, dtype={dtype}")
-
         # 5. [Optional] Build Text converter: e.g. bpe-sym -> Text
         if token_type is None:
             token_type = asr_train_args.token_type
@@ -203,7 +192,7 @@
         """Inference
 
         Args:
-            data: Input speech data
+            speech: Input speech data
         Returns:
             text, token, token_int, hyp
 
@@ -216,6 +205,7 @@
 
         # data: (Nsamples,) -> (1, Nsamples)
         speech = speech.unsqueeze(0).to(getattr(torch, self.dtype))
+        lfr_factor = max(1, (speech.size()[-1] // 80) - 1)
         # lengths: (1,)
         lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1))
         batch = {"speech": speech, "speech_lengths": lengths}
@@ -264,32 +254,36 @@
 
 
 def inference(
-        output_dir: str,
         maxlenratio: float,
         minlenratio: float,
         batch_size: int,
-        dtype: str,
         beam_size: int,
         ngpu: int,
-        seed: int,
         ctc_weight: float,
         lm_weight: float,
-        ngram_weight: float,
         penalty: float,
-        nbest: int,
-        num_workers: int,
         log_level: Union[int, str],
-        data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
-        key_file: Optional[str],
+        data_path_and_name_and_type,
         asr_train_config: Optional[str],
         asr_model_file: Optional[str],
-        lm_train_config: Optional[str],
-        lm_file: Optional[str],
-        word_lm_train_config: Optional[str],
-        token_type: Optional[str],
-        bpemodel: Optional[str],
-        allow_variable_data_keys: bool,
-        streaming: bool,
+        audio_lists: Union[List[Any], bytes] = None,
+        lm_train_config: Optional[str] = None,
+        lm_file: Optional[str] = None,
+        token_type: Optional[str] = None,
+        key_file: Optional[str] = None,
+        word_lm_train_config: Optional[str] = None,
+        bpemodel: Optional[str] = None,
+        allow_variable_data_keys: bool = False,
+        streaming: bool = False,
+        output_dir: Optional[str] = None,
+        dtype: str = "float32",
+        seed: int = 0,
+        ngram_weight: float = 0.9,
+        nbest: int = 1,
+        num_workers: int = 1,
+        frontend_conf: dict = None,
+        fs: Union[dict, int] = 16000,
+        lang: Optional[str] = None,
         **kwargs,
 ):
     assert check_argument_types()
@@ -309,7 +303,46 @@
         device = "cuda"
     else:
         device = "cpu"
+    hop_length: int = 160
+    sr: int = 16000
+    if isinstance(fs, int):
+        sr = fs
+    else:
+        if 'model_fs' in fs and fs['model_fs'] is not None:
+            sr = fs['model_fs']
+    # data_path_and_name_and_type for modelscope: (data from audio_lists)
+    # ['speech', 'sound', 'am.mvn']
+    # data_path_and_name_and_type for funasr:
+    # [('/mnt/data/jiangyu.xzy/exp/maas/mvn.1.scp', 'speech', 'kaldi_ark')]
+    if isinstance(data_path_and_name_and_type[0], Tuple):
+        features_type: str = data_path_and_name_and_type[0][1]
+    elif isinstance(data_path_and_name_and_type[0], str):
+        features_type: str = data_path_and_name_and_type[1]
+    else:
+        raise NotImplementedError("unknown features type:{0}".format(data_path_and_name_and_type))
+    if features_type != 'sound':
+        frontend_conf = None
+        flag_modelscope = False
+    else:
+        flag_modelscope = True
+    if frontend_conf is not None:
+        if 'hop_length' in frontend_conf:
+            hop_length = frontend_conf['hop_length']
 
+    finish_count = 0
+    file_count = 1
+    if flag_modelscope and not isinstance(data_path_and_name_and_type[0], Tuple):
+        data_path_and_name_and_type_new = [
+            audio_lists, data_path_and_name_and_type[0], data_path_and_name_and_type[1]
+        ]
+        if isinstance(audio_lists, bytes):
+            file_count = 1
+        else:
+            file_count = len(audio_lists)
+        if len(data_path_and_name_and_type) >= 3 and frontend_conf is not None:
+            mvn_file = data_path_and_name_and_type[2]
+            mvn_data = wav_utils.extract_CMVN_featrures(mvn_file)
+            frontend_conf['mvn_data'] = mvn_data
     # 1. Set random-seed
     set_all_random_seed(seed)
 
@@ -332,45 +365,66 @@
         penalty=penalty,
         nbest=nbest,
         streaming=streaming,
+        frontend_conf=frontend_conf,
     )
     logging.info("speech2text_kwargs: {}".format(speech2text_kwargs))
     speech2text = Speech2Text(**speech2text_kwargs)
 
     # 3. Build data-iterator
-    loader = ASRTask.build_streaming_iterator(
-        data_path_and_name_and_type,
-        dtype=dtype,
-        batch_size=batch_size,
-        key_file=key_file,
-        num_workers=num_workers,
-        preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
-        collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
-        allow_variable_data_keys=allow_variable_data_keys,
-        inference=True,
-    )
+    if flag_modelscope:
+        loader = ASRTask.build_streaming_iterator_modelscope(
+            data_path_and_name_and_type_new,
+            dtype=dtype,
+            batch_size=batch_size,
+            key_file=key_file,
+            num_workers=num_workers,
+            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+            allow_variable_data_keys=allow_variable_data_keys,
+            inference=True,
+            sample_rate=fs
+        )
+    else:
+        loader = ASRTask.build_streaming_iterator(
+            data_path_and_name_and_type,
+            dtype=dtype,
+            batch_size=batch_size,
+            key_file=key_file,
+            num_workers=num_workers,
+            preprocess_fn=ASRTask.build_preprocess_fn(speech2text.asr_train_args, False),
+            collate_fn=ASRTask.build_collate_fn(speech2text.asr_train_args, False),
+            allow_variable_data_keys=allow_variable_data_keys,
+            inference=True,
+        )
 
     # 7 .Start for-loop
     # FIXME(kamo): The output format should be discussed about
-    with DatadirWriter(output_dir) as writer:
-        for keys, batch in loader:
-            assert isinstance(batch, dict), type(batch)
-            assert all(isinstance(s, str) for s in keys), keys
-            _bs = len(next(iter(batch.values())))
-            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
-            batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+    asr_result_list = []
+    if output_dir is not None:
+        writer = DatadirWriter(output_dir)
+    else:
+        writer = None
 
-            # N-best list of (text, token, token_int, hyp_object)
-            try:
-                results = speech2text(**batch)
-            except TooShortUttError as e:
-                logging.warning(f"Utterance {keys} {e}")
-                hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
-                results = [[" ", ["<space>"], [2], hyp]] * nbest
+    for keys, batch in loader:
+        assert isinstance(batch, dict), type(batch)
+        assert all(isinstance(s, str) for s in keys), keys
+        _bs = len(next(iter(batch.values())))
+        assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+        batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
 
-            # Only supporting batch_size==1
-            key = keys[0]
-            for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
-                # Create a directory: outdir/{n}best_recog
+        # N-best list of (text, token, token_int, hyp_object)
+        try:
+            results = speech2text(**batch)
+        except TooShortUttError as e:
+            logging.warning(f"Utterance {keys} {e}")
+            hyp = Hypothesis(score=0.0, scores={}, states={}, yseq=[])
+            results = [[" ", ["<space>"], [2], hyp]] * nbest
+
+        # Only supporting batch_size==1
+        key = keys[0]
+        for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+            # Create a directory: outdir/{n}best_recog
+            if writer is not None:
                 ibest_writer = writer[f"{n}best_recog"]
 
                 # Write the result to each file
@@ -378,8 +432,25 @@
                 ibest_writer["token_int"][key] = " ".join(map(str, token_int))
                 ibest_writer["score"][key] = str(hyp.score)
 
-                if text is not None:
+            if text is not None:
+                text_postprocessed = postprocess_utils.sentence_postprocess(token)
+                item = {'key': key, 'value': text_postprocessed}
+                asr_result_list.append(item)
+                finish_count += 1
+                asr_utils.print_progress(finish_count / file_count)
+                if writer is not None:
                     ibest_writer["text"][key] = text
+    return asr_result_list
+
+
+def set_parameters(language: str = None,
+                   sample_rate: Union[int, Dict[Any, int]] = None):
+    if language is not None:
+        global global_asr_language
+        global_asr_language = language
+    if sample_rate is not None:
+        global global_sample_rate
+        global_sample_rate = sample_rate
 
 
 def get_parser():
@@ -432,6 +503,8 @@
         required=True,
         action="append",
     )
+    group.add_argument("--audio_lists", type=list, default=None)
+    # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
     group.add_argument("--key_file", type=str_or_none)
     group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
 

--
Gitblit v1.9.1