zhifu gao
2024-09-25 2196844d1d6e5b8732c95896bb46f0eacdd9cf9d
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from __future__ import print_function
 
import argparse
import copy
import logging
import os
from shutil import copyfile
 
import torch
import yaml
from typing import Union
 
 
from funasr.models.fsmn_kws.model import FsmnKWSConvert
 
 
def count_parameters(model):
    return sum(p.numel() for p in model.parameters() if p.requires_grad)
 
 
def get_args():
    parser = argparse.ArgumentParser(
        description=
        'load and convert network to each other between kaldi/pytorch format')
    parser.add_argument('--config', required=True, help='config file')
    parser.add_argument(
        '--network_file',
        default='',
        required=True,
        help='input network, support kaldi.txt/pytorch.pt')
    parser.add_argument('--model_dir', required=True, help='save model dir')
    parser.add_argument('--model_name', required=True, help='save model name')
    parser.add_argument('--convert_to',
                        default='kaldi',
                        required=True,
                        help='target network type, kaldi/pytorch')
 
    args = parser.parse_args()
    return args
 
 
def convert_to_kaldi(
    configs,
    network_file,
    model_dir,
    model_name="convert.kaldi.txt"
):
    copyfile(network_file, os.path.join(model_dir, 'origin.torch.pt'))
 
    model = FsmnKWSConvert(
            vocab_size=configs['encoder_conf']['output_dim'],
            encoder='FSMNConvert',
            encoder_conf=configs['encoder_conf'],
            ctc_conf=configs['ctc_conf'],
    )
    print(model)
    num_params = count_parameters(model)
    print('the number of model params: {}'.format(num_params))
 
    states= torch.load(network_file, map_location='cpu')
    model.load_state_dict(states["state_dict"])
 
    kaldi_text = os.path.join(model_dir, model_name)
    with open(kaldi_text, 'w', encoding='utf8') as fout:
        nnet_desp = model.to_kaldi_net()
        fout.write(nnet_desp)
    fout.close()
 
 
def convert_to_pytorch(
    configs,
    network_file,
    model_dir,
    model_name="convert.torch.pt"
):
    model = FsmnKWSConvert(
            vocab_size=configs['encoder_conf']['output_dim'],
            frontend=None,
            specaug=None,
            normalize=None,
            encoder='FSMNConvert',
            encoder_conf=configs['encoder_conf'],
            ctc_conf=configs['ctc_conf'],
    )
 
    num_params = count_parameters(model)
    print('the number of model params: {}'.format(num_params))
 
    copyfile(network_file, os.path.join(model_dir, 'origin.kaldi.txt'))
    model.to_pytorch_net(network_file)
 
    save_model_path = os.path.join(model_dir, model_name)
    torch.save({"model": model.state_dict()}, save_model_path)
 
    print('convert torch format back to kaldi')
    kaldi_text = os.path.join(model_dir, 'convert.kaldi.txt')
    with open(kaldi_text, 'w', encoding='utf8') as fout:
        nnet_desp = model.to_kaldi_net()
        fout.write(nnet_desp)
    fout.close()
 
    print('Done!')
 
 
def main():
    args = get_args()
    logging.basicConfig(level=logging.DEBUG,
                        format='%(asctime)s %(levelname)s %(message)s')
    print(args)
    with open(args.config, 'r') as fin:
        configs = yaml.load(fin, Loader=yaml.FullLoader)
 
    if args.convert_to == 'pytorch':
        print('convert kaldi net to pytorch...')
        convert_to_pytorch(
            configs,
            args.network_file,
            args.model_dir,
            args.model_name
        )
    elif args.convert_to == 'kaldi':
        print('convert pytorch net to kaldi...')
        convert_to_kaldi(
            configs,
            args.network_file,
            args.model_dir,
            args.model_name
        )
    else:
        print('unsupported target network type: {}'.format(args.convert_to))
 
 
if __name__ == '__main__':
    main()