From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/bin/vad_inference_launch.py |   94 +++++++++++++++++-----------------------------
 1 files changed, 35 insertions(+), 59 deletions(-)

diff --git a/funasr/bin/vad_inference_launch.py b/funasr/bin/vad_inference_launch.py
index 2ccc716..0cde570 100644
--- a/funasr/bin/vad_inference_launch.py
+++ b/funasr/bin/vad_inference_launch.py
@@ -1,56 +1,33 @@
 #!/usr/bin/env python3
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
 
 import torch
+
 torch.set_num_threads(1)
 
 import argparse
 import logging
 import os
 import sys
-from typing import Union, Dict, Any
-
-from funasr.utils import config_argparse
-from funasr.utils.cli_utils import get_commandline_args
-from funasr.utils.types import str2bool
-from funasr.utils.types import str2triple_str
-from funasr.utils.types import str_or_none
-
-import argparse
-import logging
-import os
-import sys
 import json
-from pathlib import Path
-from typing import Any
-from typing import List
 from typing import Optional
-from typing import Sequence
-from typing import Tuple
 from typing import Union
-from typing import Dict
 
-import math
 import numpy as np
 import torch
-from typeguard import check_argument_types
-from typeguard import check_return_type
-
+from funasr.build_utils.build_streaming_iterator import build_streaming_iterator
 from funasr.fileio.datadir_writer import DatadirWriter
-from funasr.modules.scorers.scorer_interface import BatchScorerInterface
-from funasr.modules.subsampling import TooShortUttError
-from funasr.tasks.vad import VADTask
-from funasr.torch_utils.device_funcs import to_device
 from funasr.torch_utils.set_all_random_seed import set_all_random_seed
 from funasr.utils import config_argparse
 from funasr.utils.cli_utils import get_commandline_args
 from funasr.utils.types import str2bool
 from funasr.utils.types import str2triple_str
 from funasr.utils.types import str_or_none
-from funasr.utils import asr_utils, wav_utils, postprocess_utils
-from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
 from funasr.bin.vad_infer import Speech2VadSegment, Speech2VadSegmentOnline
+
 
 def inference_vad(
         batch_size: int,
@@ -69,10 +46,8 @@
         num_workers: int = 1,
         **kwargs,
 ):
-    assert check_argument_types()
     if batch_size > 1:
         raise NotImplementedError("batch decoding is not implemented")
-
 
     logging.basicConfig(
         level=log_level,
@@ -110,16 +85,14 @@
             if isinstance(raw_inputs, torch.Tensor):
                 raw_inputs = raw_inputs.numpy()
             data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        loader = VADTask.build_streaming_iterator(
-            data_path_and_name_and_type,
+        loader = build_streaming_iterator(
+            task_name="vad",
+            preprocess_args=None,
+            data_path_and_name_and_type=data_path_and_name_and_type,
             dtype=dtype,
             batch_size=batch_size,
             key_file=key_file,
             num_workers=num_workers,
-            preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
-            collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
-            allow_variable_data_keys=allow_variable_data_keys,
-            inference=True,
         )
 
         finish_count = 0
@@ -150,10 +123,11 @@
                 vad_results.append(item)
                 if writer is not None:
                     ibest_writer["text"][keys[i]] = "{}".format(results[i])
-
+        torch.cuda.empty_cache()
         return vad_results
 
     return _forward
+
 
 def inference_vad_online(
         batch_size: int,
@@ -172,8 +146,6 @@
         num_workers: int = 1,
         **kwargs,
 ):
-    assert check_argument_types()
-
 
     logging.basicConfig(
         level=log_level,
@@ -212,16 +184,14 @@
             if isinstance(raw_inputs, torch.Tensor):
                 raw_inputs = raw_inputs.numpy()
             data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
-        loader = VADTask.build_streaming_iterator(
-            data_path_and_name_and_type,
+        loader = build_streaming_iterator(
+            task_name="vad",
+            preprocess_args=None,
+            data_path_and_name_and_type=data_path_and_name_and_type,
             dtype=dtype,
             batch_size=batch_size,
             key_file=key_file,
             num_workers=num_workers,
-            preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
-            collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
-            allow_variable_data_keys=allow_variable_data_keys,
-            inference=True,
         )
 
         finish_count = 0
@@ -237,9 +207,13 @@
             ibest_writer = None
 
         vad_results = []
-        batch_in_cache = param_dict['in_cache'] if param_dict is not None else dict()
-        is_final = param_dict.get('is_final', False) if param_dict is not None else False
-        max_end_sil = param_dict.get('max_end_sil', 800) if param_dict is not None else 800
+        if param_dict is None:
+            param_dict = dict()
+            param_dict['in_cache'] = dict()
+            param_dict['is_final'] = True
+        batch_in_cache = param_dict.get('in_cache', dict())
+        is_final = param_dict.get('is_final', False)
+        max_end_sil = param_dict.get('max_end_sil', 800)
         for keys, batch in loader:
             assert isinstance(batch, dict), type(batch)
             assert all(isinstance(s, str) for s in keys), keys
@@ -265,6 +239,16 @@
         return vad_results
 
     return _forward
+
+
+def inference_launch(mode, **kwargs):
+    if mode == "offline":
+        return inference_vad(**kwargs)
+    elif mode == "online":
+        return inference_vad_online(**kwargs)
+    else:
+        logging.info("Unknown decoding mode: {}".format(mode))
+        return None
 
 
 def get_parser():
@@ -358,15 +342,6 @@
     return parser
 
 
-def inference_launch(mode, **kwargs):
-    if mode == "offline":
-        return inference_vad(**kwargs)
-    elif mode == "online":
-        return inference_vad_online(**kwargs)
-    else:
-        logging.info("Unknown decoding mode: {}".format(mode))
-        return None
-
 def main(cmd=None):
     print(get_commandline_args(), file=sys.stderr)
     parser = get_parser()
@@ -397,5 +372,6 @@
     inference_pipeline = inference_launch(**kwargs)
     return inference_pipeline(kwargs["data_path_and_name_and_type"])
 
+
 if __name__ == "__main__":
     main()

--
Gitblit v1.9.1