游雁
2023-09-13 33d3d2084403fd34b79c835d2f2fe04f6cd8f738
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import codecs
import pdb
import sys
import torch
 
char1 = sys.argv[1]
char2 = sys.argv[2]
model1 = torch.load(sys.argv[3], map_location='cpu')
model2_path = sys.argv[4]
 
d_new = model1
char1_list = []
map_list = []
 
 
with codecs.open(char1) as f:
    for line in f.readlines():
        char1_list.append(line.strip())
 
with codecs.open(char2) as f:
    for line in f.readlines():
        map_list.append(char1_list.index(line.strip()))
print(map_list)
 
for k, v in d_new.items():
    if k == 'ctc.ctc_lo.weight' or k == 'ctc.ctc_lo.bias' or k == 'decoder.output_layer.weight' or k == 'decoder.output_layer.bias' or k == 'decoder.embed.0.weight':
        d_new[k] = v[map_list]
    
torch.save(d_new, model2_path)