zhifu gao
2024-03-05 753d579531e102e0c05358883af5d5ace02004e1
funasr/download/download_from_hub.py
@@ -13,10 +13,16 @@
        pass
    elif hub == "openai":
        model_or_path = kwargs.get("model")
        if model_or_path in name_maps_openai:
            model_or_path = name_maps_openai[model_or_path]
        kwargs["model_path"] = model_or_path
        if os.path.exists(model_or_path):
            # local path
            kwargs["model_path"] = model_or_path
            kwargs["model"] = "WhisperWarp"
        else:
            # model name
            if model_or_path in name_maps_openai:
                model_or_path = name_maps_openai[model_or_path]
            kwargs["model_path"] = model_or_path
    return kwargs
def download_from_ms(**kwargs):
@@ -24,7 +30,7 @@
    if model_or_path in name_maps_ms:
        model_or_path = name_maps_ms[model_or_path]
    model_revision = kwargs.get("model_revision")
    if not os.path.exists(model_or_path):
    if not os.path.exists(model_or_path) and "model_path" not in kwargs:
        try:
            model_or_path = get_or_download_model_dir(model_or_path, model_revision,
                                                      is_training=kwargs.get("is_training"),
@@ -32,7 +38,7 @@
        except Exception as e:
            print(f"Download: {model_or_path} failed!: {e}")
    
    kwargs["model_path"] = model_or_path
    kwargs["model_path"] = model_or_path if "model_path" not in kwargs else kwargs["model_path"]
    
    if os.path.exists(os.path.join(model_or_path, "configuration.json")):
        with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f: