语帆
2024-03-04 920331972a136834a560d78917de60f6c6623d96
commit
10个文件已修改
11个文件已添加
1947 ■■■■■ 已修改文件
examples/industrial_data_pretraining/contextual_paraformer/demo.py 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/contextual_paraformer/demo.sh 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/contextual_paraformer/demo2.sh 9 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/contextual_paraformer/path.sh 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/lcbnet/compute_wer_details.py 702 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/lcbnet/demo.py 13 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/lcbnet/demo.sh 72 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/lcbnet/utils 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/seaco_paraformer/demo.py 10 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/auto/auto_model.py 12 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/frontends/default.py 20 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/conformer/encoder.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/contextual_paraformer/model.py 39 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/lcbnet/__init__.py 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/lcbnet/attention.py 112 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/lcbnet/encoder.py 392 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/lcbnet/model.py 495 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/seaco_paraformer/model.py 18 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/train_utils/load_pretrained_model.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/load_utils.py 28 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/contextual_paraformer/demo.py
old mode 100644 new mode 100755
examples/industrial_data_pretraining/contextual_paraformer/demo.sh
old mode 100644 new mode 100755
@@ -2,7 +2,7 @@
model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
model_revision="v2.0.4"
python funasr/bin/inference.py \
python ../../../funasr/bin/inference.py \
+model=${model} \
+model_revision=${model_revision} \
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \
examples/industrial_data_pretraining/contextual_paraformer/demo2.sh
New file
@@ -0,0 +1,9 @@
python -m funasr.bin.inference \
--config-path="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \
--config-name="config.yaml" \
++init_param="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/model.pb" \
++tokenizer_conf.token_list="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/tokens.txt" \
++frontend_conf.cmvn_file="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/am.mvn" \
++input="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/asr_example_zh.wav" \
++output_dir="./outputs/debug2" \
++device="" \
examples/industrial_data_pretraining/contextual_paraformer/path.sh
New file
@@ -0,0 +1,6 @@
export FUNASR_DIR=$PWD/../../../
# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
export PYTHONIOENCODING=UTF-8
export PATH=$FUNASR_DIR/funasr/bin:$PATH
export PYTHONPATH=$FUNASR_DIR/funasr/bin:$FUNASR_DIR/funasr:$FUNASR_DIR:$PYTHONPATH
examples/industrial_data_pretraining/lcbnet/compute_wer_details.py
New file
@@ -0,0 +1,702 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
from enum import Enum
import re, sys, unicodedata
import codecs
import argparse
from tqdm import tqdm
import os
import pdb
remove_tag = False
spacelist = [" ", "\t", "\r", "\n"]
puncts = [
    "!",
    ",",
    "?",
    "、",
    "。",
    "!",
    ",",
    ";",
    "?",
    ":",
    "「",
    "」",
    "︰",
    "『",
    "』",
    "《",
    "》",
]
class Code(Enum):
    match = 1
    substitution = 2
    insertion = 3
    deletion = 4
class WordError(object):
    def __init__(self):
        self.errors = {
            Code.substitution: 0,
            Code.insertion: 0,
            Code.deletion: 0,
        }
        self.ref_words = 0
    def get_wer(self):
        assert self.ref_words != 0
        errors = (
                self.errors[Code.substitution]
                + self.errors[Code.insertion]
                + self.errors[Code.deletion]
        )
        return 100.0 * errors / self.ref_words
    def get_result_string(self):
        return (
            f"error_rate={self.get_wer():.4f}, "
            f"ref_words={self.ref_words}, "
            f"subs={self.errors[Code.substitution]}, "
            f"ins={self.errors[Code.insertion]}, "
            f"dels={self.errors[Code.deletion]}"
        )
def characterize(string):
    res = []
    i = 0
    while i < len(string):
        char = string[i]
        if char in puncts:
            i += 1
            continue
        cat1 = unicodedata.category(char)
        # https://unicodebook.readthedocs.io/unicode.html#unicode-categories
        if cat1 == "Zs" or cat1 == "Cn" or char in spacelist:  # space or not assigned
            i += 1
            continue
        if cat1 == "Lo":  # letter-other
            res.append(char)
            i += 1
        else:
            # some input looks like: <unk><noise>, we want to separate it to two words.
            sep = " "
            if char == "<":
                sep = ">"
            j = i + 1
            while j < len(string):
                c = string[j]
                if ord(c) >= 128 or (c in spacelist) or (c == sep):
                    break
                j += 1
            if j < len(string) and string[j] == ">":
                j += 1
            res.append(string[i:j])
            i = j
    return res
def stripoff_tags(x):
    if not x:
        return ""
    chars = []
    i = 0
    T = len(x)
    while i < T:
        if x[i] == "<":
            while i < T and x[i] != ">":
                i += 1
            i += 1
        else:
            chars.append(x[i])
            i += 1
    return "".join(chars)
def normalize(sentence, ignore_words, cs, split=None):
    """sentence, ignore_words are both in unicode"""
    new_sentence = []
    for token in sentence:
        x = token
        if not cs:
            x = x.upper()
        if x in ignore_words:
            continue
        if remove_tag:
            x = stripoff_tags(x)
        if not x:
            continue
        if split and x in split:
            new_sentence += split[x]
        else:
            new_sentence.append(x)
    return new_sentence
class Calculator:
    def __init__(self):
        self.data = {}
        self.space = []
        self.cost = {}
        self.cost["cor"] = 0
        self.cost["sub"] = 1
        self.cost["del"] = 1
        self.cost["ins"] = 1
    def calculate(self, lab, rec):
        # Initialization
        lab.insert(0, "")
        rec.insert(0, "")
        while len(self.space) < len(lab):
            self.space.append([])
        for row in self.space:
            for element in row:
                element["dist"] = 0
                element["error"] = "non"
            while len(row) < len(rec):
                row.append({"dist": 0, "error": "non"})
        for i in range(len(lab)):
            self.space[i][0]["dist"] = i
            self.space[i][0]["error"] = "del"
        for j in range(len(rec)):
            self.space[0][j]["dist"] = j
            self.space[0][j]["error"] = "ins"
        self.space[0][0]["error"] = "non"
        for token in lab:
            if token not in self.data and len(token) > 0:
                self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
        for token in rec:
            if token not in self.data and len(token) > 0:
                self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
        # Computing edit distance
        for i, lab_token in enumerate(lab):
            for j, rec_token in enumerate(rec):
                if i == 0 or j == 0:
                    continue
                min_dist = sys.maxsize
                min_error = "none"
                dist = self.space[i - 1][j]["dist"] + self.cost["del"]
                error = "del"
                if dist < min_dist:
                    min_dist = dist
                    min_error = error
                dist = self.space[i][j - 1]["dist"] + self.cost["ins"]
                error = "ins"
                if dist < min_dist:
                    min_dist = dist
                    min_error = error
                if lab_token == rec_token.replace("<BIAS>", ""):
                    dist = self.space[i - 1][j - 1]["dist"] + self.cost["cor"]
                    error = "cor"
                else:
                    dist = self.space[i - 1][j - 1]["dist"] + self.cost["sub"]
                    error = "sub"
                if dist < min_dist:
                    min_dist = dist
                    min_error = error
                self.space[i][j]["dist"] = min_dist
                self.space[i][j]["error"] = min_error
        # Tracing back
        result = {
            "lab": [],
            "rec": [],
            "code": [],
            "all": 0,
            "cor": 0,
            "sub": 0,
            "ins": 0,
            "del": 0,
        }
        i = len(lab) - 1
        j = len(rec) - 1
        while True:
            if self.space[i][j]["error"] == "cor":  # correct
                if len(lab[i]) > 0:
                    self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
                    self.data[lab[i]]["cor"] = self.data[lab[i]]["cor"] + 1
                    result["all"] = result["all"] + 1
                    result["cor"] = result["cor"] + 1
                result["lab"].insert(0, lab[i])
                result["rec"].insert(0, rec[j])
                result["code"].insert(0, Code.match)
                i = i - 1
                j = j - 1
            elif self.space[i][j]["error"] == "sub":  # substitution
                if len(lab[i]) > 0:
                    self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
                    self.data[lab[i]]["sub"] = self.data[lab[i]]["sub"] + 1
                    result["all"] = result["all"] + 1
                    result["sub"] = result["sub"] + 1
                result["lab"].insert(0, lab[i])
                result["rec"].insert(0, rec[j])
                result["code"].insert(0, Code.substitution)
                i = i - 1
                j = j - 1
            elif self.space[i][j]["error"] == "del":  # deletion
                if len(lab[i]) > 0:
                    self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
                    self.data[lab[i]]["del"] = self.data[lab[i]]["del"] + 1
                    result["all"] = result["all"] + 1
                    result["del"] = result["del"] + 1
                result["lab"].insert(0, lab[i])
                result["rec"].insert(0, "")
                result["code"].insert(0, Code.deletion)
                i = i - 1
            elif self.space[i][j]["error"] == "ins":  # insertion
                if len(rec[j]) > 0:
                    self.data[rec[j]]["ins"] = self.data[rec[j]]["ins"] + 1
                    result["ins"] = result["ins"] + 1
                result["lab"].insert(0, "")
                result["rec"].insert(0, rec[j])
                result["code"].insert(0, Code.insertion)
                j = j - 1
            elif self.space[i][j]["error"] == "non":  # starting point
                break
            else:  # shouldn't reach here
                print(
                    "this should not happen , i = {i} , j = {j} , error = {error}".format(
                        i=i, j=j, error=self.space[i][j]["error"]
                    )
                )
        return result
    def overall(self):
        result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
        for token in self.data:
            result["all"] = result["all"] + self.data[token]["all"]
            result["cor"] = result["cor"] + self.data[token]["cor"]
            result["sub"] = result["sub"] + self.data[token]["sub"]
            result["ins"] = result["ins"] + self.data[token]["ins"]
            result["del"] = result["del"] + self.data[token]["del"]
        return result
    def cluster(self, data):
        result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
        for token in data:
            if token in self.data:
                result["all"] = result["all"] + self.data[token]["all"]
                result["cor"] = result["cor"] + self.data[token]["cor"]
                result["sub"] = result["sub"] + self.data[token]["sub"]
                result["ins"] = result["ins"] + self.data[token]["ins"]
                result["del"] = result["del"] + self.data[token]["del"]
        return result
    def keys(self):
        return list(self.data.keys())
def width(string):
    return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
def default_cluster(word):
    unicode_names = [unicodedata.name(char) for char in word]
    for i in reversed(range(len(unicode_names))):
        if unicode_names[i].startswith("DIGIT"):  # 1
            unicode_names[i] = "Number"  # 'DIGIT'
        elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[
            i
        ].startswith("CJK COMPATIBILITY IDEOGRAPH"):
            # 明 / 郎
            unicode_names[i] = "Mandarin"  # 'CJK IDEOGRAPH'
        elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[
            i
        ].startswith("LATIN SMALL LETTER"):
            # A / a
            unicode_names[i] = "English"  # 'LATIN LETTER'
        elif unicode_names[i].startswith("HIRAGANA LETTER"):  # は こ め
            unicode_names[i] = "Japanese"  # 'GANA LETTER'
        elif (
                unicode_names[i].startswith("AMPERSAND")
                or unicode_names[i].startswith("APOSTROPHE")
                or unicode_names[i].startswith("COMMERCIAL AT")
                or unicode_names[i].startswith("DEGREE CELSIUS")
                or unicode_names[i].startswith("EQUALS SIGN")
                or unicode_names[i].startswith("FULL STOP")
                or unicode_names[i].startswith("HYPHEN-MINUS")
                or unicode_names[i].startswith("LOW LINE")
                or unicode_names[i].startswith("NUMBER SIGN")
                or unicode_names[i].startswith("PLUS SIGN")
                or unicode_names[i].startswith("SEMICOLON")
        ):
            # & / ' / @ / ℃ / = / . / - / _ / # / + / ;
            del unicode_names[i]
        else:
            return "Other"
    if len(unicode_names) == 0:
        return "Other"
    if len(unicode_names) == 1:
        return unicode_names[0]
    for i in range(len(unicode_names) - 1):
        if unicode_names[i] != unicode_names[i + 1]:
            return "Other"
    return unicode_names[0]
def get_args():
    parser = argparse.ArgumentParser(description="wer cal")
    parser.add_argument("--ref", type=str, help="Text input path")
    parser.add_argument("--ref_ocr", type=str, help="Text input path")
    parser.add_argument("--rec_name", type=str, action="append", default=[])
    parser.add_argument("--rec_file", type=str, action="append", default=[])
    parser.add_argument("--verbose", type=int, default=1, help="show")
    parser.add_argument("--char", type=bool, default=True, help="show")
    args = parser.parse_args()
    return args
def main(args):
    cluster_file = ""
    ignore_words = set()
    tochar = args.char
    verbose = args.verbose
    padding_symbol = " "
    case_sensitive = False
    max_words_per_line = sys.maxsize
    split = None
    if not case_sensitive:
        ig = set([w.upper() for w in ignore_words])
        ignore_words = ig
    default_clusters = {}
    default_words = {}
    ref_file = args.ref
    ref_ocr = args.ref_ocr
    rec_files = args.rec_file
    rec_names = args.rec_name
    assert len(rec_files) == len(rec_names)
    # load ocr
    ref_ocr_dict = {}
    with codecs.open(ref_ocr, "r", "utf-8") as fh:
        for line in fh:
            if "$" in line:
                line = line.replace("$", " ")
            if tochar:
                array = characterize(line)
            else:
                array = line.strip().split()
            if len(array) == 0:
                continue
            fid = array[0]
            ref_ocr_dict[fid] = normalize(array[1:], ignore_words, case_sensitive, split)
    if split and not case_sensitive:
        newsplit = dict()
        for w in split:
            words = split[w]
            for i in range(len(words)):
                words[i] = words[i].upper()
            newsplit[w.upper()] = words
        split = newsplit
    rec_sets = {}
    calculators_dict = dict()
    ub_wer_dict = dict()
    hotwords_related_dict = dict()  # 记录recall相关的内容
    for i, hyp_file in enumerate(rec_files):
        rec_sets[rec_names[i]] = dict()
        with codecs.open(hyp_file, "r", "utf-8") as fh:
            for line in fh:
                if tochar:
                    array = characterize(line)
                else:
                    array = line.strip().split()
                if len(array) == 0:
                    continue
                fid = array[0]
                rec_sets[rec_names[i]][fid] = normalize(array[1:], ignore_words, case_sensitive, split)
        calculators_dict[rec_names[i]] = Calculator()
        ub_wer_dict[rec_names[i]] = {"u_wer": WordError(), "b_wer": WordError(), "wer": WordError()}
        hotwords_related_dict[rec_names[i]] = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}
        # tp: 热词在label里,同时在rec里
        # tn: 热词不在label里,同时不在rec里
        # fp: 热词不在label里,但是在rec里
        # fn: 热词在label里,但是不在rec里
    # record wrong label but in ocr
    wrong_rec_but_in_ocr_dict = {}
    for rec_name in rec_names:
        wrong_rec_but_in_ocr_dict[rec_name] = 0
    _file_total_len = 0
    with os.popen("cat {} | wc -l".format(ref_file)) as pipe:
        _file_total_len = int(pipe.read().strip())
    # compute error rate on the interaction of reference file and hyp file
    for line in tqdm(open(ref_file, 'r', encoding='utf-8'), total=_file_total_len):
        if tochar:
            array = characterize(line)
        else:
            array = line.rstrip('\n').split()
        if len(array) == 0: continue
        fid = array[0]
        lab = normalize(array[1:], ignore_words, case_sensitive, split)
        if verbose:
            print('\nutt: %s' % fid)
        ocr_text = ref_ocr_dict[fid]
        ocr_set = set(ocr_text)
        print('ocr: {}'.format(" ".join(ocr_text)))
        list_match = []  # 指label里面在ocr里面的内容
        list_not_mathch = []
        tmp_error = 0
        tmp_match = 0
        for index in range(len(lab)):
            # text_list.append(uttlist[index+1])
            if lab[index] not in ocr_set:
                tmp_error += 1
                list_not_mathch.append(lab[index])
            else:
                tmp_match += 1
                list_match.append(lab[index])
        print('label in ocr: {}'.format(" ".join(list_match)))
        # for each reco file
        base_wrong_ocr_wer = None
        ocr_wrong_ocr_wer = None
        for rec_name in rec_names:
            rec_set = rec_sets[rec_name]
            if fid not in rec_set:
                continue
            rec = rec_set[fid]
            # print(rec)
            for word in rec + lab:
                if word not in default_words:
                    default_cluster_name = default_cluster(word)
                    if default_cluster_name not in default_clusters:
                        default_clusters[default_cluster_name] = {}
                    if word not in default_clusters[default_cluster_name]:
                        default_clusters[default_cluster_name][word] = 1
                    default_words[word] = default_cluster_name
            result = calculators_dict[rec_name].calculate(lab.copy(), rec.copy())
            if verbose:
                if result['all'] != 0:
                    wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
                else:
                    wer = 0.0
            print('WER(%s): %4.2f %%' % (rec_name, wer), end=' ')
            print('N=%d C=%d S=%d D=%d I=%d' %
                  (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
            # print(result['rec'])
            wrong_rec_but_in_ocr = []
            for idx in range(len(result['lab'])):
                if result['lab'][idx] != "":
                    if result['lab'][idx] != result['rec'][idx].replace("<BIAS>", ""):
                        if result['lab'][idx] in list_match:
                            wrong_rec_but_in_ocr.append(result['lab'][idx])
                            wrong_rec_but_in_ocr_dict[rec_name] += 1
            print('wrong_rec_but_in_ocr: {}'.format(" ".join(wrong_rec_but_in_ocr)))
            if rec_name == "base":
                base_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
            if "ocr" in rec_name or "hot" in rec_name:
                ocr_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
                if ocr_wrong_ocr_wer < base_wrong_ocr_wer:
                    print("{} {} helps, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
                elif ocr_wrong_ocr_wer > base_wrong_ocr_wer:
                    print("{} {} hurts, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
            # recall = 0
            # false_alarm = 0
            # for idx in range(len(result['lab'])):
            #     if "<BIAS>" in result['rec'][idx]:
            #         if result['rec'][idx].replace("<BIAS>", "") in list_match:
            #             recall += 1
            #         else:
            #             false_alarm += 1
            # print("bias hotwords recall: {}, fa: {}, list_match {}, recall: {:.2f}, fa: {:.2f}".format(
            #     recall, false_alarm, len(list_match), recall / len(list_match) if len(list_match) != 0 else 0, false_alarm / len(list_match) if len(list_match) != 0 else 0
            # ))
            # tp: 热词在label里,同时在rec里
            # tn: 热词不在label里,同时不在rec里
            # fp: 热词不在label里,但是在rec里
            # fn: 热词在label里,但是不在rec里
            _rec_list = [word.replace("<BIAS>", "") for word in rec]
            _label_list = [word for word in lab]
            _tp = _tn = _fp = _fn = 0
            hot_true_list = [hotword for hotword in ocr_text if hotword in _label_list]
            hot_bad_list = [hotword for hotword in ocr_text if hotword not in _label_list]
            for badhotword in hot_bad_list:
                count = len([word for word in _rec_list if word == badhotword])
                # print(f"bad {badhotword} count: {count}")
                # for word in _rec_list:
                #     if badhotword == word:
                #         count += 1
                if count == 0:
                    hotwords_related_dict[rec_name]['tn'] += 1
                    _tn += 1
                    # fp: 0
                else:
                    hotwords_related_dict[rec_name]['fp'] += count
                    _fp += count
                    # tn: 0
                # if badhotword in _rec_list:
                #     hotwords_related_dict[rec_name]['fp'] += 1
                # else:
                #     hotwords_related_dict[rec_name]['tn'] += 1
            for hotword in hot_true_list:
                true_count = len([word for word in _label_list if hotword == word])
                rec_count = len([word for word in _rec_list if hotword == word])
                # print(f"good {hotword} true_count: {true_count}, rec_count: {rec_count}")
                if rec_count == true_count:
                    hotwords_related_dict[rec_name]['tp'] += true_count
                    _tp += true_count
                elif rec_count > true_count:
                    hotwords_related_dict[rec_name]['tp'] += true_count
                    # fp: 不在label里,但是在rec里
                    hotwords_related_dict[rec_name]['fp'] += rec_count - true_count
                    _tp += true_count
                    _fp += rec_count - true_count
                else:
                    hotwords_related_dict[rec_name]['tp'] += rec_count
                    # fn: 热词在label里,但是不在rec里
                    hotwords_related_dict[rec_name]['fn'] += true_count - rec_count
                    _tp += rec_count
                    _fn += true_count - rec_count
            print("hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
                _tp, _tn, _fp, _fn, sum([_tp, _tn, _fp, _fn]), _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0
            ))
            # if hotword in _rec_list:
            #     hotwords_related_dict[rec_name]['tp'] += 1
            # else:
            #     hotwords_related_dict[rec_name]['fn'] += 1
            # 计算uwer, bwer, wer
            for code, rec_word, lab_word in zip(result["code"], result["rec"], result["lab"]):
                if code == Code.match:
                    ub_wer_dict[rec_name]["wer"].ref_words += 1
                    if lab_word in hot_true_list:
                        # tmp_ref.append(ref_tokens[ref_idx])
                        ub_wer_dict[rec_name]["b_wer"].ref_words += 1
                    else:
                        ub_wer_dict[rec_name]["u_wer"].ref_words += 1
                elif code == Code.substitution:
                    ub_wer_dict[rec_name]["wer"].ref_words += 1
                    ub_wer_dict[rec_name]["wer"].errors[Code.substitution] += 1
                    if lab_word in hot_true_list:
                        # tmp_ref.append(ref_tokens[ref_idx])
                        ub_wer_dict[rec_name]["b_wer"].ref_words += 1
                        ub_wer_dict[rec_name]["b_wer"].errors[Code.substitution] += 1
                    else:
                        ub_wer_dict[rec_name]["u_wer"].ref_words += 1
                        ub_wer_dict[rec_name]["u_wer"].errors[Code.substitution] += 1
                elif code == Code.deletion:
                    ub_wer_dict[rec_name]["wer"].ref_words += 1
                    ub_wer_dict[rec_name]["wer"].errors[Code.deletion] += 1
                    if lab_word in hot_true_list:
                        # tmp_ref.append(ref_tokens[ref_idx])
                        ub_wer_dict[rec_name]["b_wer"].ref_words += 1
                        ub_wer_dict[rec_name]["b_wer"].errors[Code.deletion] += 1
                    else:
                        ub_wer_dict[rec_name]["u_wer"].ref_words += 1
                        ub_wer_dict[rec_name]["u_wer"].errors[Code.deletion] += 1
                elif code == Code.insertion:
                    ub_wer_dict[rec_name]["wer"].errors[Code.insertion] += 1
                    if rec_word in hot_true_list:
                        ub_wer_dict[rec_name]["b_wer"].errors[Code.insertion] += 1
                    else:
                        ub_wer_dict[rec_name]["u_wer"].errors[Code.insertion] += 1
            space = {}
            space['lab'] = []
            space['rec'] = []
            for idx in range(len(result['lab'])):
                len_lab = width(result['lab'][idx])
                len_rec = width(result['rec'][idx])
                length = max(len_lab, len_rec)
                space['lab'].append(length - len_lab)
                space['rec'].append(length - len_rec)
            upper_lab = len(result['lab'])
            upper_rec = len(result['rec'])
            lab1, rec1 = 0, 0
            while lab1 < upper_lab or rec1 < upper_rec:
                if verbose > 1:
                    print('lab(%s):' % fid.encode('utf-8'), end=' ')
                else:
                    print('lab:', end=' ')
                lab2 = min(upper_lab, lab1 + max_words_per_line)
                for idx in range(lab1, lab2):
                    token = result['lab'][idx]
                    print('{token}'.format(token=token), end='')
                    for n in range(space['lab'][idx]):
                        print(padding_symbol, end='')
                    print(' ', end='')
                print()
                if verbose > 1:
                    print('rec(%s):' % fid.encode('utf-8'), end=' ')
                else:
                    print('rec:', end=' ')
                rec2 = min(upper_rec, rec1 + max_words_per_line)
                for idx in range(rec1, rec2):
                    token = result['rec'][idx]
                    print('{token}'.format(token=token), end='')
                    for n in range(space['rec'][idx]):
                        print(padding_symbol, end='')
                    print(' ', end='')
                print()
                # print('\n', end='\n')
                lab1 = lab2
                rec1 = rec2
        print('\n', end='\n')
        # break
    if verbose:
        print('===========================================================================')
        print()
    print(wrong_rec_but_in_ocr_dict)
    for rec_name in rec_names:
        result = calculators_dict[rec_name].overall()
        if result['all'] != 0:
            wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
        else:
            wer = 0.0
        print('{} Overall -> {:4.2f} %'.format(rec_name, wer), end=' ')
        print('N=%d C=%d S=%d D=%d I=%d' %
              (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
        print(f"WER: {ub_wer_dict[rec_name]['wer'].get_result_string()}")
        print(f"U-WER: {ub_wer_dict[rec_name]['u_wer'].get_result_string()}")
        print(f"B-WER: {ub_wer_dict[rec_name]['b_wer'].get_result_string()}")
        print('hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%'.format(
            hotwords_related_dict[rec_name]['tp'],
            hotwords_related_dict[rec_name]['tn'],
            hotwords_related_dict[rec_name]['fp'],
            hotwords_related_dict[rec_name]['fn'],
            sum([v for k, v in hotwords_related_dict[rec_name].items()]),
            hotwords_related_dict[rec_name]['tp'] / (
                    hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn']
            ) * 100 if hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] != 0 else 0
        ))
        # tp: 热词在label里,同时在rec里
        # tn: 热词不在label里,同时不在rec里
        # fp: 热词不在label里,但是在rec里
        # fn: 热词在label里,但是不在rec里
        if not verbose:
            print()
        print()
if __name__ == "__main__":
    args = get_args()
    # print("")
    print(args)
    main(args)
examples/industrial_data_pretraining/lcbnet/demo.py
New file
@@ -0,0 +1,13 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
from funasr import AutoModel
model = AutoModel(model="iic/LCB-NET",
                  model_revision="v1.0.0")
res = model.generate(input=("https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/asr_example.wav","https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/ocr.txt"),data_type=("sound", "text"))
print(res)
examples/industrial_data_pretraining/lcbnet/demo.sh
New file
@@ -0,0 +1,72 @@
file_dir="/home/yf352572/.cache/modelscope/hub/iic/LCB-NET/"
CUDA_VISIBLE_DEVICES="0,1"
inference_device="cuda"
if [ ${inference_device} == "cuda" ]; then
    nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
else
    inference_batch_size=1
    CUDA_VISIBLE_DEVICES=""
    for JOB in $(seq ${nj}); do
        CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1,"
    done
fi
inference_dir="outputs/slidespeech_dev"
_logdir="${inference_dir}/logdir"
echo "inference_dir: ${inference_dir}"
mkdir -p "${_logdir}"
key_file1=${file_dir}/dev/wav.scp
key_file2=${file_dir}/dev/ocr.txt
split_scps1=
split_scps2=
for JOB in $(seq "${nj}"); do
    split_scps1+=" ${_logdir}/wav.${JOB}.scp"
    split_scps2+=" ${_logdir}/ocr.${JOB}.txt"
done
utils/split_scp.pl "${key_file1}" ${split_scps1}
utils/split_scp.pl "${key_file2}" ${split_scps2}
gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ })
for JOB in $(seq ${nj}); do
    {
        id=$((JOB-1))
        gpuid=${gpuid_list_array[$id]}
        export CUDA_VISIBLE_DEVICES=${gpuid}
        python -m funasr.bin.inference \
        --config-path=${file_dir} \
        --config-name="config.yaml" \
        ++init_param=${file_dir}/model.pt \
        ++tokenizer_conf.token_list=${file_dir}/tokens.txt \
        ++input=[${_logdir}/wav.${JOB}.scp,${_logdir}/ocr.${JOB}.txt] \
        +data_type='["kaldi_ark", "text"]' \
        ++tokenizer_conf.bpemodel=${file_dir}/bpe.pt \
        ++normalize_conf.stats_file=${file_dir}/am.mvn \
        ++output_dir="${inference_dir}/${JOB}" \
        ++device="${inference_device}" \
        ++ncpu=1 \
        ++disable_log=true  &> ${_logdir}/log.${JOB}.txt
    }&
done
wait
mkdir -p ${inference_dir}/1best_recog
for JOB in $(seq "${nj}"); do
   cat "${inference_dir}/${JOB}/1best_recog/token" >> "${inference_dir}/1best_recog/token"
done
echo "Computing WER ..."
sed -e 's/ /\t/' -e 's/ //g' -e 's/▁/ /g' -e 's/\t /\t/'  ${inference_dir}/1best_recog/token > ${inference_dir}/1best_recog/token.proc
cp  ${file_dir}/dev/text ${inference_dir}/1best_recog/token.ref
cp  ${file_dir}/dev/ocr.list ${inference_dir}/1best_recog/ocr.list
python utils/compute_wer.py ${inference_dir}/1best_recog/token.ref ${inference_dir}/1best_recog/token.proc ${inference_dir}/1best_recog/token.cer
tail -n 3 ${inference_dir}/1best_recog/token.cer
./run_bwer_recall.sh  ${inference_dir}/1best_recog/
tail -n 6 ${inference_dir}/1best_recog/BWER-UWER.results |head -n 5
examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh
New file
@@ -0,0 +1,11 @@
#now_result_name=asr_conformer_acc1_lr002_warm20000/decode_asr_asr_model_valid.acc.ave
#hotword_type=ocr_1ngram_top10_hotwords_list
hot_exp_suf=$1
python compute_wer_details.py --v 1 \
   --ref ${hot_exp_suf}/token.ref \
   --ref_ocr ${hot_exp_suf}/ocr.list  \
   --rec_name base \
   --rec_file ${hot_exp_suf}/token.proc \
   > ${hot_exp_suf}/BWER-UWER.results
examples/industrial_data_pretraining/lcbnet/utils
New file
@@ -0,0 +1 @@
../../aishell/paraformer/utils
examples/industrial_data_pretraining/seaco_paraformer/demo.py
@@ -7,10 +7,10 @@
model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                  model_revision="v2.0.4",
                  vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                  vad_model_revision="v2.0.4",
                  punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                  punc_model_revision="v2.0.4",
                  # vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
                  # vad_model_revision="v2.0.4",
                  # punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
                  # punc_model_revision="v2.0.4",
                  # spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
                  # spk_model_revision="v2.0.2",
                  )
@@ -43,4 +43,4 @@
wav_file = os.path.join(model.model_path, "example/asr_example.wav")
speech, sample_rate = soundfile.read(wav_file)
res = model.generate(input=[speech], batch_size_s=300, is_final=True)
'''
'''
funasr/auto/auto_model.py
@@ -28,7 +28,7 @@
    from funasr.models.campplus.cluster_backend import ClusterBackend
except:
    print("If you want to use the speaker diarization, please `pip install hdbscan`")
import pdb
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
    """
@@ -46,6 +46,7 @@
    chars = string.ascii_letters + string.digits
    if isinstance(data_in, str) and data_in.startswith('http'): # url
        data_in = download_from_url(data_in)
    if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
        _, file_extension = os.path.splitext(data_in)
        file_extension = file_extension.lower()
@@ -146,7 +147,7 @@
            kwargs = download_model(**kwargs)
        
        set_all_random_seed(kwargs.get("seed", 0))
        device = kwargs.get("device", "cuda")
        if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
            device = "cpu"
@@ -168,7 +169,6 @@
            vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
        else:
            vocab_size = -1
        # build frontend
        frontend = kwargs.get("frontend", None)
        kwargs["input_size"] = None
@@ -181,7 +181,6 @@
        # build model
        model_class = tables.model_classes.get(kwargs["model"])
        model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
        model.to(device)
        
        # init_param
@@ -224,9 +223,9 @@
        batch_size = kwargs.get("batch_size", 1)
        # if kwargs.get("device", "cpu") == "cpu":
        #     batch_size = 1
        key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key)
        speed_stats = {}
        asr_result_list = []
        num_samples = len(data_list)
@@ -239,6 +238,7 @@
            data_batch = data_list[beg_idx:end_idx]
            key_batch = key_list[beg_idx:end_idx]
            batch = {"data_in": data_batch, "key": key_batch}
            if (end_idx - beg_idx) == 1 and kwargs.get("data_type", None) == "fbank": # fbank
                batch["data_in"] = data_batch[0]
                batch["data_lengths"] = input_len
funasr/frontends/default.py
@@ -3,7 +3,6 @@
from typing import Tuple
from typing import Union
import logging
import humanfriendly
import numpy as np
import torch
import torch.nn as nn
@@ -16,8 +15,10 @@
from funasr.frontends.utils.stft import Stft
from funasr.frontends.utils.frontend import Frontend
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.register import tables
@tables.register("frontend_classes", "DefaultFrontend")
class DefaultFrontend(nn.Module):
    """Conventional frontend structure for ASR.
    Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
@@ -25,7 +26,7 @@
    def __init__(
            self,
            fs: Union[int, str] = 16000,
            fs: int = 16000,
            n_fft: int = 512,
            win_length: int = None,
            hop_length: int = 128,
@@ -40,14 +41,14 @@
            frontend_conf: Optional[dict] = None,
            apply_stft: bool = True,
            use_channel: int = None,
            **kwargs,
    ):
        super().__init__()
        if isinstance(fs, str):
            fs = humanfriendly.parse_size(fs)
        # Deepcopy (In general, dict shouldn't be used as default arg)
        frontend_conf = copy.deepcopy(frontend_conf)
        self.hop_length = hop_length
        self.fs = fs
        if apply_stft:
            self.stft = Stft(
@@ -84,8 +85,12 @@
        return self.n_mels
    def forward(
            self, input: torch.Tensor, input_lengths: torch.Tensor
            self, input: torch.Tensor, input_lengths:  Union[torch.Tensor, list]
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        if isinstance(input_lengths, list):
            input_lengths = torch.tensor(input_lengths)
        if  input.dtype == torch.float64:
            input = input.float()
        # 1. Domain-conversion: e.g. Stft: time -> time-freq
        if self.stft is not None:
            input_stft, feats_lens = self._compute_stft(input, input_lengths)
@@ -145,7 +150,7 @@
    def __init__(
            self,
            fs: Union[int, str] = 16000,
            fs: int = 16000,
            n_fft: int = 512,
            win_length: int = None,
            hop_length: int = None,
@@ -168,9 +173,6 @@
            mc: bool = True
    ):
        super().__init__()
        if isinstance(fs, str):
            fs = humanfriendly.parse_size(fs)
        # Deepcopy (In general, dict shouldn't be used as default arg)
        frontend_conf = copy.deepcopy(frontend_conf)
        if win_length is None and hop_length is None:
funasr/models/conformer/encoder.py
@@ -47,7 +47,7 @@
from funasr.models.transformer.utils.subsampling import Conv2dSubsamplingPad
from funasr.models.transformer.utils.subsampling import StreamingConvInput
from funasr.register import tables
import pdb
class ConvolutionModule(nn.Module):
    """ConvolutionModule in Conformer model.
funasr/models/contextual_paraformer/model.py
@@ -29,7 +29,7 @@
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
import pdb
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
@@ -62,7 +62,6 @@
        crit_attn_weight = kwargs.get("crit_attn_weight", 0.0)
        crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0)
        bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)
        if bias_encoder_type == 'lstm':
            self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate)
@@ -103,17 +102,16 @@
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        batch_size = speech.shape[0]
        hotword_pad = kwargs.get("hotword_pad")
        hotword_lengths = kwargs.get("hotword_lengths")
        dha_pad = kwargs.get("dha_pad")
        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        loss_ctc, cer_ctc = None, None
        
        stats = dict()
@@ -128,12 +126,11 @@
            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
            stats["cer_ctc"] = cer_ctc
        
        # 2b. Attention decoder branch
        loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
            encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths
        )
        # 3. CTC-Att loss definition
        if self.ctc_weight == 0.0:
            loss = loss_att + loss_pre * self.predictor_weight
@@ -171,22 +168,24 @@
    ):
        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
            encoder_out.device)
        if self.predictor_bias == 1:
            _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
            ys_pad_lens = ys_pad_lens + self.predictor_bias
        pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
                                                                     ignore_id=self.ignore_id)
        # -1. bias encoder
        if self.use_decoder_embedding:
            hw_embed = self.decoder.embed(hotword_pad)
        else:
            hw_embed = self.bias_embed(hotword_pad)
        hw_embed, (_, _) = self.bias_encoder(hw_embed)
        _ind = np.arange(0, hotword_pad.shape[0]).tolist()
        selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]]
        contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
        # 0. sampler
        decoder_out_1st = None
        if self.sampling_ratio > 0.0:
@@ -195,7 +194,7 @@
                                                           pre_acoustic_embeds, contextual_info)
        else:
            sematic_embeds = pre_acoustic_embeds
        # 1. Forward decoder
        decoder_outs = self.decoder(
            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
@@ -211,7 +210,7 @@
            loss_ideal = None
        '''
        loss_ideal = None
        if decoder_out_1st is None:
            decoder_out_1st = decoder_out
        # 2. Compute attention loss
@@ -288,10 +287,11 @@
                                                               enforce_sorted=False)
            _, (h_n, _) = self.bias_encoder(hw_embed)
            hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
        decoder_outs = self.decoder(
            encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
        )
        decoder_out = decoder_outs[0]
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        return decoder_out, ys_pad_lens
@@ -305,38 +305,42 @@
                 **kwargs,
                 ):
        # init beamsearch
        is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
        is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        if self.beam_search is None and (is_use_lm or is_use_ctc):
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)
        meta_data = {}
        
        # extract fbank feats
        time1 = time.perf_counter()
        audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
        time2 = time.perf_counter()
        meta_data["load_data"] = f"{time2 - time1:0.3f}"
        speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
                                               frontend=frontend)
        time3 = time.perf_counter()
        meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
        meta_data[
            "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
        speech = speech.to(device=kwargs["device"])
        speech_lengths = speech_lengths.to(device=kwargs["device"])
        # hotword
        self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend)
        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]
        # predictor
        predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
@@ -344,8 +348,7 @@
        pre_token_length = pre_token_length.round().long()
        if torch.max(pre_token_length) < 1:
            return []
        decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens,
                                                                 pre_acoustic_embeds,
                                                                 pre_token_length,
funasr/models/lcbnet/__init__.py
funasr/models/lcbnet/attention.py
New file
@@ -0,0 +1,112 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
# Copyright 2024 yufan
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""Multi-Head Attention Return Weight layer definition."""
import math
import torch
from torch import nn
class MultiHeadedAttentionReturnWeight(nn.Module):
    """Multi-Head Attention layer.
    Args:
        n_head (int): The number of heads.
        n_feat (int): The number of features.
        dropout_rate (float): Dropout rate.
    """
    def __init__(self, n_head, n_feat, dropout_rate):
        """Construct an MultiHeadedAttentionReturnWeight object."""
        super(MultiHeadedAttentionReturnWeight, self).__init__()
        assert n_feat % n_head == 0
        # We assume d_v always equals d_k
        self.d_k = n_feat // n_head
        self.h = n_head
        self.linear_q = nn.Linear(n_feat, n_feat)
        self.linear_k = nn.Linear(n_feat, n_feat)
        self.linear_v = nn.Linear(n_feat, n_feat)
        self.linear_out = nn.Linear(n_feat, n_feat)
        self.attn = None
        self.dropout = nn.Dropout(p=dropout_rate)
    def forward_qkv(self, query, key, value):
        """Transform query, key and value.
        Args:
            query (torch.Tensor): Query tensor (#batch, time1, size).
            key (torch.Tensor): Key tensor (#batch, time2, size).
            value (torch.Tensor): Value tensor (#batch, time2, size).
        Returns:
            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
        """
        n_batch = query.size(0)
        q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
        k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
        v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
        q = q.transpose(1, 2)  # (batch, head, time1, d_k)
        k = k.transpose(1, 2)  # (batch, head, time2, d_k)
        v = v.transpose(1, 2)  # (batch, head, time2, d_k)
        return q, k, v
    def forward_attention(self, value, scores, mask):
        """Compute attention context vector.
        Args:
            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
        Returns:
            torch.Tensor: Transformed value (#batch, time1, d_model)
                weighted by the attention score (#batch, time1, time2).
        """
        n_batch = value.size(0)
        if mask is not None:
            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
            min_value = torch.finfo(scores.dtype).min
            scores = scores.masked_fill(mask, min_value)
            self.attn = torch.softmax(scores, dim=-1).masked_fill(
                mask, 0.0
            )  # (batch, head, time1, time2)
        else:
            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
        p_attn = self.dropout(self.attn)
        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
        x = (
            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
        )  # (batch, time1, d_model)
        return self.linear_out(x), self.attn  # (batch, time1, d_model)
    def forward(self, query, key, value, mask):
        """Compute scaled dot product attention.
        Args:
            query (torch.Tensor): Query tensor (#batch, time1, size).
            key (torch.Tensor): Key tensor (#batch, time2, size).
            value (torch.Tensor): Value tensor (#batch, time2, size).
            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
                (#batch, time1, time2).
        Returns:
            torch.Tensor: Output tensor (#batch, time1, d_model).
        """
        q, k, v = self.forward_qkv(query, key, value)
        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
        return self.forward_attention(v, scores, mask)
funasr/models/lcbnet/encoder.py
New file
@@ -0,0 +1,392 @@
# Copyright 2019 Shigeki Karita
#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
"""Transformer encoder definition."""
from typing import List
from typing import Optional
from typing import Tuple
import torch
from torch import nn
import logging
from funasr.models.transformer.attention import MultiHeadedAttention
from funasr.models.lcbnet.attention import MultiHeadedAttentionReturnWeight
from funasr.models.transformer.embedding import PositionalEncoding
from funasr.models.transformer.layer_norm import LayerNorm
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
from funasr.models.transformer.utils.repeat import repeat
from funasr.register import tables
class EncoderLayer(nn.Module):
    """Encoder layer module.
    Args:
        size (int): Input dimension.
        self_attn (torch.nn.Module): Self-attention module instance.
            `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
            can be used as the argument.
        feed_forward (torch.nn.Module): Feed-forward module instance.
            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
            can be used as the argument.
        dropout_rate (float): Dropout rate.
        normalize_before (bool): Whether to use layer_norm before the first block.
        concat_after (bool): Whether to concat attention layer's input and output.
            if True, additional linear will be applied.
            i.e. x -> x + linear(concat(x, att(x)))
            if False, no additional linear will be applied. i.e. x -> x + att(x)
        stochastic_depth_rate (float): Proability to skip this layer.
            During training, the layer may skip residual computation and return input
            as-is with given probability.
    """
    def __init__(
            self,
            size,
            self_attn,
            feed_forward,
            dropout_rate,
            normalize_before=True,
            concat_after=False,
            stochastic_depth_rate=0.0,
    ):
        """Construct an EncoderLayer object."""
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.norm1 = LayerNorm(size)
        self.norm2 = LayerNorm(size)
        self.dropout = nn.Dropout(dropout_rate)
        self.size = size
        self.normalize_before = normalize_before
        self.concat_after = concat_after
        if self.concat_after:
            self.concat_linear = nn.Linear(size + size, size)
        self.stochastic_depth_rate = stochastic_depth_rate
    def forward(self, x, mask, cache=None):
        """Compute encoded features.
        Args:
            x_input (torch.Tensor): Input tensor (#batch, time, size).
            mask (torch.Tensor): Mask tensor for the input (#batch, time).
            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
        Returns:
            torch.Tensor: Output tensor (#batch, time, size).
            torch.Tensor: Mask tensor (#batch, time).
        """
        skip_layer = False
        # with stochastic depth, residual connection `x + f(x)` becomes
        # `x <- x + 1 / (1 - p) * f(x)` at training time.
        stoch_layer_coeff = 1.0
        if self.training and self.stochastic_depth_rate > 0:
            skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
            stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
        if skip_layer:
            if cache is not None:
                x = torch.cat([cache, x], dim=1)
            return x, mask
        residual = x
        if self.normalize_before:
            x = self.norm1(x)
        if cache is None:
            x_q = x
        else:
            assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
            x_q = x[:, -1:, :]
            residual = residual[:, -1:, :]
            mask = None if mask is None else mask[:, -1:, :]
        if self.concat_after:
            x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
            x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
        else:
            x = residual + stoch_layer_coeff * self.dropout(
                self.self_attn(x_q, x, x, mask)
            )
        if not self.normalize_before:
            x = self.norm1(x)
        residual = x
        if self.normalize_before:
            x = self.norm2(x)
        x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
        if not self.normalize_before:
            x = self.norm2(x)
        if cache is not None:
            x = torch.cat([cache, x], dim=1)
        return x, mask
@tables.register("encoder_classes", "TransformerTextEncoder")
class TransformerTextEncoder(nn.Module):
    """Transformer text encoder module.
    Args:
        input_size: input dim
        output_size: dimension of attention
        attention_heads: the number of heads of multi head attention
        linear_units: the number of units of position-wise feed forward
        num_blocks: the number of decoder blocks
        dropout_rate: dropout rate
        attention_dropout_rate: dropout rate in attention
        positional_dropout_rate: dropout rate after adding positional encoding
        input_layer: input layer type
        pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
        normalize_before: whether to use layer_norm before the first block
        concat_after: whether to concat attention layer's input and output
            if True, additional linear will be applied.
            i.e. x -> x + linear(concat(x, att(x)))
            if False, no additional linear will be applied.
            i.e. x -> x + att(x)
        positionwise_layer_type: linear of conv1d
        positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
        padding_idx: padding_idx for input_layer=embed
    """
    def __init__(
            self,
            input_size: int,
            output_size: int = 256,
            attention_heads: int = 4,
            linear_units: int = 2048,
            num_blocks: int = 6,
            dropout_rate: float = 0.1,
            positional_dropout_rate: float = 0.1,
            attention_dropout_rate: float = 0.0,
            pos_enc_class=PositionalEncoding,
            normalize_before: bool = True,
            concat_after: bool = False,
    ):
        super().__init__()
        self._output_size = output_size
        self.embed = torch.nn.Sequential(
            torch.nn.Embedding(input_size, output_size),
            pos_enc_class(output_size, positional_dropout_rate),
        )
        self.normalize_before = normalize_before
        positionwise_layer = PositionwiseFeedForward
        positionwise_layer_args = (
            output_size,
            linear_units,
            dropout_rate,
        )
        self.encoders = repeat(
            num_blocks,
            lambda lnum: EncoderLayer(
                output_size,
                MultiHeadedAttention(
                    attention_heads, output_size, attention_dropout_rate
                ),
                positionwise_layer(*positionwise_layer_args),
                dropout_rate,
                normalize_before,
                concat_after,
            ),
        )
        if self.normalize_before:
            self.after_norm = LayerNorm(output_size)
    def output_size(self) -> int:
        return self._output_size
    def forward(
            self,
            xs_pad: torch.Tensor,
            ilens: torch.Tensor,
    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
        """Embed positions in tensor.
        Args:
            xs_pad: input tensor (B, L, D)
            ilens: input length (B)
        Returns:
            position embedded tensor and mask
        """
        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
        xs_pad = self.embed(xs_pad)
        xs_pad, masks = self.encoders(xs_pad, masks)
        if self.normalize_before:
            xs_pad = self.after_norm(xs_pad)
        olens = masks.squeeze(1).sum(1)
        return xs_pad, olens, None
@tables.register("encoder_classes", "FusionSANEncoder")
class SelfSrcAttention(nn.Module):
    """Single decoder layer module.
    Args:
        size (int): Input dimension.
        self_attn (torch.nn.Module): Self-attention module instance.
            `MultiHeadedAttention` instance can be used as the argument.
        src_attn (torch.nn.Module): Self-attention module instance.
            `MultiHeadedAttention` instance can be used as the argument.
        feed_forward (torch.nn.Module): Feed-forward module instance.
            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
            can be used as the argument.
        dropout_rate (float): Dropout rate.
        normalize_before (bool): Whether to use layer_norm before the first block.
        concat_after (bool): Whether to concat attention layer's input and output.
            if True, additional linear will be applied.
            i.e. x -> x + linear(concat(x, att(x)))
            if False, no additional linear will be applied. i.e. x -> x + att(x)
    """
    def __init__(
            self,
            size,
            attention_heads,
            attention_dim,
            linear_units,
            self_attention_dropout_rate,
            src_attention_dropout_rate,
            positional_dropout_rate,
            dropout_rate,
            normalize_before=True,
            concat_after=False,
    ):
        """Construct an SelfSrcAttention object."""
        super(SelfSrcAttention, self).__init__()
        self.size = size
        self.self_attn = MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate)
        self.src_attn = MultiHeadedAttentionReturnWeight(attention_heads, attention_dim, src_attention_dropout_rate)
        self.feed_forward = PositionwiseFeedForward(attention_dim, linear_units, positional_dropout_rate)
        self.norm1 = LayerNorm(size)
        self.norm2 = LayerNorm(size)
        self.norm3 = LayerNorm(size)
        self.dropout = nn.Dropout(dropout_rate)
        self.normalize_before = normalize_before
        self.concat_after = concat_after
        if self.concat_after:
            self.concat_linear1 = nn.Linear(size + size, size)
            self.concat_linear2 = nn.Linear(size + size, size)
    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
        """Compute decoded features.
        Args:
            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
            cache (List[torch.Tensor]): List of cached tensors.
                Each tensor shape should be (#batch, maxlen_out - 1, size).
        Returns:
            torch.Tensor: Output tensor(#batch, maxlen_out, size).
            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
        """
        residual = tgt
        if self.normalize_before:
            tgt = self.norm1(tgt)
        if cache is None:
            tgt_q = tgt
            tgt_q_mask = tgt_mask
        else:
            # compute only the last frame query keeping dim: max_time_out -> 1
            assert cache.shape == (
                tgt.shape[0],
                tgt.shape[1] - 1,
                self.size,
            ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
            tgt_q = tgt[:, -1:, :]
            residual = residual[:, -1:, :]
            tgt_q_mask = None
            if tgt_mask is not None:
                tgt_q_mask = tgt_mask[:, -1:, :]
        if self.concat_after:
            tgt_concat = torch.cat(
                (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
            )
            x = residual + self.concat_linear1(tgt_concat)
        else:
            x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
        if not self.normalize_before:
            x = self.norm1(x)
        residual = x
        if self.normalize_before:
            x = self.norm2(x)
        if self.concat_after:
            x_concat = torch.cat(
                (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
            )
            x = residual + self.concat_linear2(x_concat)
        else:
            x, score = self.src_attn(x, memory, memory, memory_mask)
            x = residual + self.dropout(x)
        if not self.normalize_before:
            x = self.norm2(x)
        residual = x
        if self.normalize_before:
            x = self.norm3(x)
        x = residual + self.dropout(self.feed_forward(x))
        if not self.normalize_before:
            x = self.norm3(x)
        if cache is not None:
            x = torch.cat([cache, x], dim=1)
        return x, tgt_mask, memory, memory_mask
@tables.register("encoder_classes", "ConvBiasPredictor")
class ConvPredictor(nn.Module):
    def __init__(self, size=256, l_order=3, r_order=3, attention_heads=4, attention_dropout_rate=0.1, linear_units=2048):
        super().__init__()
        self.atten = MultiHeadedAttention(attention_heads, size, attention_dropout_rate)
        self.norm1 = LayerNorm(size)
        self.feed_forward = PositionwiseFeedForward(size, linear_units, attention_dropout_rate)
        self.norm2 = LayerNorm(size)
        self.pad = nn.ConstantPad1d((l_order, r_order), 0)
        self.conv1d = nn.Conv1d(size, size, l_order + r_order + 1, groups=size)
        self.output_linear = nn.Linear(size, 1)
    def forward(self, text_enc, asr_enc):
        # stage1 cross-attention
        residual = text_enc
        text_enc = residual + self.atten(text_enc, asr_enc, asr_enc, None)
        # stage2 FFN
        residual = text_enc
        text_enc = self.norm1(text_enc)
        text_enc = residual + self.feed_forward(text_enc)
        # stage Conv predictor
        text_enc = self.norm2(text_enc)
        context = text_enc.transpose(1, 2)
        queries = self.pad(context)
        memory = self.conv1d(queries)
        output = memory + context
        output = output.transpose(1, 2)
        output = torch.relu(output)
        output = self.output_linear(output)
        if output.dim()==3:
          output = output.squeeze(2)
        return output
funasr/models/lcbnet/model.py
New file
@@ -0,0 +1,495 @@
#!/usr/bin/env python3
# -*- encoding: utf-8 -*-
# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
#  MIT License  (https://opensource.org/licenses/MIT)
import logging
from typing import Union, Dict, List, Tuple, Optional
import time
import torch
import torch.nn as nn
from torch.cuda.amp import autocast
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.ctc.ctc import CTC
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.metrics.compute_acc import th_accuracy
# from funasr.models.e2e_asr_common import ErrorCalculator
from funasr.train_utils.device_funcs import force_gatherable
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils import postprocess_utils
from funasr.utils.datadir_writer import DatadirWriter
from funasr.register import tables
import pdb
@tables.register("model_classes", "LCBNet")
class LCBNet(nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    LCB-NET: LONG-CONTEXT BIASING FOR AUDIO-VISUAL SPEECH RECOGNITION
    https://arxiv.org/abs/2401.06390
    """
    def __init__(
        self,
        specaug: str = None,
        specaug_conf: dict = None,
        normalize: str = None,
        normalize_conf: dict = None,
        encoder: str = None,
        encoder_conf: dict = None,
        decoder: str = None,
        decoder_conf: dict = None,
        text_encoder: str = None,
        text_encoder_conf: dict = None,
        bias_predictor: str = None,
        bias_predictor_conf: dict = None,
        fusion_encoder: str = None,
        fusion_encoder_conf: dict = None,
        ctc: str = None,
        ctc_conf: dict = None,
        ctc_weight: float = 0.5,
        interctc_weight: float = 0.0,
        select_num: int = 2,
        select_length: int = 3,
        insert_blank: bool = True,
        input_size: int = 80,
        vocab_size: int = -1,
        ignore_id: int = -1,
        blank_id: int = 0,
        sos: int = 1,
        eos: int = 2,
        lsm_weight: float = 0.0,
        length_normalized_loss: bool = False,
        report_cer: bool = True,
        report_wer: bool = True,
        sym_space: str = "<space>",
        sym_blank: str = "<blank>",
        # extract_feats_in_collect_stats: bool = True,
        share_embedding: bool = False,
        # preencoder: Optional[AbsPreEncoder] = None,
        # postencoder: Optional[AbsPostEncoder] = None,
        **kwargs,
    ):
        super().__init__()
        if specaug is not None:
            specaug_class = tables.specaug_classes.get(specaug)
            specaug = specaug_class(**specaug_conf)
        if normalize is not None:
            normalize_class = tables.normalize_classes.get(normalize)
            normalize = normalize_class(**normalize_conf)
        encoder_class = tables.encoder_classes.get(encoder)
        encoder = encoder_class(input_size=input_size, **encoder_conf)
        encoder_output_size = encoder.output_size()
        # lcbnet modules: text encoder, fusion encoder and bias predictor
        text_encoder_class = tables.encoder_classes.get(text_encoder)
        text_encoder = text_encoder_class(input_size=vocab_size, **text_encoder_conf)
        fusion_encoder_class = tables.encoder_classes.get(fusion_encoder)
        fusion_encoder = fusion_encoder_class(**fusion_encoder_conf)
        bias_predictor_class = tables.encoder_classes.get(bias_predictor)
        bias_predictor = bias_predictor_class(**bias_predictor_conf)
        if decoder is not None:
            decoder_class = tables.decoder_classes.get(decoder)
            decoder = decoder_class(
                vocab_size=vocab_size,
                encoder_output_size=encoder_output_size,
                **decoder_conf,
            )
        if ctc_weight > 0.0:
            if ctc_conf is None:
                ctc_conf = {}
            ctc = CTC(
                odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
            )
        self.blank_id = blank_id
        self.sos = vocab_size - 1
        self.eos = vocab_size - 1
        self.vocab_size = vocab_size
        self.ignore_id = ignore_id
        self.ctc_weight = ctc_weight
        self.specaug = specaug
        self.normalize = normalize
        self.encoder = encoder
        # lcbnet
        self.text_encoder = text_encoder
        self.fusion_encoder = fusion_encoder
        self.bias_predictor = bias_predictor
        self.select_num = select_num
        self.select_length = select_length
        self.insert_blank = insert_blank
        if not hasattr(self.encoder, "interctc_use_conditioning"):
            self.encoder.interctc_use_conditioning = False
        if self.encoder.interctc_use_conditioning:
            self.encoder.conditioning_layer = torch.nn.Linear(
                vocab_size, self.encoder.output_size()
            )
        self.interctc_weight = interctc_weight
        # self.error_calculator = None
        if ctc_weight == 1.0:
            self.decoder = None
        else:
            self.decoder = decoder
        self.criterion_att = LabelSmoothingLoss(
            size=vocab_size,
            padding_idx=ignore_id,
            smoothing=lsm_weight,
            normalize_length=length_normalized_loss,
        )
        #
        # if report_cer or report_wer:
        #     self.error_calculator = ErrorCalculator(
        #         token_list, sym_space, sym_blank, report_cer, report_wer
        #     )
        #
        self.error_calculator = None
        if ctc_weight == 0.0:
            self.ctc = None
        else:
            self.ctc = ctc
        self.share_embedding = share_embedding
        if self.share_embedding:
            self.decoder.embed = None
        self.length_normalized_loss = length_normalized_loss
        self.beam_search = None
    def forward(
        self,
        speech: torch.Tensor,
        speech_lengths: torch.Tensor,
        text: torch.Tensor,
        text_lengths: torch.Tensor,
        **kwargs,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                text: (Batch, Length)
                text_lengths: (Batch,)
        """
        if len(text_lengths.size()) > 1:
            text_lengths = text_lengths[:, 0]
        if len(speech_lengths.size()) > 1:
            speech_lengths = speech_lengths[:, 0]
        batch_size = speech.shape[0]
        # 1. Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]
        loss_att, acc_att, cer_att, wer_att = None, None, None, None
        loss_ctc, cer_ctc = None, None
        stats = dict()
        # decoder: CTC branch
        if self.ctc_weight != 0.0:
            loss_ctc, cer_ctc = self._calc_ctc_loss(
                encoder_out, encoder_out_lens, text, text_lengths
            )
            # Collect CTC branch stats
            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
            stats["cer_ctc"] = cer_ctc
        # Intermediate CTC (optional)
        loss_interctc = 0.0
        if self.interctc_weight != 0.0 and intermediate_outs is not None:
            for layer_idx, intermediate_out in intermediate_outs:
                # we assume intermediate_out has the same length & padding
                # as those of encoder_out
                loss_ic, cer_ic = self._calc_ctc_loss(
                    intermediate_out, encoder_out_lens, text, text_lengths
                )
                loss_interctc = loss_interctc + loss_ic
                # Collect Intermedaite CTC stats
                stats["loss_interctc_layer{}".format(layer_idx)] = (
                    loss_ic.detach() if loss_ic is not None else None
                )
                stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
            loss_interctc = loss_interctc / len(intermediate_outs)
            # calculate whole encoder loss
            loss_ctc = (
                           1 - self.interctc_weight
                       ) * loss_ctc + self.interctc_weight * loss_interctc
        # decoder: Attention decoder branch
        loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
            encoder_out, encoder_out_lens, text, text_lengths
        )
        # 3. CTC-Att loss definition
        if self.ctc_weight == 0.0:
            loss = loss_att
        elif self.ctc_weight == 1.0:
            loss = loss_ctc
        else:
            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
        # Collect Attn branch stats
        stats["loss_att"] = loss_att.detach() if loss_att is not None else None
        stats["acc"] = acc_att
        stats["cer"] = cer_att
        stats["wer"] = wer_att
        # Collect total loss stats
        stats["loss"] = torch.clone(loss.detach())
        # force_gatherable: to-device and to-tensor if scalar for DataParallel
        if self.length_normalized_loss:
            batch_size = int((text_lengths + 1).sum())
        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
        return loss, stats, weight
    def encode(
        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
                ind: int
        """
        with autocast(False):
            # Data augmentation
            if self.specaug is not None and self.training:
                speech, speech_lengths = self.specaug(speech, speech_lengths)
            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
            if self.normalize is not None:
                speech, speech_lengths = self.normalize(speech, speech_lengths)
        # Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
        if self.encoder.interctc_use_conditioning:
            encoder_out, encoder_out_lens, _ = self.encoder(
                speech, speech_lengths, ctc=self.ctc
            )
        else:
            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
        intermediate_outs = None
        if isinstance(encoder_out, tuple):
            intermediate_outs = encoder_out[1]
            encoder_out = encoder_out[0]
        if intermediate_outs is not None:
            return (encoder_out, intermediate_outs), encoder_out_lens
        return encoder_out, encoder_out_lens
    def _calc_att_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ):
        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
        ys_in_lens = ys_pad_lens + 1
        # 1. Forward decoder
        decoder_out, _ = self.decoder(
            encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
        )
        # 2. Compute attention loss
        loss_att = self.criterion_att(decoder_out, ys_out_pad)
        acc_att = th_accuracy(
            decoder_out.view(-1, self.vocab_size),
            ys_out_pad,
            ignore_label=self.ignore_id,
        )
        # Compute cer/wer using attention-decoder
        if self.training or self.error_calculator is None:
            cer_att, wer_att = None, None
        else:
            ys_hat = decoder_out.argmax(dim=-1)
            cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
        return loss_att, acc_att, cer_att, wer_att
    def _calc_ctc_loss(
        self,
        encoder_out: torch.Tensor,
        encoder_out_lens: torch.Tensor,
        ys_pad: torch.Tensor,
        ys_pad_lens: torch.Tensor,
    ):
        # Calc CTC loss
        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
        # Calc CER using CTC
        cer_ctc = None
        if not self.training and self.error_calculator is not None:
            ys_hat = self.ctc.argmax(encoder_out).data
            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
        return loss_ctc, cer_ctc
    def init_beam_search(self,
                         **kwargs,
                         ):
        from funasr.models.transformer.search import BeamSearch
        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
        from funasr.models.transformer.scorers.length_bonus import LengthBonus
        # 1. Build ASR model
        scorers = {}
        if self.ctc != None:
            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
            scorers.update(
                ctc=ctc
            )
        token_list = kwargs.get("token_list")
        scorers.update(
            decoder=self.decoder,
            length_bonus=LengthBonus(len(token_list)),
        )
        # 3. Build ngram model
        # ngram is not supported now
        ngram = None
        scorers["ngram"] = ngram
        weights = dict(
            decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.3),
            ctc=kwargs.get("decoding_ctc_weight", 0.3),
            lm=kwargs.get("lm_weight", 0.0),
            ngram=kwargs.get("ngram_weight", 0.0),
            length_bonus=kwargs.get("penalty", 0.0),
        )
        beam_search = BeamSearch(
            beam_size=kwargs.get("beam_size", 20),
            weights=weights,
            scorers=scorers,
            sos=self.sos,
            eos=self.eos,
            vocab_size=len(token_list),
            token_list=token_list,
            pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
        )
        self.beam_search = beam_search
    def inference(self,
             data_in,
             data_lengths=None,
             key: list=None,
             tokenizer=None,
             frontend=None,
             **kwargs,
             ):
        if kwargs.get("batch_size", 1) > 1:
            raise NotImplementedError("batch decoding is not implemented")
        # init beamsearch
        if self.beam_search is None:
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)
        meta_data = {}
        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
            speech, speech_lengths = data_in, data_lengths
            if len(speech.shape) < 3:
                speech = speech[None, :, :]
            if speech_lengths is None:
                speech_lengths = speech.shape[1]
        else:
            # extract fbank feats
            time1 = time.perf_counter()
            sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
                                                            data_type=kwargs.get("data_type", "sound"),
                                                            tokenizer=tokenizer)
            time2 = time.perf_counter()
            meta_data["load_data"] = f"{time2 - time1:0.3f}"
            audio_sample_list = sample_list[0]
            if len(sample_list) >1:
                ocr_sample_list = sample_list[1]
            else:
                ocr_sample_list = [[294, 0]]
            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
                                                   frontend=frontend)
            time3 = time.perf_counter()
            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
            frame_shift = 10
            meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift / 1000
        speech = speech.to(device=kwargs["device"])
        speech_lengths = speech_lengths.to(device=kwargs["device"])
        # Encoder
        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]
        ocr_list_new = [[x + 1 if x != 0 else x for x in sublist] for sublist in ocr_sample_list]
        ocr = torch.tensor(ocr_list_new).to(device=kwargs["device"])
        ocr_lengths = ocr.new_full([1], dtype=torch.long, fill_value=ocr.size(1)).to(device=kwargs["device"])
        ocr, ocr_lens, _ = self.text_encoder(ocr, ocr_lengths)
        fusion_out, _, _, _ = self.fusion_encoder(encoder_out,None, ocr, None)
        encoder_out = encoder_out + fusion_out
        # c. Passed the encoder result and the beam search
        nbest_hyps = self.beam_search(
            x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
        )
        nbest_hyps = nbest_hyps[: self.nbest]
        results = []
        b, n, d = encoder_out.size()
        for i in range(b):
            for nbest_idx, hyp in enumerate(nbest_hyps):
                ibest_writer = None
                if kwargs.get("output_dir") is not None:
                    if not hasattr(self, "writer"):
                        self.writer = DatadirWriter(kwargs.get("output_dir"))
                    ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
                # remove sos/eos and get results
                last_pos = -1
                if isinstance(hyp.yseq, list):
                    token_int = hyp.yseq[1:last_pos]
                else:
                    token_int = hyp.yseq[1:last_pos].tolist()
                # remove blank symbol id, which is assumed to be 0
                token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
                # Change integer-ids to tokens
                token = tokenizer.ids2tokens(token_int)
                text = tokenizer.tokens2text(token)
                text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                result_i = {"key": key[i], "token": token, "text": text_postprocessed}
                results.append(result_i)
                if ibest_writer is not None:
                    ibest_writer["token"][key[i]] = " ".join(token)
                    ibest_writer["text"][key[i]] = text_postprocessed
        return results, meta_data
funasr/models/seaco_paraformer/model.py
@@ -30,7 +30,7 @@
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
import pdb
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
else:
@@ -128,7 +128,7 @@
        hotword_pad = kwargs.get("hotword_pad")
        hotword_lengths = kwargs.get("hotword_lengths")
        dha_pad = kwargs.get("dha_pad")
        batch_size = speech.shape[0]
        # for data-parallel
        text = text[:, : text_lengths.max()]
@@ -209,17 +209,20 @@
                               nfilter=50,
                               seaco_weight=1.0):
        # decoder forward
        decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True)
        decoder_pred = torch.log_softmax(decoder_out, dim=-1)
        if hw_list is not None:
            hw_lengths = [len(i) for i in hw_list]
            hw_list_ = [torch.Tensor(i).long() for i in hw_list]
            hw_list_pad = pad_list(hw_list_, 0).to(encoder_out.device)
            selected = self._hotword_representation(hw_list_pad, torch.Tensor(hw_lengths).int().to(encoder_out.device))
            contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device)
            num_hot_word = contextual_info.shape[1]
            _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device)
            # ASF Core
            if nfilter > 0 and nfilter < num_hot_word:
                hotword_scores = self.seaco_decoder.forward_asf6(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens)
@@ -239,7 +242,7 @@
            cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens)
            dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens)
            merged = self._merge(cif_attended, dec_attended)
            dha_output = self.hotword_output_layer(merged)  # remove the last token in loss calculation
            dha_pred = torch.log_softmax(dha_output, dim=-1)
            def _merge_res(dec_output, dha_output):
@@ -253,8 +256,8 @@
                # logits = dec_output * dha_mask + dha_output[:,:,:-1] * (1-dha_mask)
                logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask)
                return logits
            merged_pred = _merge_res(decoder_pred, dha_pred)
            # import pdb; pdb.set_trace()
            return merged_pred
        else:
            return decoder_pred
@@ -304,7 +307,6 @@
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
            self.nbest = kwargs.get("nbest", 1)
        meta_data = {}
        
        # extract fbank feats
@@ -330,6 +332,7 @@
        if isinstance(encoder_out, tuple):
            encoder_out = encoder_out[0]
        
        # predictor
        predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
        pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \
@@ -338,15 +341,14 @@
        if torch.max(pre_token_length) < 1:
            return []
        decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens,
                                                   pre_acoustic_embeds,
                                                   pre_token_length,
                                                   hw_list=self.hotword_list)
        # decoder_out, _ = decoder_outs[0], decoder_outs[1]
        _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
                                                                  pre_token_length)
        results = []
        b, n, d = decoder_out.size()
        for i in range(b):
funasr/train_utils/load_pretrained_model.py
@@ -7,7 +7,7 @@
import torch
import torch.nn
import torch.optim
import pdb
def filter_state_dict(
    dst_state: Dict[str, Union[float, torch.Tensor]],
@@ -63,6 +63,7 @@
    dst_state = obj.state_dict()
    
    print(f"ckpt: {path}")
    if oss_bucket is None:
        src_state = torch.load(path, map_location=map_location)
    else:
funasr/utils/load_utils.py
@@ -13,29 +13,25 @@
    from funasr.download.file import download_from_url
except:
    print("urllib is not installed, if you infer from url, please install it first.")
import pdb
def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None, **kwargs):
    if isinstance(data_or_path_or_list, (list, tuple)):
        if data_type is not None and isinstance(data_type, (list, tuple)):
            data_types = [data_type] * len(data_or_path_or_list)
            data_or_path_or_list_ret = [[] for d in data_type]
            for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)):
                for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)):
                    data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs)
                    data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
            return data_or_path_or_list_ret
        else:
            return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list]
    if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file
        data_or_path_or_list = download_from_url(data_or_path_or_list)
    if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
        if data_type is None or data_type == "sound":
            data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
@@ -56,10 +52,22 @@
        data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
    elif isinstance(data_or_path_or_list, np.ndarray):  # audio sample point
        data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze()  # [n_samples,]
    elif isinstance(data_or_path_or_list, str) and data_type == "kaldi_ark":
        data_mat = kaldiio.load_mat(data_or_path_or_list)
        if isinstance(data_mat, tuple):
            audio_fs, mat = data_mat
        else:
            mat = data_mat
        if mat.dtype == 'int16' or mat.dtype == 'int32':
            mat = mat.astype(np.float64)
            mat = mat / 32768
        if mat.ndim ==2:
            mat = mat[:,0]
        data_or_path_or_list = mat
    else:
        pass
        # print(f"unsupport data type: {data_or_path_or_list}, return raw data")
    if audio_fs != fs and data_type != "text":
        resampler = torchaudio.transforms.Resample(audio_fs, fs)
        data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
@@ -81,8 +89,6 @@
    return array
def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None, **kwargs):
    # import pdb;
    # pdb.set_trace()
    if isinstance(data, np.ndarray):
        data = torch.from_numpy(data)
        if len(data.shape) < 2:
@@ -100,9 +106,7 @@
            data_list.append(data_i)
            data_len.append(data_i.shape[0])
        data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
    # import pdb;
    # pdb.set_trace()
    # if data_type == "sound":
    data, data_len = frontend(data, data_len, **kwargs)
    
    if isinstance(data_len, (list, tuple)):