From 28ccfbfc51068a663a80764e14074df5edf2b5ba Mon Sep 17 00:00:00 2001
From: kongdeqiang <kongdeqiang960204@163.com>
Date: 星期五, 13 三月 2026 17:41:41 +0800
Subject: [PATCH] 提交

---
 funasr/__init__.py |  144 +++++++++---------------------------------------
 1 files changed, 27 insertions(+), 117 deletions(-)

diff --git a/funasr/__init__.py b/funasr/__init__.py
index 1f31505..8fa29d0 100644
--- a/funasr/__init__.py
+++ b/funasr/__init__.py
@@ -1,9 +1,6 @@
 """Initialize funasr package."""
 
 import os
-from pathlib import Path
-import torch
-import numpy as np
 
 dirname = os.path.dirname(__file__)
 version_file = os.path.join(dirname, "version.txt")
@@ -11,122 +8,35 @@
     __version__ = f.read().strip()
 
 
-def prepare_model(
-    model: str = None,
-    # mode: str = None,
-    vad_model: str = None,
-    punc_model: str = None,
-    model_hub: str = "ms",
-    cache_dir: str = None,
-    **kwargs,
-):
-    if not Path(model).exists():
-        if model_hub == "ms" or model_hub == "modelscope":
-            try:
-                from modelscope.hub.snapshot_download import snapshot_download as download_tool
-                model = name_maps_ms[model] if model is not None else None
-                vad_model = name_maps_ms[vad_model] if vad_model is not None else None
-                punc_model = name_maps_ms[punc_model] if punc_model is not None else None
-            except:
-                raise "You are exporting model from modelscope, please install modelscope and try it again. To install modelscope, you could:\n" \
-                      "\npip3 install -U modelscope\n" \
-                      "For the users in China, you could install with the command:\n" \
-                      "\npip3 install -U modelscope -i https://mirror.sjtu.edu.cn/pypi/web/simple"
-        elif model_hub == "hf" or model_hub == "huggingface":
-            download_tool = 0
-        else:
-            raise "model_hub must be on of ms or hf, but get {}".format(model_hub)
+import importlib
+import pkgutil
+
+
+def import_submodules(package, recursive=True):
+    if isinstance(package, str):
         try:
-            model = download_tool(model, cache_dir=cache_dir, revision=kwargs.get("revision", None))
-            print("model have been downloaded to: {}".format(model))
-        except:
-            raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(
-                model)
-        
-        if vad_model is not None and not Path(vad_model).exists():
-            vad_model = download_tool(vad_model, cache_dir=cache_dir)
-            print("model have been downloaded to: {}".format(vad_model))
-        if punc_model is not None and not Path(punc_model).exists():
-            punc_model = download_tool(punc_model, cache_dir=cache_dir)
-            print("model have been downloaded to: {}".format(punc_model))
-        
-        # asr
-        kwargs.update({"cmvn_file": None if model is None else os.path.join(model, "am.mvn"),
-                       "asr_model_file": None if model is None else os.path.join(model, "model.pb"),
-                       "asr_train_config": None if model is None else os.path.join(model, "config.yaml"),
-                       })
-        mode = kwargs.get("mode", None)
-        if mode is None:
-            import json
-            json_file = os.path.join(model, 'configuration.json')
-            with open(json_file, 'r') as f:
-                config_data = json.load(f)
-                if config_data['task'] == "punctuation":
-                    mode = config_data['model']['punc_model_config']['mode']
-                else:
-                    mode = config_data['model']['model_config']['mode']
-        if vad_model is not None and "vad" not in mode:
-            mode = "paraformer_vad"
-        kwargs["mode"] = mode
-        # vad
-        kwargs.update({"vad_cmvn_file": None if vad_model is None else os.path.join(vad_model, "vad.mvn"),
-                       "vad_model_file": None if vad_model is None else os.path.join(vad_model, "vad.pb"),
-                       "vad_infer_config": None if vad_model is None else os.path.join(vad_model, "vad.yaml"),
-                       })
-        # punc
-        kwargs.update({
-            "punc_model_file": None if punc_model is None else os.path.join(punc_model, "punc.pb"),
-            "punc_infer_config": None if punc_model is None else os.path.join(punc_model, "punc.yaml"),
-        })
-        
-        
-        return model, vad_model, punc_model, kwargs
+            package = importlib.import_module(package)
+        except Exception as e:
+            # 濡傛灉鎯宠鐪嬪埌瀵煎叆閿欒鐨勫叿浣撲俊鎭紝鍙互鍙栨秷娉ㄩ噴涓嬮潰鐨勮
+            # print(f"Failed to import {package}: {e}")
+            pass
+    results = {}
+    if not isinstance(package, str):
+        for loader, name, is_pkg in pkgutil.walk_packages(package.__path__, package.__name__ + "."):
+            try:
+                results[name] = importlib.import_module(name)
+            except Exception as e:
+                # 濡傛灉鎯宠鐪嬪埌瀵煎叆閿欒鐨勫叿浣撲俊鎭紝鍙互鍙栨秷娉ㄩ噴涓嬮潰鐨勮
+                # print(f"Failed to import {name}: {e}")
+                pass
+            if recursive and is_pkg:
+                results.update(import_submodules(name))
+    return results
 
-name_maps_ms = {
-    "paraformer-zh": "damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
-    "paraformer-zh-spk": "damo/speech_paraformer-large-vad-punc-spk_asr_nat-zh-cn",
-    "paraformer-en": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
-    "paraformer-en-spk": "damo/speech_paraformer-large-vad-punc_asr_nat-en-16k-common-vocab10020",
-    "paraformer-zh-streaming": "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online",
-    "fsmn-vad": "damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
-    "ct-punc": "damo/punc_ct-transformer_cn-en-common-vocab471067-large",
-    "fa-zh": "damo/speech_timestamp_prediction-v1-16k-offline",
-}
 
-def infer(task_name: str = "asr",
-            model: str = None,
-            # mode: str = None,
-            vad_model: str = None,
-            punc_model: str = None,
-            model_hub: str = "ms",
-            cache_dir: str = None,
-            **kwargs,
-          ):
+import_submodules(__name__)
 
-    model, vad_model, punc_model, kwargs = prepare_model(model, vad_model, punc_model, model_hub, cache_dir, **kwargs)
-    if task_name == "asr":
-        from funasr.bin.asr_inference_launch import inference_launch
+from funasr.auto.auto_model import AutoModel
+from funasr.auto.auto_frontend import AutoFrontend
 
-        inference_pipeline = inference_launch(**kwargs)
-    elif task_name == "":
-        pipeline = 1
-    elif task_name == "":
-        pipeline = 2
-    elif task_name == "":
-        pipeline = 2
-    
-    def _infer_fn(input, **kwargs):
-        data_type = kwargs.get('data_type', 'sound')
-        data_path_and_name_and_type = [input, 'speech', data_type]
-        raw_inputs = None
-        if isinstance(input, torch.Tensor):
-            input = input.numpy()
-        if isinstance(input, np.ndarray):
-            data_path_and_name_and_type = None
-            raw_inputs = input
-            
-
-        
-        return inference_pipeline(data_path_and_name_and_type, raw_inputs=raw_inputs, **kwargs)
-    
-    return _infer_fn
\ No newline at end of file
+os.environ["HYDRA_FULL_ERROR"] = "1"

--
Gitblit v1.9.1