from __future__ import print_function
|
|
import argparse
|
import copy
|
import logging
|
import os
|
from shutil import copyfile
|
|
import torch
|
import yaml
|
from typing import Union
|
from funasr.models.fsmn_kws_mt.encoder import FSMNMTConvert
|
from funasr.models.fsmn_kws_mt.model import FsmnKWSMTConvert
|
|
|
def count_parameters(model):
|
return sum(p.numel() for p in model.parameters() if p.requires_grad)
|
|
|
def get_args():
|
parser = argparse.ArgumentParser(
|
description=
|
'load and convert network to each other between kaldi/pytorch format')
|
parser.add_argument('--config', required=True, help='config file')
|
parser.add_argument(
|
'--network_file',
|
default='',
|
required=True,
|
help='input network, support kaldi.txt/pytorch.pt')
|
parser.add_argument('--model_dir', required=True, help='save model dir')
|
parser.add_argument('--model_name', required=True, help='save model name')
|
parser.add_argument('--model_name2', required=True, help='save model name')
|
parser.add_argument('--convert_to',
|
default='kaldi',
|
required=True,
|
help='target network type, kaldi/pytorch')
|
|
args = parser.parse_args()
|
return args
|
|
|
def convert_to_kaldi(
|
configs,
|
network_file,
|
model_dir,
|
model_name="convert.kaldi.txt",
|
model_name2="convert.kaldi2.txt"
|
):
|
copyfile(network_file, os.path.join(model_dir, 'origin.torch.pt'))
|
|
model = FsmnKWSMTConvert(
|
encoder='FSMNMTConvert',
|
encoder_conf=configs['encoder_conf'],
|
ctc_conf=configs['ctc_conf'],
|
)
|
print(model)
|
num_params = count_parameters(model)
|
print('the number of model params: {}'.format(num_params))
|
|
states= torch.load(network_file, map_location='cpu')
|
model.load_state_dict(states["state_dict"])
|
|
kaldi_text = os.path.join(model_dir, model_name)
|
with open(kaldi_text, 'w', encoding='utf8') as fout:
|
nnet_desp = model.to_kaldi_net()
|
fout.write(nnet_desp)
|
fout.close()
|
|
kaldi_text2 = os.path.join(model_dir, model_name2)
|
with open(kaldi_text2, 'w', encoding='utf8') as fout:
|
nnet_desp2 = model.to_kaldi_net2()
|
fout.write(nnet_desp2)
|
fout.close()
|
|
|
def convert_to_pytorch(
|
configs,
|
network_file,
|
model_dir,
|
model_name="convert.torch.pt"
|
):
|
model = FsmnKWSMTConvert(
|
encoder='FSMNMTConvert',
|
encoder_conf=configs['encoder_conf'],
|
ctc_conf=configs['ctc_conf'],
|
)
|
|
num_params = count_parameters(model)
|
print('the number of model params: {}'.format(num_params))
|
|
copyfile(network_file, os.path.join(model_dir, 'origin.kaldi.txt'))
|
model.to_pytorch_net(network_file)
|
|
save_model_path = os.path.join(model_dir, model_name)
|
torch.save({"model": model.state_dict()}, save_model_path)
|
|
print('convert torch format back to kaldi')
|
kaldi_text = os.path.join(model_dir, 'convert.kaldi.txt')
|
with open(kaldi_text, 'w', encoding='utf8') as fout:
|
nnet_desp = model.to_kaldi_net()
|
fout.write(nnet_desp)
|
fout.close()
|
|
print('Done!')
|
|
|
def main():
|
args = get_args()
|
logging.basicConfig(level=logging.DEBUG,
|
format='%(asctime)s %(levelname)s %(message)s')
|
print(args)
|
with open(args.config, 'r') as fin:
|
configs = yaml.load(fin, Loader=yaml.FullLoader)
|
|
if args.convert_to == 'pytorch':
|
print('convert kaldi net to pytorch...')
|
convert_to_pytorch(
|
configs,
|
args.network_file,
|
args.model_dir,
|
args.model_name,
|
args.model_name2,
|
)
|
elif args.convert_to == 'kaldi':
|
print('convert pytorch net to kaldi...')
|
convert_to_kaldi(
|
configs,
|
args.network_file,
|
args.model_dir,
|
args.model_name
|
)
|
else:
|
print('unsupported target network type: {}'.format(args.convert_to))
|
|
|
if __name__ == '__main__':
|
main()
|