1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
| import os
| import shutil
|
| from modelscope.hub.snapshot_download import snapshot_download
|
|
| if __name__ == '__main__':
| import sys
|
| model = sys.argv[1]
| checkpoint_dir = sys.argv[2]
| checkpoint_name = sys.argv[3]
|
| try:
| pretrained_model_path = snapshot_download(model, cache_dir=checkpoint_dir)
| except BaseException:
| raise BaseException(f"Please download pretrain model from ModelScope firstly.")
| shutil.copy(os.path.join(checkpoint_dir, checkpoint_name), os.path.join(pretrained_model_path, "model.pb"))
| shutil.copy(os.path.join(checkpoint_dir, "config.yaml"), os.path.join(pretrained_model_path, "config.yaml"))
|
|