| | |
| | | import json |
| | | import os |
| | | from omegaconf import OmegaConf |
| | | import torch |
| | |
| | | model_or_path = get_or_download_model_dir(model_or_path, model_revision, is_training=kwargs.get("is_training")) |
| | | |
| | | config = os.path.join(model_or_path, "config.yaml") |
| | | assert os.path.exists(config), "{} is not exist!".format(config) |
| | | cfg = OmegaConf.load(config) |
| | | kwargs = OmegaConf.merge(cfg, kwargs) |
| | | init_param = os.path.join(model_or_path, "model.pb") |
| | | kwargs["init_param"] = init_param |
| | | if os.path.exists(os.path.join(model_or_path, "tokens.txt")): |
| | | kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt") |
| | | if os.path.exists(os.path.join(model_or_path, "tokens.json")): |
| | | kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json") |
| | | if os.path.exists(os.path.join(model_or_path, "seg_dict")): |
| | | kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict") |
| | | if os.path.exists(os.path.join(model_or_path, "bpe.model")): |
| | | kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model") |
| | | kwargs["model"] = cfg["model"] |
| | | if os.path.exists(os.path.join(model_or_path, "am.mvn")): |
| | | kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn") |
| | | |
| | | if os.path.exists(config) and os.path.exists(os.path.join(model_or_path, "model.pb")): |
| | | # config = os.path.join(model_or_path, "config.yaml") |
| | | # assert os.path.exists(config), "{} is not exist!".format(config) |
| | | cfg = OmegaConf.load(config) |
| | | kwargs = OmegaConf.merge(cfg, kwargs) |
| | | init_param = os.path.join(model_or_path, "model.pb") |
| | | kwargs["init_param"] = init_param |
| | | if os.path.exists(os.path.join(model_or_path, "tokens.txt")): |
| | | kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt") |
| | | if os.path.exists(os.path.join(model_or_path, "tokens.json")): |
| | | kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.json") |
| | | if os.path.exists(os.path.join(model_or_path, "seg_dict")): |
| | | kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict") |
| | | if os.path.exists(os.path.join(model_or_path, "bpe.model")): |
| | | kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model") |
| | | kwargs["model"] = cfg["model"] |
| | | if os.path.exists(os.path.join(model_or_path, "am.mvn")): |
| | | kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn") |
| | | else:# configuration.json |
| | | assert os.path.exists(os.path.join(model_or_path, "configuration.json")) |
| | | with open(os.path.join(model_or_path, "configuration.json"), 'r', encoding='utf-8') as f: |
| | | conf_json = json.load(f) |
| | | config = os.path.join(model_or_path, conf_json["model"]["model_config"]) |
| | | cfg = OmegaConf.load(config) |
| | | kwargs = OmegaConf.merge(cfg, kwargs) |
| | | init_param = os.path.join(model_or_path, conf_json["model"]["model_name"]) |
| | | kwargs["init_param"] = init_param |
| | | kwargs["model"] = cfg["model"] |
| | | return OmegaConf.to_container(kwargs, resolve=True) |
| | | |
| | | def get_or_download_model_dir( |
| | |
| | | if os.path.exists(model): |
| | | model_cache_dir = model if os.path.isdir( |
| | | model) else os.path.dirname(model) |
| | | check_local_model_is_latest( |
| | | model_cache_dir, |
| | | user_agent={ |
| | | Invoke.KEY: key, |
| | | ThirdParty.KEY: "funasr" |
| | | }) |
| | | try: |
| | | check_local_model_is_latest( |
| | | model_cache_dir, |
| | | user_agent={ |
| | | Invoke.KEY: key, |
| | | ThirdParty.KEY: "funasr" |
| | | }) |
| | | except: |
| | | print("could not check the latest version") |
| | | else: |
| | | model_cache_dir = snapshot_download( |
| | | model, |