jmwang66
2022-12-09 0b8348376a20a6888d116982e346ada5fa5d15ab
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
from pathlib import Path
 
import torch
import yaml
 
 
class NoAliasSafeDumper(yaml.SafeDumper):
    # Disable anchor/alias in yaml because looks ugly
    def ignore_aliases(self, data):
        return True
 
 
def yaml_no_alias_safe_dump(data, stream=None, **kwargs):
    """Safe-dump in yaml with no anchor/alias"""
    return yaml.dump(
        data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs
    )
 
 
def gen_conf(file, out_dir):
    conf = torch.load(file)["config"]
    conf["oss_bucket"] = "null"
    print(conf)
    output_dir = Path(out_dir)
    output_dir.mkdir(parents=True, exist_ok=True)
    with (output_dir / "config.yaml").open("w", encoding="utf-8") as f:
        yaml_no_alias_safe_dump(conf, f, indent=4, sort_keys=False)
 
 
if __name__ == "__main__":
    import sys
 
    in_f = sys.argv[1]
    out_f = sys.argv[2]
    gen_conf(in_f, out_f)