import editdistance
|
import sys
|
import os
|
from itertools import permutations
|
|
|
def load_transcripts(file_path):
|
trans_list = []
|
for one_line in open(file_path, "rt"):
|
meeting_id, trans = one_line.strip().split(" ")
|
trans_list.append((meeting_id.strip(), trans.strip()))
|
|
return trans_list
|
|
def calc_spk_trans(trans):
|
spk_trans_ = [x.strip() for x in trans.split("$")]
|
spk_trans = []
|
for i in range(len(spk_trans_)):
|
spk_trans.append((str(i), spk_trans_[i]))
|
return spk_trans
|
|
def calc_cer(ref_trans, hyp_trans):
|
ref_spk_trans = calc_spk_trans(ref_trans)
|
hyp_spk_trans = calc_spk_trans(hyp_trans)
|
ref_spk_num, hyp_spk_num = len(ref_spk_trans), len(hyp_spk_trans)
|
num_spk = max(len(ref_spk_trans), len(hyp_spk_trans))
|
ref_spk_trans.extend([("", "")] * (num_spk - len(ref_spk_trans)))
|
hyp_spk_trans.extend([("", "")] * (num_spk - len(hyp_spk_trans)))
|
|
errors, counts, permutes = [], [], []
|
min_error = 0
|
cost_dict = {}
|
for perm in permutations(range(num_spk)):
|
flag = True
|
p_err, p_count = 0, 0
|
for idx, p in enumerate(perm):
|
if abs(len(ref_spk_trans[idx][1]) - len(hyp_spk_trans[p][1])) > min_error > 0:
|
flag = False
|
break
|
cost_key = "{}-{}".format(idx, p)
|
if cost_key in cost_dict:
|
_e = cost_dict[cost_key]
|
else:
|
_e = editdistance.eval(ref_spk_trans[idx][1], hyp_spk_trans[p][1])
|
cost_dict[cost_key] = _e
|
if _e > min_error > 0:
|
flag = False
|
break
|
p_err += _e
|
p_count += len(ref_spk_trans[idx][1])
|
|
if flag:
|
if p_err < min_error or min_error == 0:
|
min_error = p_err
|
|
errors.append(p_err)
|
counts.append(p_count)
|
permutes.append(perm)
|
|
sd_cer = [(err, cnt, err/cnt, permute)
|
for err, cnt, permute in zip(errors, counts, permutes)]
|
best_rst = min(sd_cer, key=lambda x: x[2])
|
|
return best_rst[0], best_rst[1], ref_spk_num, hyp_spk_num
|
|
|
def main():
|
ref=sys.argv[1]
|
hyp=sys.argv[2]
|
result_path="/".join(hyp.split("/")[:-1]) + "/text_cpcer"
|
ref_list = load_transcripts(ref)
|
hyp_list = load_transcripts(hyp)
|
result_file = open(result_path,'w')
|
record_2_spk = [0, 0]
|
record_3_spk = [0, 0]
|
record_4_spk = [0, 0]
|
error, count = 0, 0
|
for (ref_id, ref_trans), (hyp_id, hyp_trans) in zip(ref_list, hyp_list):
|
assert ref_id == hyp_id
|
mid = ref_id
|
dist, length, ref_spk_num, hyp_spk_num = calc_cer(ref_trans, hyp_trans)
|
error, count = error + dist, count + length
|
result_file.write("{} {:.2f} {} {}\n".format(mid, dist / length * 100.0, ref_spk_num, hyp_spk_num))
|
ref_spk = len(ref_trans.split("$"))
|
hyp_spk = len(hyp_trans.split("$"))
|
if ref_spk == 2:
|
record_2_spk[0] += dist
|
record_2_spk[1] += length
|
elif ref_spk == 3:
|
record_3_spk[0] += dist
|
record_3_spk[1] += length
|
else:
|
record_4_spk[0] += dist
|
record_4_spk[1] += length
|
print(record_2_spk[0]/record_2_spk[1]*100.0)
|
print(record_3_spk[0]/record_3_spk[1]*100.0)
|
print(record_4_spk[0]/record_4_spk[1]*100.0)
|
result_file.write("CP-CER: {:.2f}\n".format(error / count * 100.0))
|
result_file.close()
|
|
|
if __name__ == '__main__':
|
main()
|