| | |
| | | return kwargs |
| | | |
| | | def download_fr_ms(**kwargs): |
| | | model_or_path = kwargs.get("model_pretrain") |
| | | model_revision = kwargs.get("model_pretrain_revision") |
| | | model_or_path = kwargs.get("model") |
| | | model_revision = kwargs.get("model_revision") |
| | | if not os.path.exists(model_or_path): |
| | | model_or_path = get_or_download_model_dir(model_or_path, model_revision, third_party="funasr") |
| | | model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training")) |
| | | |
| | | config = os.path.join(model_or_path, "config.yaml") |
| | | assert os.path.exists(config), "{} is not exist!".format(config) |
| | |
| | | init_param = os.path.join(model_or_path, "model.pb") |
| | | kwargs["init_param"] = init_param |
| | | kwargs["token_list"] = os.path.join(model_or_path, "tokens.txt") |
| | | kwargs["model"] = cfg["model"] |
| | | kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn") |
| | | |
| | | return kwargs |
| | | |
| | | def get_or_download_model_dir( |
| | | model, |
| | | model_revision=None, |
| | | third_party=None): |
| | | is_training=False, |
| | | ): |
| | | """ Get local model directory or download model if necessary. |
| | | |
| | | Args: |
| | | model (str): model id or path to local model directory. |
| | | model_revision (str, optional): model version number. |
| | | third_party (str, optional): in which third party library |
| | | this function is called. |
| | | :param is_training: |
| | | """ |
| | | from modelscope.hub.check_model import check_local_model_is_latest |
| | | from modelscope.hub.snapshot_download import snapshot_download |
| | | |
| | | from modelscope.utils.constant import Invoke, ThirdParty |
| | | |
| | | key = Invoke.LOCAL_TRAINER if is_training else Invoke.PIPELINE |
| | | |
| | | if os.path.exists(model): |
| | | model_cache_dir = model if os.path.isdir( |
| | |
| | | check_local_model_is_latest( |
| | | model_cache_dir, |
| | | user_agent={ |
| | | Invoke.KEY: Invoke.LOCAL_TRAINER, |
| | | ThirdParty.KEY: third_party |
| | | Invoke.KEY: key, |
| | | ThirdParty.KEY: "funasr" |
| | | }) |
| | | else: |
| | | model_cache_dir = snapshot_download( |
| | | model, |
| | | revision=model_revision, |
| | | user_agent={ |
| | | Invoke.KEY: Invoke.TRAINER, |
| | | ThirdParty.KEY: third_party |
| | | Invoke.KEY: key, |
| | | ThirdParty.KEY: "funasr" |
| | | }) |
| | | return model_cache_dir |