| | |
| | | import os |
| | | import torch |
| | | from torch.nn import functional as F |
| | | |
| | | import yaml |
| | | import numpy as np |
| | | |
| | | def sequence_mask(lengths, maxlen=None, dtype=torch.float32, device=None): |
| | |
| | | outputs += inputs |
| | | return outputs |
| | | |
| | | |
| | | def proc_tf_vocab(vocab_path): |
| | | with open(vocab_path, encoding="utf-8") as f: |
| | | token_list = [line.rstrip() for line in f] |
| | | if '<unk>' not in token_list: |
| | | token_list.append('<unk>') |
| | | return token_list |
| | | |
| | | |
| | | def gen_config_for_tfmodel(config_path, vocab_path, output_dir): |
| | | token_list = proc_tf_vocab(vocab_path) |
| | | with open(config_path, encoding="utf-8") as f: |
| | | config = yaml.safe_load(f) |
| | | |
| | | config['token_list'] = token_list |
| | | |
| | | if not os.path.exists(output_dir): |
| | | os.makedirs(output_dir) |
| | | |
| | | with open(os.path.join(output_dir, "config.yaml"), "w", encoding="utf-8") as f: |
| | | yaml_no_alias_safe_dump(config, f, indent=4, sort_keys=False) |
| | | |
| | | |
| | | class NoAliasSafeDumper(yaml.SafeDumper): |
| | | # Disable anchor/alias in yaml because looks ugly |
| | | def ignore_aliases(self, data): |
| | | return True |
| | | |
| | | |
| | | def yaml_no_alias_safe_dump(data, stream=None, **kwargs): |
| | | """Safe-dump in yaml with no anchor/alias""" |
| | | return yaml.dump( |
| | | data, stream, allow_unicode=True, Dumper=NoAliasSafeDumper, **kwargs |
| | | ) |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | import sys |
| | | |
| | | config_path = sys.argv[1] |
| | | vocab_path = sys.argv[2] |
| | | output_dir = sys.argv[3] |
| | | gen_config_for_tfmodel(config_path, vocab_path, output_dir) |