| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # coding=utf8 |
| | | |
| | | # Copyright 2021 Jiayu DU |
| | | |
| | | import sys |
| | | import argparse |
| | | import json |
| | | import logging |
| | | logging.basicConfig(stream=sys.stderr, level=logging.INFO, format='[%(levelname)s] %(message)s') |
| | | |
| | | DEBUG = None |
| | | |
| | | def GetEditType(ref_token, hyp_token): |
| | | if ref_token == None and hyp_token != None: |
| | | return 'I' |
| | | elif ref_token != None and hyp_token == None: |
| | | return 'D' |
| | | elif ref_token == hyp_token: |
| | | return 'C' |
| | | elif ref_token != hyp_token: |
| | | return 'S' |
| | | else: |
| | | raise RuntimeError |
| | | |
| | | class AlignmentArc: |
| | | def __init__(self, src, dst, ref, hyp): |
| | | self.src = src |
| | | self.dst = dst |
| | | self.ref = ref |
| | | self.hyp = hyp |
| | | self.edit_type = GetEditType(ref, hyp) |
| | | |
| | | def similarity_score_function(ref_token, hyp_token): |
| | | return 0 if (ref_token == hyp_token) else -1.0 |
| | | |
| | | def insertion_score_function(token): |
| | | return -1.0 |
| | | |
| | | def deletion_score_function(token): |
| | | return -1.0 |
| | | |
| | | def EditDistance( |
| | | ref, |
| | | hyp, |
| | | similarity_score_function = similarity_score_function, |
| | | insertion_score_function = insertion_score_function, |
| | | deletion_score_function = deletion_score_function): |
| | | assert(len(ref) != 0) |
| | | class DPState: |
| | | def __init__(self): |
| | | self.score = -float('inf') |
| | | # backpointer |
| | | self.prev_r = None |
| | | self.prev_h = None |
| | | |
| | | def print_search_grid(S, R, H, fstream): |
| | | print(file=fstream) |
| | | for r in range(R): |
| | | for h in range(H): |
| | | print(F'[{r},{h}]:{S[r][h].score:4.3f}:({S[r][h].prev_r},{S[r][h].prev_h}) ', end='', file=fstream) |
| | | print(file=fstream) |
| | | |
| | | R = len(ref) + 1 |
| | | H = len(hyp) + 1 |
| | | |
| | | # Construct DP search space, a (R x H) grid |
| | | S = [ [] for r in range(R) ] |
| | | for r in range(R): |
| | | S[r] = [ DPState() for x in range(H) ] |
| | | |
| | | # initialize DP search grid origin, S(r = 0, h = 0) |
| | | S[0][0].score = 0.0 |
| | | S[0][0].prev_r = None |
| | | S[0][0].prev_h = None |
| | | |
| | | # initialize REF axis |
| | | for r in range(1, R): |
| | | S[r][0].score = S[r-1][0].score + deletion_score_function(ref[r-1]) |
| | | S[r][0].prev_r = r-1 |
| | | S[r][0].prev_h = 0 |
| | | |
| | | # initialize HYP axis |
| | | for h in range(1, H): |
| | | S[0][h].score = S[0][h-1].score + insertion_score_function(hyp[h-1]) |
| | | S[0][h].prev_r = 0 |
| | | S[0][h].prev_h = h-1 |
| | | |
| | | best_score = S[0][0].score |
| | | best_state = (0, 0) |
| | | |
| | | for r in range(1, R): |
| | | for h in range(1, H): |
| | | sub_or_cor_score = similarity_score_function(ref[r-1], hyp[h-1]) |
| | | new_score = S[r-1][h-1].score + sub_or_cor_score |
| | | if new_score >= S[r][h].score: |
| | | S[r][h].score = new_score |
| | | S[r][h].prev_r = r-1 |
| | | S[r][h].prev_h = h-1 |
| | | |
| | | del_score = deletion_score_function(ref[r-1]) |
| | | new_score = S[r-1][h].score + del_score |
| | | if new_score >= S[r][h].score: |
| | | S[r][h].score = new_score |
| | | S[r][h].prev_r = r - 1 |
| | | S[r][h].prev_h = h |
| | | |
| | | ins_score = insertion_score_function(hyp[h-1]) |
| | | new_score = S[r][h-1].score + ins_score |
| | | if new_score >= S[r][h].score: |
| | | S[r][h].score = new_score |
| | | S[r][h].prev_r = r |
| | | S[r][h].prev_h = h-1 |
| | | |
| | | best_score = S[R-1][H-1].score |
| | | best_state = (R-1, H-1) |
| | | |
| | | if DEBUG: |
| | | print_search_grid(S, R, H, sys.stderr) |
| | | |
| | | # Backtracing best alignment path, i.e. a list of arcs |
| | | # arc = (src, dst, ref, hyp, edit_type) |
| | | # src/dst = (r, h), where r/h refers to search grid state-id along Ref/Hyp axis |
| | | best_path = [] |
| | | r, h = best_state[0], best_state[1] |
| | | prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h |
| | | score = S[r][h].score |
| | | # loop invariant: |
| | | # 1. (prev_r, prev_h) -> (r, h) is a "forward arc" on best alignment path |
| | | # 2. score is the value of point(r, h) on DP search grid |
| | | while prev_r != None or prev_h != None: |
| | | src = (prev_r, prev_h) |
| | | dst = (r, h) |
| | | if (r == prev_r + 1 and h == prev_h + 1): # Substitution or correct |
| | | arc = AlignmentArc(src, dst, ref[prev_r], hyp[prev_h]) |
| | | elif (r == prev_r + 1 and h == prev_h): # Deletion |
| | | arc = AlignmentArc(src, dst, ref[prev_r], None) |
| | | elif (r == prev_r and h == prev_h + 1): # Insertion |
| | | arc = AlignmentArc(src, dst, None, hyp[prev_h]) |
| | | else: |
| | | raise RuntimeError |
| | | best_path.append(arc) |
| | | r, h = prev_r, prev_h |
| | | prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h |
| | | score = S[r][h].score |
| | | |
| | | best_path.reverse() |
| | | return (best_path, best_score) |
| | | |
| | | def PrettyPrintAlignment(alignment, stream = sys.stderr): |
| | | def get_token_str(token): |
| | | if token == None: |
| | | return "*" |
| | | return token |
| | | |
| | | def is_double_width_char(ch): |
| | | if (ch >= '\u4e00') and (ch <= '\u9fa5'): # codepoint ranges for Chinese chars |
| | | return True |
| | | # TODO: support other double-width-char language such as Japanese, Korean |
| | | else: |
| | | return False |
| | | |
| | | def display_width(token_str): |
| | | m = 0 |
| | | for c in token_str: |
| | | if is_double_width_char(c): |
| | | m += 2 |
| | | else: |
| | | m += 1 |
| | | return m |
| | | |
| | | R = ' REF : ' |
| | | H = ' HYP : ' |
| | | E = ' EDIT : ' |
| | | for arc in alignment: |
| | | r = get_token_str(arc.ref) |
| | | h = get_token_str(arc.hyp) |
| | | e = arc.edit_type if arc.edit_type != 'C' else '' |
| | | |
| | | nr, nh, ne = display_width(r), display_width(h), display_width(e) |
| | | n = max(nr, nh, ne) + 1 |
| | | |
| | | R += r + ' ' * (n-nr) |
| | | H += h + ' ' * (n-nh) |
| | | E += e + ' ' * (n-ne) |
| | | |
| | | print(R, file=stream) |
| | | print(H, file=stream) |
| | | print(E, file=stream) |
| | | |
| | | def CountEdits(alignment): |
| | | c, s, i, d = 0, 0, 0, 0 |
| | | for arc in alignment: |
| | | if arc.edit_type == 'C': |
| | | c += 1 |
| | | elif arc.edit_type == 'S': |
| | | s += 1 |
| | | elif arc.edit_type == 'I': |
| | | i += 1 |
| | | elif arc.edit_type == 'D': |
| | | d += 1 |
| | | else: |
| | | raise RuntimeError |
| | | return (c, s, i, d) |
| | | |
| | | def ComputeTokenErrorRate(c, s, i, d): |
| | | return 100.0 * (s + d + i) / (s + d + c) |
| | | |
| | | def ComputeSentenceErrorRate(num_err_utts, num_utts): |
| | | assert(num_utts != 0) |
| | | return 100.0 * num_err_utts / num_utts |
| | | |
| | | |
| | | class EvaluationResult: |
| | | def __init__(self): |
| | | self.num_ref_utts = 0 |
| | | self.num_hyp_utts = 0 |
| | | self.num_eval_utts = 0 # seen in both ref & hyp |
| | | self.num_hyp_without_ref = 0 |
| | | |
| | | self.C = 0 |
| | | self.S = 0 |
| | | self.I = 0 |
| | | self.D = 0 |
| | | self.token_error_rate = 0.0 |
| | | |
| | | self.num_utts_with_error = 0 |
| | | self.sentence_error_rate = 0.0 |
| | | |
| | | def to_json(self): |
| | | return json.dumps(self.__dict__) |
| | | |
| | | def to_kaldi(self): |
| | | info = ( |
| | | F'%WER {self.token_error_rate:.2f} [ {self.S + self.D + self.I} / {self.C + self.S + self.D}, {self.I} ins, {self.D} del, {self.S} sub ]\n' |
| | | F'%SER {self.sentence_error_rate:.2f} [ {self.num_utts_with_error} / {self.num_eval_utts} ]\n' |
| | | ) |
| | | return info |
| | | |
| | | def to_sclite(self): |
| | | return "TODO" |
| | | |
| | | def to_espnet(self): |
| | | return "TODO" |
| | | |
| | | def to_summary(self): |
| | | #return json.dumps(self.__dict__, indent=4) |
| | | summary = ( |
| | | '==================== Overall Statistics ====================\n' |
| | | F'num_ref_utts: {self.num_ref_utts}\n' |
| | | F'num_hyp_utts: {self.num_hyp_utts}\n' |
| | | F'num_hyp_without_ref: {self.num_hyp_without_ref}\n' |
| | | F'num_eval_utts: {self.num_eval_utts}\n' |
| | | F'sentence_error_rate: {self.sentence_error_rate:.2f}%\n' |
| | | F'token_error_rate: {self.token_error_rate:.2f}%\n' |
| | | F'token_stats:\n' |
| | | F' - tokens:{self.C + self.S + self.D:>7}\n' |
| | | F' - edits: {self.S + self.I + self.D:>7}\n' |
| | | F' - cor: {self.C:>7}\n' |
| | | F' - sub: {self.S:>7}\n' |
| | | F' - ins: {self.I:>7}\n' |
| | | F' - del: {self.D:>7}\n' |
| | | '============================================================\n' |
| | | ) |
| | | return summary |
| | | |
| | | |
| | | class Utterance: |
| | | def __init__(self, uid, text): |
| | | self.uid = uid |
| | | self.text = text |
| | | |
| | | |
| | | def LoadUtterances(filepath, format): |
| | | utts = {} |
| | | if format == 'text': # utt_id word1 word2 ... |
| | | with open(filepath, 'r', encoding='utf8') as f: |
| | | for line in f: |
| | | line = line.strip() |
| | | if line: |
| | | cols = line.split(maxsplit=1) |
| | | assert(len(cols) == 2 or len(cols) == 1) |
| | | uid = cols[0] |
| | | text = cols[1] if len(cols) == 2 else '' |
| | | if utts.get(uid) != None: |
| | | raise RuntimeError(F'Found duplicated utterence id {uid}') |
| | | utts[uid] = Utterance(uid, text) |
| | | else: |
| | | raise RuntimeError(F'Unsupported text format {format}') |
| | | return utts |
| | | |
| | | |
| | | def tokenize_text(text, tokenizer): |
| | | if tokenizer == 'whitespace': |
| | | return text.split() |
| | | elif tokenizer == 'char': |
| | | return [ ch for ch in ''.join(text.split()) ] |
| | | else: |
| | | raise RuntimeError(F'ERROR: Unsupported tokenizer {tokenizer}') |
| | | |
| | | |
| | | if __name__ == '__main__': |
| | | parser = argparse.ArgumentParser() |
| | | # optional |
| | | parser.add_argument('--tokenizer', choices=['whitespace', 'char'], default='whitespace', help='whitespace for WER, char for CER') |
| | | parser.add_argument('--ref-format', choices=['text'], default='text', help='reference format, first col is utt_id, the rest is text') |
| | | parser.add_argument('--hyp-format', choices=['text'], default='text', help='hypothesis format, first col is utt_id, the rest is text') |
| | | # required |
| | | parser.add_argument('--ref', type=str, required=True, help='input reference file') |
| | | parser.add_argument('--hyp', type=str, required=True, help='input hypothesis file') |
| | | |
| | | parser.add_argument('result_file', type=str) |
| | | args = parser.parse_args() |
| | | logging.info(args) |
| | | |
| | | ref_utts = LoadUtterances(args.ref, args.ref_format) |
| | | hyp_utts = LoadUtterances(args.hyp, args.hyp_format) |
| | | |
| | | r = EvaluationResult() |
| | | |
| | | # check valid utterances in hyp that have matched non-empty reference |
| | | eval_utts = [] |
| | | r.num_hyp_without_ref = 0 |
| | | for uid in sorted(hyp_utts.keys()): |
| | | if uid in ref_utts.keys(): # TODO: efficiency |
| | | if ref_utts[uid].text.strip(): # non-empty reference |
| | | eval_utts.append(uid) |
| | | else: |
| | | logging.warn(F'Found {uid} with empty reference, skipping...') |
| | | else: |
| | | logging.warn(F'Found {uid} without reference, skipping...') |
| | | r.num_hyp_without_ref += 1 |
| | | |
| | | r.num_hyp_utts = len(hyp_utts) |
| | | r.num_ref_utts = len(ref_utts) |
| | | r.num_eval_utts = len(eval_utts) |
| | | |
| | | with open(args.result_file, 'w+', encoding='utf8') as fo: |
| | | for uid in eval_utts: |
| | | ref = ref_utts[uid] |
| | | hyp = hyp_utts[uid] |
| | | |
| | | alignment, score = EditDistance( |
| | | tokenize_text(ref.text, args.tokenizer), |
| | | tokenize_text(hyp.text, args.tokenizer) |
| | | ) |
| | | |
| | | c, s, i, d = CountEdits(alignment) |
| | | utt_ter = ComputeTokenErrorRate(c, s, i, d) |
| | | |
| | | # utt-level evaluation result |
| | | print(F'{{"uid":{uid}, "score":{score}, "ter":{utt_ter:.2f}, "cor":{c}, "sub":{s}, "ins":{i}, "del":{d}}}', file=fo) |
| | | PrettyPrintAlignment(alignment, fo) |
| | | |
| | | r.C += c |
| | | r.S += s |
| | | r.I += i |
| | | r.D += d |
| | | |
| | | if utt_ter > 0: |
| | | r.num_utts_with_error += 1 |
| | | |
| | | # corpus level evaluation result |
| | | r.sentence_error_rate = ComputeSentenceErrorRate(r.num_utts_with_error, r.num_eval_utts) |
| | | r.token_error_rate = ComputeTokenErrorRate(r.C, r.S, r.I, r.D) |
| | | |
| | | print(r.to_summary(), file=fo) |
| | | |
| | | print(r.to_json()) |
| | | print(r.to_kaldi()) |