From 559cc2c6e296bc80917a7408911f671dfcc2b68b Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期五, 12 五月 2023 17:25:54 +0800
Subject: [PATCH] update repo

---
 egs/aishell2/transformer/utils/error_rate_zh |  370 ++++++++++++++++++++++++++++++++++++++++++++++++++++
 1 files changed, 370 insertions(+), 0 deletions(-)

diff --git a/egs/aishell2/transformer/utils/error_rate_zh b/egs/aishell2/transformer/utils/error_rate_zh
new file mode 100755
index 0000000..6871a07
--- /dev/null
+++ b/egs/aishell2/transformer/utils/error_rate_zh
@@ -0,0 +1,370 @@
+#!/usr/bin/env python3
+# coding=utf8
+
+# Copyright  2021  Jiayu DU
+
+import sys
+import argparse
+import json
+import logging
+logging.basicConfig(stream=sys.stderr, level=logging.INFO, format='[%(levelname)s] %(message)s')
+
+DEBUG = None
+
+def GetEditType(ref_token, hyp_token):
+    if ref_token == None and hyp_token != None:
+        return 'I'
+    elif ref_token != None and hyp_token == None:
+        return 'D'
+    elif ref_token == hyp_token:
+        return 'C'
+    elif ref_token != hyp_token:
+        return 'S'
+    else:
+        raise RuntimeError
+
+class AlignmentArc:
+    def __init__(self, src, dst, ref, hyp):
+        self.src = src
+        self.dst = dst
+        self.ref = ref
+        self.hyp = hyp
+        self.edit_type = GetEditType(ref, hyp)
+
+def similarity_score_function(ref_token, hyp_token):
+    return 0 if (ref_token == hyp_token) else -1.0
+
+def insertion_score_function(token):
+    return -1.0
+
+def deletion_score_function(token):
+    return -1.0
+
+def EditDistance(
+        ref,
+        hyp, 
+        similarity_score_function = similarity_score_function,
+        insertion_score_function = insertion_score_function,
+        deletion_score_function = deletion_score_function):
+    assert(len(ref) != 0)
+    class DPState:
+        def __init__(self):
+            self.score = -float('inf')
+            # backpointer
+            self.prev_r = None
+            self.prev_h = None
+    
+    def print_search_grid(S, R, H, fstream):
+        print(file=fstream)
+        for r in range(R):
+            for h in range(H):
+                print(F'[{r},{h}]:{S[r][h].score:4.3f}:({S[r][h].prev_r},{S[r][h].prev_h}) ', end='', file=fstream)
+            print(file=fstream)
+
+    R = len(ref) + 1
+    H = len(hyp) + 1
+
+    # Construct DP search space, a (R x H) grid
+    S = [ [] for r in range(R) ]
+    for r in range(R):
+        S[r] = [ DPState() for x in range(H) ]
+
+    # initialize DP search grid origin, S(r = 0, h = 0)
+    S[0][0].score = 0.0
+    S[0][0].prev_r = None
+    S[0][0].prev_h = None
+
+    # initialize REF axis
+    for r in range(1, R):
+        S[r][0].score = S[r-1][0].score + deletion_score_function(ref[r-1])
+        S[r][0].prev_r = r-1
+        S[r][0].prev_h = 0
+
+    # initialize HYP axis
+    for h in range(1, H):
+        S[0][h].score = S[0][h-1].score + insertion_score_function(hyp[h-1])
+        S[0][h].prev_r = 0
+        S[0][h].prev_h = h-1
+
+    best_score = S[0][0].score
+    best_state = (0, 0)
+
+    for r in range(1, R):
+        for h in range(1, H):
+            sub_or_cor_score = similarity_score_function(ref[r-1], hyp[h-1])
+            new_score = S[r-1][h-1].score + sub_or_cor_score
+            if new_score >= S[r][h].score:
+                S[r][h].score = new_score
+                S[r][h].prev_r = r-1
+                S[r][h].prev_h = h-1
+
+            del_score = deletion_score_function(ref[r-1])
+            new_score = S[r-1][h].score + del_score
+            if new_score >= S[r][h].score:
+                S[r][h].score = new_score
+                S[r][h].prev_r = r - 1
+                S[r][h].prev_h = h
+
+            ins_score = insertion_score_function(hyp[h-1])
+            new_score = S[r][h-1].score + ins_score
+            if new_score >= S[r][h].score:
+                S[r][h].score = new_score
+                S[r][h].prev_r = r
+                S[r][h].prev_h = h-1
+
+    best_score = S[R-1][H-1].score
+    best_state = (R-1, H-1)
+
+    if DEBUG:
+        print_search_grid(S, R, H, sys.stderr)
+
+    # Backtracing best alignment path, i.e. a list of arcs
+    # arc = (src, dst, ref, hyp, edit_type)
+    # src/dst = (r, h), where r/h refers to search grid state-id along Ref/Hyp axis
+    best_path = []
+    r, h = best_state[0], best_state[1]
+    prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h
+    score = S[r][h].score
+    # loop invariant:
+    #   1. (prev_r, prev_h) -> (r, h) is a "forward arc" on best alignment path
+    #   2. score is the value of point(r, h) on DP search grid
+    while prev_r != None or prev_h != None:
+        src = (prev_r, prev_h)
+        dst = (r, h)
+        if (r == prev_r + 1 and h == prev_h + 1): # Substitution or correct
+            arc = AlignmentArc(src, dst, ref[prev_r], hyp[prev_h])
+        elif (r == prev_r + 1 and h == prev_h): # Deletion
+            arc = AlignmentArc(src, dst, ref[prev_r], None)
+        elif (r == prev_r and h == prev_h + 1): # Insertion
+            arc = AlignmentArc(src, dst, None, hyp[prev_h])
+        else:
+            raise RuntimeError
+        best_path.append(arc)
+        r, h = prev_r, prev_h
+        prev_r, prev_h = S[r][h].prev_r, S[r][h].prev_h
+        score = S[r][h].score
+    
+    best_path.reverse()
+    return (best_path, best_score)
+
+def PrettyPrintAlignment(alignment, stream = sys.stderr):
+    def get_token_str(token):
+        if token == None:
+            return "*"
+        return token
+    
+    def is_double_width_char(ch):
+        if (ch >= '\u4e00') and (ch <= '\u9fa5'): # codepoint ranges for Chinese chars
+            return True
+        # TODO: support other double-width-char language such as Japanese, Korean 
+        else:
+            return False
+    
+    def display_width(token_str):
+        m = 0
+        for c in token_str:
+            if is_double_width_char(c):
+                m += 2
+            else:
+                m += 1
+        return m
+
+    R = '  REF  : '
+    H = '  HYP  : '
+    E = '  EDIT : '
+    for arc in alignment:
+        r = get_token_str(arc.ref)
+        h = get_token_str(arc.hyp)
+        e = arc.edit_type if arc.edit_type != 'C' else ''
+
+        nr, nh, ne = display_width(r), display_width(h), display_width(e)
+        n = max(nr, nh, ne) + 1
+
+        R += r + ' ' * (n-nr)
+        H += h + ' ' * (n-nh)
+        E += e + ' ' * (n-ne)
+
+    print(R, file=stream)
+    print(H, file=stream)
+    print(E, file=stream)
+
+def CountEdits(alignment):
+    c, s, i, d = 0, 0, 0, 0
+    for arc in alignment:
+        if arc.edit_type == 'C':
+            c += 1
+        elif arc.edit_type == 'S':
+            s += 1
+        elif arc.edit_type == 'I':
+            i += 1
+        elif arc.edit_type == 'D':
+            d += 1
+        else:
+            raise RuntimeError
+    return (c, s, i, d)
+
+def ComputeTokenErrorRate(c, s, i, d):
+    return 100.0 * (s + d + i) / (s + d + c)
+
+def ComputeSentenceErrorRate(num_err_utts, num_utts):
+    assert(num_utts != 0)
+    return 100.0 * num_err_utts / num_utts
+
+
+class EvaluationResult:
+    def __init__(self):
+        self.num_ref_utts = 0
+        self.num_hyp_utts = 0
+        self.num_eval_utts = 0 # seen in both ref & hyp
+        self.num_hyp_without_ref = 0
+
+        self.C = 0
+        self.S = 0
+        self.I = 0
+        self.D = 0
+        self.token_error_rate = 0.0
+
+        self.num_utts_with_error = 0
+        self.sentence_error_rate = 0.0
+    
+    def to_json(self):
+        return json.dumps(self.__dict__)
+    
+    def to_kaldi(self):
+        info = (
+            F'%WER {self.token_error_rate:.2f} [ {self.S + self.D + self.I} / {self.C + self.S + self.D}, {self.I} ins, {self.D} del, {self.S} sub ]\n'
+            F'%SER {self.sentence_error_rate:.2f} [ {self.num_utts_with_error} / {self.num_eval_utts} ]\n'
+        )
+        return info
+    
+    def to_sclite(self):
+        return "TODO"
+    
+    def to_espnet(self):
+        return "TODO"
+    
+    def to_summary(self):
+        #return json.dumps(self.__dict__, indent=4)
+        summary = (
+            '==================== Overall Statistics ====================\n'
+            F'num_ref_utts: {self.num_ref_utts}\n'
+            F'num_hyp_utts: {self.num_hyp_utts}\n'
+            F'num_hyp_without_ref: {self.num_hyp_without_ref}\n'
+            F'num_eval_utts: {self.num_eval_utts}\n'
+            F'sentence_error_rate: {self.sentence_error_rate:.2f}%\n'
+            F'token_error_rate: {self.token_error_rate:.2f}%\n'
+            F'token_stats:\n'
+            F'  - tokens:{self.C + self.S + self.D:>7}\n'
+            F'  - edits: {self.S + self.I + self.D:>7}\n'
+            F'  - cor:   {self.C:>7}\n'
+            F'  - sub:   {self.S:>7}\n'
+            F'  - ins:   {self.I:>7}\n'
+            F'  - del:   {self.D:>7}\n'
+            '============================================================\n'
+        )
+        return summary
+
+
+class Utterance:
+    def __init__(self, uid, text):
+        self.uid = uid
+        self.text = text
+
+
+def LoadUtterances(filepath, format):
+    utts = {}
+    if format == 'text': # utt_id word1 word2 ...
+        with open(filepath, 'r', encoding='utf8') as f:
+            for line in f:
+                line = line.strip()
+                if line:
+                    cols = line.split(maxsplit=1)
+                    assert(len(cols) == 2 or len(cols) == 1)
+                    uid = cols[0]
+                    text = cols[1] if len(cols) == 2 else ''
+                    if utts.get(uid) != None:
+                        raise RuntimeError(F'Found duplicated utterence id {uid}')
+                    utts[uid] = Utterance(uid, text)
+    else:
+        raise RuntimeError(F'Unsupported text format {format}')
+    return utts
+
+
+def tokenize_text(text, tokenizer):
+    if tokenizer == 'whitespace':
+        return text.split()
+    elif tokenizer == 'char':
+        return [ ch for ch in ''.join(text.split()) ]
+    else:
+        raise RuntimeError(F'ERROR: Unsupported tokenizer {tokenizer}')
+
+
+if __name__ == '__main__':
+    parser = argparse.ArgumentParser()
+    # optional
+    parser.add_argument('--tokenizer', choices=['whitespace', 'char'], default='whitespace', help='whitespace for WER, char for CER')
+    parser.add_argument('--ref-format', choices=['text'], default='text', help='reference format, first col is utt_id, the rest is text')
+    parser.add_argument('--hyp-format', choices=['text'], default='text', help='hypothesis format, first col is utt_id, the rest is text')
+    # required
+    parser.add_argument('--ref', type=str, required=True, help='input reference file')
+    parser.add_argument('--hyp', type=str, required=True, help='input hypothesis file')
+
+    parser.add_argument('result_file', type=str)
+    args = parser.parse_args()
+    logging.info(args)
+
+    ref_utts = LoadUtterances(args.ref, args.ref_format)
+    hyp_utts = LoadUtterances(args.hyp, args.hyp_format)
+
+    r = EvaluationResult()
+
+    # check valid utterances in hyp that have matched non-empty reference
+    eval_utts = []
+    r.num_hyp_without_ref = 0
+    for uid in sorted(hyp_utts.keys()):
+        if uid in ref_utts.keys(): # TODO: efficiency
+            if ref_utts[uid].text.strip(): # non-empty reference
+                eval_utts.append(uid)
+            else:
+                logging.warn(F'Found {uid} with empty reference, skipping...')
+        else:
+            logging.warn(F'Found {uid} without reference, skipping...')
+            r.num_hyp_without_ref += 1
+
+    r.num_hyp_utts = len(hyp_utts)
+    r.num_ref_utts = len(ref_utts)
+    r.num_eval_utts = len(eval_utts)
+
+    with open(args.result_file, 'w+', encoding='utf8') as fo:
+        for uid in eval_utts:
+            ref = ref_utts[uid]
+            hyp = hyp_utts[uid]
+
+            alignment, score = EditDistance(
+                tokenize_text(ref.text, args.tokenizer),
+                tokenize_text(hyp.text, args.tokenizer)
+            )
+            
+            c, s, i, d = CountEdits(alignment)
+            utt_ter = ComputeTokenErrorRate(c, s, i, d)
+
+            # utt-level evaluation result
+            print(F'{{"uid":{uid}, "score":{score}, "ter":{utt_ter:.2f}, "cor":{c}, "sub":{s}, "ins":{i}, "del":{d}}}', file=fo)
+            PrettyPrintAlignment(alignment, fo)
+
+            r.C += c
+            r.S += s
+            r.I += i
+            r.D += d
+
+            if utt_ter > 0:
+                r.num_utts_with_error += 1
+
+        # corpus level evaluation result
+        r.sentence_error_rate = ComputeSentenceErrorRate(r.num_utts_with_error, r.num_eval_utts)
+        r.token_error_rate = ComputeTokenErrorRate(r.C, r.S, r.I, r.D)
+
+        print(r.to_summary(), file=fo)
+
+    print(r.to_json())
+    print(r.to_kaldi())

--
Gitblit v1.9.1