游雁
2023-12-11 d77910eb6d171727f2350e45c31c91436c4c8891
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import os
from omegaconf import OmegaConf
import torch
from funasr.utils.name_maps_from_hub import name_maps_ms, name_maps_hf
 
def download_model(**kwargs):
    model_hub = kwargs.get("model_hub", "ms")
    if model_hub == "ms":
        kwargs = download_fr_ms(**kwargs)
    
    return kwargs
 
def download_fr_ms(**kwargs):
    model_or_path = kwargs.get("model_pretrain")
    model_revision = kwargs.get("model_pretrain_revision")
    if not os.path.exists(model_or_path):
        model_or_path = get_or_download_model_dir(model_or_path, model_revision, third_party="funasr")
    
    config = os.path.join(model_or_path, "config.yaml")
    assert os.path.exists(config), "{} is not exist!".format(config)
    cfg = OmegaConf.load(config)
    kwargs = OmegaConf.merge(cfg, kwargs)
    init_param = os.path.join(model_or_path, "model.pb")
    kwargs["init_param"] = init_param
    kwargs["token_list"] = os.path.join(model_or_path, "tokens.txt")
    
    return kwargs
 
def get_or_download_model_dir(
                              model,
                              model_revision=None,
                              third_party=None):
    """ Get local model directory or download model if necessary.
 
    Args:
        model (str): model id or path to local model directory.
        model_revision  (str, optional): model version number.
        third_party (str, optional): in which third party library
            this function is called.
    """
    from modelscope.hub.check_model import check_local_model_is_latest
    from modelscope.hub.snapshot_download import snapshot_download
 
    from modelscope.utils.constant import Invoke, ThirdParty
    
    if os.path.exists(model):
        model_cache_dir = model if os.path.isdir(
            model) else os.path.dirname(model)
        check_local_model_is_latest(
            model_cache_dir,
            user_agent={
                Invoke.KEY: Invoke.LOCAL_TRAINER,
                ThirdParty.KEY: third_party
            })
    else:
        model_cache_dir = snapshot_download(
            model,
            revision=model_revision,
            user_agent={
                Invoke.KEY: Invoke.TRAINER,
                ThirdParty.KEY: third_party
            })
    return model_cache_dir