From 2acef4bdaea588adee3098a057a395937dff4e6a Mon Sep 17 00:00:00 2001
From: 雾聪 <wucong.lyb@alibaba-inc.com>
Date: 星期一, 08 一月 2024 16:51:42 +0800
Subject: [PATCH] json stamp_sents for websocket-server

---
 funasr/bin/asr_inference_launch.py |  337 +++++++++----------------------------------------------
 1 files changed, 56 insertions(+), 281 deletions(-)

diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index e93d740..f34bfb2 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -20,7 +20,8 @@
 import numpy as np
 import torch
 import torchaudio
-import soundfile
+# import librosa
+import librosa
 import yaml
 
 from funasr.bin.asr_infer import Speech2Text
@@ -47,13 +48,13 @@
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
 from funasr.utils.vad_utils import slice_padding_fbank
-from funasr.utils.speaker_utils import (check_audio_list, 
-                                        sv_preprocess, 
-                                        sv_chunk, 
-                                        CAMPPlus, 
-                                        extract_feature, 
+from funasr.utils.speaker_utils import (check_audio_list,
+                                        sv_preprocess,
+                                        sv_chunk,
+                                        extract_feature,
                                         postprocess,
                                         distribute_spk)
+import funasr.modules.cnn as sv_module
 from funasr.build_utils.build_model_from_file import build_model_from_file
 from funasr.utils.cluster_backend import ClusterBackend
 from funasr.utils.modelscope_utils import get_cache_dir
@@ -675,11 +676,13 @@
                 beg_idx = end_idx
                 batch = {"speech": speech_j, "speech_lengths": speech_lengths_j}
                 batch = to_device(batch, device=device)
-                # print("batch: ", speech_j.shape[0])
+
                 beg_asr = time.time()
                 results = speech2text(**batch)
                 end_asr = time.time()
-                # print("time cost asr: ", end_asr - beg_asr)
+                if speech2text.device != "cpu":
+                    print("batch: ", speech_j.shape[0])
+                    print("time cost asr: ", end_asr - beg_asr)
 
                 if len(results) < 1:
                     results = [["", [], [], [], [], [], []]]
@@ -815,7 +818,15 @@
         format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
     )
 
-    sv_model_file = asr_model_file.replace("model.pb", "campplus_cn_common.bin")
+    sv_model_config_path = asr_model_file.replace("model.pb", "sv_model_config.yaml")
+    if not os.path.exists(sv_model_config_path):
+        sv_model_config = {'sv_model_class': 'CAMPPlus','sv_model_file': 'campplus_cn_common.bin', 'models_config': {}}
+    else:
+        with open(sv_model_config_path, 'r') as f:
+            sv_model_config = yaml.load(f, Loader=yaml.FullLoader)
+    if sv_model_config['models_config'] is None:
+        sv_model_config['models_config'] = {}
+    sv_model_file = asr_model_file.replace("model.pb", sv_model_config['sv_model_file'])
 
     if param_dict is not None:
         hotword_list_or_file = param_dict.get('hotword')
@@ -941,9 +952,15 @@
             #####  speaker_verification  #####
             ##################################
             # load sv model
-            sv_model_dict = torch.load(sv_model_file, map_location=torch.device('cpu'))
-            sv_model = CAMPPlus()
+            if ngpu > 0:
+                sv_model_dict = torch.load(sv_model_file)
+                sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config'])
+                sv_model.cuda()
+            else:
+                sv_model_dict = torch.load(sv_model_file, map_location=torch.device('cpu'))
+                sv_model = getattr(sv_module, sv_model_config['sv_model_class'])(**sv_model_config['models_config'])
             sv_model.load_state_dict(sv_model_dict)
+            print(f'load sv model params: {sv_model_file}')
             sv_model.eval()
             cb_model = ClusterBackend()
             vad_segments = []
@@ -953,24 +970,31 @@
                 ed = int(vadsegment[1]) / 1000
                 vad_segments.append(
                     [st, ed, audio[int(st * 16000):int(ed * 16000)]])
-            check_audio_list(vad_segments)
-            # sv pipeline
-            segments = sv_chunk(vad_segments)
-            embeddings = []
-            for s in segments:
-                #_, embs = self.sv_pipeline([s[2]], output_emb=True)
-                # embeddings.append(embs)
-                wavs = sv_preprocess([s[2]])
-                # embs = self.forward(wavs)
-                embs = []
-                for x in wavs:
-                    x = extract_feature([x])
-                    embs.append(sv_model(x))
-                embs = torch.cat(embs)
-                embeddings.append(embs.detach().numpy())
-            embeddings = np.concatenate(embeddings)
-            labels = cb_model(embeddings)
-            sv_output = postprocess(segments, vad_segments, labels, embeddings)
+            audio_dur = check_audio_list(vad_segments)
+            if audio_dur > 5:
+                # sv pipeline
+                segments = sv_chunk(vad_segments)
+                embeddings = []
+                for s in segments:
+                    #_, embs = self.sv_pipeline([s[2]], output_emb=True)
+                    # embeddings.append(embs)
+                    wavs = sv_preprocess([s[2]])
+                    # embs = self.forward(wavs)
+                    embs = []
+                    for x in wavs:
+                        x = extract_feature([x])
+                        if ngpu > 0:
+                            x = x.cuda()
+                        embs.append(sv_model(x))
+                    embs = torch.cat(embs)
+                    embeddings.append(embs.cpu().detach().numpy())
+                embeddings = np.concatenate(embeddings)
+                labels = cb_model(embeddings)
+                sv_output = postprocess(segments, vad_segments, labels, embeddings)
+            else:
+                # fake speaker res for too shot utterance
+                sv_output = [[0.0, vadsegments[-1][-1]/1000.0, 0]]
+                logging.warning("Too short utterence found: {}, return default speaker results.".format(keys))
 
             speech, speech_lengths = batch["speech"], batch["speech_lengths"]
 
@@ -1279,7 +1303,8 @@
             try:
                 raw_inputs = torchaudio.load(data_path_and_name_and_type[0])[0][0]
             except:
-                raw_inputs = soundfile.read(data_path_and_name_and_type[0], dtype='float32')[0]
+                # raw_inputs = librosa.load(data_path_and_name_and_type[0], dtype='float32')[0]
+                raw_inputs, sr = librosa.load(data_path_and_name_and_type[0], dtype='float32')
                 if raw_inputs.ndim == 2:
                     raw_inputs = raw_inputs[:, 0]
                 raw_inputs = torch.tensor(raw_inputs)
@@ -2218,259 +2243,9 @@
         logging.info("Unknown decoding mode: {}".format(mode))
         return None
 
-
-def get_parser():
-    parser = config_argparse.ArgumentParser(
-        description="ASR Decoding",
-        formatter_class=argparse.ArgumentDefaultsHelpFormatter,
-    )
-
-    # Note(kamo): Use '_' instead of '-' as separator.
-    # '-' is confusing if written in yaml.
-    parser.add_argument(
-        "--log_level",
-        type=lambda x: x.upper(),
-        default="INFO",
-        choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
-        help="The verbose level of logging",
-    )
-
-    parser.add_argument("--output_dir", type=str, required=True)
-    parser.add_argument(
-        "--ngpu",
-        type=int,
-        default=0,
-        help="The number of gpus. 0 indicates CPU mode",
-    )
-    parser.add_argument(
-        "--njob",
-        type=int,
-        default=1,
-        help="The number of jobs for each gpu",
-    )
-    parser.add_argument(
-        "--gpuid_list",
-        type=str,
-        default="",
-        help="The visible gpus",
-    )
-    parser.add_argument("--seed", type=int, default=0, help="Random seed")
-    parser.add_argument(
-        "--dtype",
-        default="float32",
-        choices=["float16", "float32", "float64"],
-        help="Data type",
-    )
-    parser.add_argument(
-        "--num_workers",
-        type=int,
-        default=1,
-        help="The number of workers used for DataLoader",
-    )
-
-    group = parser.add_argument_group("Input data related")
-    group.add_argument(
-        "--data_path_and_name_and_type",
-        type=str2triple_str,
-        required=True,
-        action="append",
-    )
-    group.add_argument("--key_file", type=str_or_none)
-    parser.add_argument(
-        "--hotword",
-        type=str_or_none,
-        default=None,
-        help="hotword file path or hotwords seperated by space"
-    )
-    group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
-    group.add_argument(
-        "--mc",
-        type=bool,
-        default=False,
-        help="MultiChannel input",
-    )
-
-    group = parser.add_argument_group("The model configuration related")
-    group.add_argument(
-        "--vad_infer_config",
-        type=str,
-        help="VAD infer configuration",
-    )
-    group.add_argument(
-        "--vad_model_file",
-        type=str,
-        help="VAD model parameter file",
-    )
-    group.add_argument(
-        "--punc_infer_config",
-        type=str,
-        help="PUNC infer configuration",
-    )
-    group.add_argument(
-        "--punc_model_file",
-        type=str,
-        help="PUNC model parameter file",
-    )
-    group.add_argument(
-        "--cmvn_file",
-        type=str,
-        help="Global CMVN file",
-    )
-    group.add_argument(
-        "--asr_train_config",
-        type=str,
-        help="ASR training configuration",
-    )
-    group.add_argument(
-        "--asr_model_file",
-        type=str,
-        help="ASR model parameter file",
-    )
-    group.add_argument(
-        "--sv_model_file",
-        type=str,
-        help="SV model parameter file",
-    )
-    group.add_argument(
-        "--lm_train_config",
-        type=str,
-        help="LM training configuration",
-    )
-    group.add_argument(
-        "--lm_file",
-        type=str,
-        help="LM parameter file",
-    )
-    group.add_argument(
-        "--word_lm_train_config",
-        type=str,
-        help="Word LM training configuration",
-    )
-    group.add_argument(
-        "--word_lm_file",
-        type=str,
-        help="Word LM parameter file",
-    )
-    group.add_argument(
-        "--ngram_file",
-        type=str,
-        help="N-gram parameter file",
-    )
-    group.add_argument(
-        "--model_tag",
-        type=str,
-        help="Pretrained model tag. If specify this option, *_train_config and "
-             "*_file will be overwritten",
-    )
-    group.add_argument(
-        "--beam_search_config",
-        default={},
-        help="The keyword arguments for transducer beam search.",
-    )
-
-    group = parser.add_argument_group("Beam-search related")
-    group.add_argument(
-        "--batch_size",
-        type=int,
-        default=1,
-        help="The batch size for inference",
-    )
-    group.add_argument("--nbest", type=int, default=5, help="Output N-best hypotheses")
-    group.add_argument("--beam_size", type=int, default=20, help="Beam size")
-    group.add_argument("--penalty", type=float, default=0.0, help="Insertion penalty")
-    group.add_argument(
-        "--maxlenratio",
-        type=float,
-        default=0.0,
-        help="Input length ratio to obtain max output length. "
-             "If maxlenratio=0.0 (default), it uses a end-detect "
-             "function "
-             "to automatically find maximum hypothesis lengths."
-             "If maxlenratio<0.0, its absolute value is interpreted"
-             "as a constant max output length",
-    )
-    group.add_argument(
-        "--minlenratio",
-        type=float,
-        default=0.0,
-        help="Input length ratio to obtain min output length",
-    )
-    group.add_argument(
-        "--ctc_weight",
-        type=float,
-        default=0.0,
-        help="CTC weight in joint decoding",
-    )
-    group.add_argument("--lm_weight", type=float, default=1.0, help="RNNLM weight")
-    group.add_argument("--ngram_weight", type=float, default=0.9, help="ngram weight")
-    group.add_argument("--streaming", type=str2bool, default=False)
-    group.add_argument("--fake_streaming", type=str2bool, default=False)
-    group.add_argument("--full_utt", type=str2bool, default=False)
-    group.add_argument("--chunk_size", type=int, default=16)
-    group.add_argument("--left_context", type=int, default=16)
-    group.add_argument("--right_context", type=int, default=0)
-    group.add_argument(
-        "--display_partial_hypotheses",
-        type=bool,
-        default=False,
-        help="Whether to display partial hypotheses during chunk-by-chunk inference.",
-    )
-
-    group = parser.add_argument_group("Dynamic quantization related")
-    group.add_argument(
-        "--quantize_asr_model",
-        type=bool,
-        default=False,
-        help="Apply dynamic quantization to ASR model.",
-    )
-    group.add_argument(
-        "--quantize_modules",
-        nargs="*",
-        default=None,
-        help="""Module names to apply dynamic quantization on.
-        The module names are provided as a list, where each name is separated
-        by a comma (e.g.: --quantize-config=[Linear,LSTM,GRU]).
-        Each specified name should be an attribute of 'torch.nn', e.g.:
-        torch.nn.Linear, torch.nn.LSTM, torch.nn.GRU, ...""",
-    )
-    group.add_argument(
-        "--quantize_dtype",
-        type=str,
-        default="qint8",
-        choices=["float16", "qint8"],
-        help="Dtype for dynamic quantization.",
-    )
-
-    group = parser.add_argument_group("Text converter related")
-    group.add_argument(
-        "--token_type",
-        type=str_or_none,
-        default=None,
-        choices=["char", "bpe", None],
-        help="The token type for ASR model. "
-             "If not given, refers from the training args",
-    )
-    group.add_argument(
-        "--bpemodel",
-        type=str_or_none,
-        default=None,
-        help="The model path of sentencepiece. "
-             "If not given, refers from the training args",
-    )
-    group.add_argument("--token_num_relax", type=int, default=1, help="")
-    group.add_argument("--decoding_ind", type=int, default=0, help="")
-    group.add_argument("--decoding_mode", type=str, default="model1", help="")
-    group.add_argument(
-        "--ctc_weight2",
-        type=float,
-        default=0.0,
-        help="CTC weight in joint decoding",
-    )
-    return parser
-
-
 def main(cmd=None):
     print(get_commandline_args(), file=sys.stderr)
+    from funasr.bin.argument import get_parser
     parser = get_parser()
     parser.add_argument(
         "--mode",

--
Gitblit v1.9.1