From 920331972a136834a560d78917de60f6c6623d96 Mon Sep 17 00:00:00 2001
From: 语帆 <yf352572@alibaba-inc.com>
Date: 星期一, 04 三月 2024 17:47:25 +0800
Subject: [PATCH] commit

---
 examples/industrial_data_pretraining/lcbnet/compute_wer_details.py  |  702 +++++++++++++++++++++
 funasr/auto/auto_model.py                                           |   12 
 funasr/models/contextual_paraformer/model.py                        |   39 
 examples/industrial_data_pretraining/contextual_paraformer/path.sh  |    6 
 examples/industrial_data_pretraining/contextual_paraformer/demo.py  |    0 
 funasr/models/lcbnet/model.py                                       |  495 +++++++++++++++
 funasr/utils/load_utils.py                                          |   28 
 examples/industrial_data_pretraining/lcbnet/demo.py                 |   13 
 funasr/models/lcbnet/attention.py                                   |  112 +++
 examples/industrial_data_pretraining/lcbnet/utils                   |    1 
 funasr/frontends/default.py                                         |   20 
 funasr/train_utils/load_pretrained_model.py                         |    3 
 funasr/models/lcbnet/__init__.py                                    |    0 
 examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh      |   11 
 funasr/models/seaco_paraformer/model.py                             |   18 
 examples/industrial_data_pretraining/contextual_paraformer/demo2.sh |    9 
 funasr/models/conformer/encoder.py                                  |    2 
 funasr/models/lcbnet/encoder.py                                     |  392 +++++++++++
 examples/industrial_data_pretraining/seaco_paraformer/demo.py       |   10 
 examples/industrial_data_pretraining/contextual_paraformer/demo.sh  |    2 
 examples/industrial_data_pretraining/lcbnet/demo.sh                 |   72 ++
 21 files changed, 1,886 insertions(+), 61 deletions(-)

diff --git a/examples/industrial_data_pretraining/contextual_paraformer/demo.py b/examples/industrial_data_pretraining/contextual_paraformer/demo.py
old mode 100644
new mode 100755
diff --git a/examples/industrial_data_pretraining/contextual_paraformer/demo.sh b/examples/industrial_data_pretraining/contextual_paraformer/demo.sh
old mode 100644
new mode 100755
index 8fc66f3..1bd4f7f
--- a/examples/industrial_data_pretraining/contextual_paraformer/demo.sh
+++ b/examples/industrial_data_pretraining/contextual_paraformer/demo.sh
@@ -2,7 +2,7 @@
 model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
 model_revision="v2.0.4"
 
-python funasr/bin/inference.py \
+python ../../../funasr/bin/inference.py \
 +model=${model} \
 +model_revision=${model_revision} \
 +input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \
diff --git a/examples/industrial_data_pretraining/contextual_paraformer/demo2.sh b/examples/industrial_data_pretraining/contextual_paraformer/demo2.sh
new file mode 100755
index 0000000..282f4f1
--- /dev/null
+++ b/examples/industrial_data_pretraining/contextual_paraformer/demo2.sh
@@ -0,0 +1,9 @@
+python -m funasr.bin.inference \
+--config-path="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \
+--config-name="config.yaml" \
+++init_param="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/model.pb" \
+++tokenizer_conf.token_list="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/tokens.txt" \
+++frontend_conf.cmvn_file="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/am.mvn" \
+++input="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/asr_example_zh.wav" \
+++output_dir="./outputs/debug2" \
+++device="" \
diff --git a/examples/industrial_data_pretraining/contextual_paraformer/path.sh b/examples/industrial_data_pretraining/contextual_paraformer/path.sh
new file mode 100755
index 0000000..1a6d67e
--- /dev/null
+++ b/examples/industrial_data_pretraining/contextual_paraformer/path.sh
@@ -0,0 +1,6 @@
+export FUNASR_DIR=$PWD/../../../
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
+export PYTHONPATH=$FUNASR_DIR/funasr/bin:$FUNASR_DIR/funasr:$FUNASR_DIR:$PYTHONPATH
diff --git a/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py b/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py
new file mode 100755
index 0000000..e72d871
--- /dev/null
+++ b/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py
@@ -0,0 +1,702 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+
+from enum import Enum
+import re, sys, unicodedata
+import codecs
+import argparse
+from tqdm import tqdm
+import os
+import pdb
+remove_tag = False
+spacelist = [" ", "\t", "\r", "\n"]
+puncts = [
+    "!",
+    ",",
+    "?",
+    "銆�",
+    "銆�",
+    "锛�",
+    "锛�",
+    "锛�",
+    "锛�",
+    "锛�",
+    "銆�",
+    "銆�",
+    "锔�",
+    "銆�",
+    "銆�",
+    "銆�",
+    "銆�",
+]
+
+
+class Code(Enum):
+    match = 1
+    substitution = 2
+    insertion = 3
+    deletion = 4
+
+
+class WordError(object):
+    def __init__(self):
+        self.errors = {
+            Code.substitution: 0,
+            Code.insertion: 0,
+            Code.deletion: 0,
+        }
+        self.ref_words = 0
+
+    def get_wer(self):
+        assert self.ref_words != 0
+        errors = (
+                self.errors[Code.substitution]
+                + self.errors[Code.insertion]
+                + self.errors[Code.deletion]
+        )
+        return 100.0 * errors / self.ref_words
+
+    def get_result_string(self):
+        return (
+            f"error_rate={self.get_wer():.4f}, "
+            f"ref_words={self.ref_words}, "
+            f"subs={self.errors[Code.substitution]}, "
+            f"ins={self.errors[Code.insertion]}, "
+            f"dels={self.errors[Code.deletion]}"
+        )
+
+
+def characterize(string):
+    res = []
+    i = 0
+    while i < len(string):
+        char = string[i]
+        if char in puncts:
+            i += 1
+            continue
+        cat1 = unicodedata.category(char)
+        # https://unicodebook.readthedocs.io/unicode.html#unicode-categories
+        if cat1 == "Zs" or cat1 == "Cn" or char in spacelist:  # space or not assigned
+            i += 1
+            continue
+        if cat1 == "Lo":  # letter-other
+            res.append(char)
+            i += 1
+        else:
+            # some input looks like: <unk><noise>, we want to separate it to two words.
+            sep = " "
+            if char == "<":
+                sep = ">"
+            j = i + 1
+            while j < len(string):
+                c = string[j]
+                if ord(c) >= 128 or (c in spacelist) or (c == sep):
+                    break
+                j += 1
+            if j < len(string) and string[j] == ">":
+                j += 1
+            res.append(string[i:j])
+            i = j
+    return res
+
+
+def stripoff_tags(x):
+    if not x:
+        return ""
+    chars = []
+    i = 0
+    T = len(x)
+    while i < T:
+        if x[i] == "<":
+            while i < T and x[i] != ">":
+                i += 1
+            i += 1
+        else:
+            chars.append(x[i])
+            i += 1
+    return "".join(chars)
+
+
+def normalize(sentence, ignore_words, cs, split=None):
+    """sentence, ignore_words are both in unicode"""
+    new_sentence = []
+    for token in sentence:
+        x = token
+        if not cs:
+            x = x.upper()
+        if x in ignore_words:
+            continue
+        if remove_tag:
+            x = stripoff_tags(x)
+        if not x:
+            continue
+        if split and x in split:
+            new_sentence += split[x]
+        else:
+            new_sentence.append(x)
+    return new_sentence
+
+
+class Calculator:
+    def __init__(self):
+        self.data = {}
+        self.space = []
+        self.cost = {}
+        self.cost["cor"] = 0
+        self.cost["sub"] = 1
+        self.cost["del"] = 1
+        self.cost["ins"] = 1
+
+    def calculate(self, lab, rec):
+        # Initialization
+        lab.insert(0, "")
+        rec.insert(0, "")
+        while len(self.space) < len(lab):
+            self.space.append([])
+        for row in self.space:
+            for element in row:
+                element["dist"] = 0
+                element["error"] = "non"
+            while len(row) < len(rec):
+                row.append({"dist": 0, "error": "non"})
+        for i in range(len(lab)):
+            self.space[i][0]["dist"] = i
+            self.space[i][0]["error"] = "del"
+        for j in range(len(rec)):
+            self.space[0][j]["dist"] = j
+            self.space[0][j]["error"] = "ins"
+        self.space[0][0]["error"] = "non"
+        for token in lab:
+            if token not in self.data and len(token) > 0:
+                self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
+        for token in rec:
+            if token not in self.data and len(token) > 0:
+                self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
+        # Computing edit distance
+        for i, lab_token in enumerate(lab):
+            for j, rec_token in enumerate(rec):
+                if i == 0 or j == 0:
+                    continue
+                min_dist = sys.maxsize
+                min_error = "none"
+                dist = self.space[i - 1][j]["dist"] + self.cost["del"]
+                error = "del"
+                if dist < min_dist:
+                    min_dist = dist
+                    min_error = error
+                dist = self.space[i][j - 1]["dist"] + self.cost["ins"]
+                error = "ins"
+                if dist < min_dist:
+                    min_dist = dist
+                    min_error = error
+                if lab_token == rec_token.replace("<BIAS>", ""):
+                    dist = self.space[i - 1][j - 1]["dist"] + self.cost["cor"]
+                    error = "cor"
+                else:
+                    dist = self.space[i - 1][j - 1]["dist"] + self.cost["sub"]
+                    error = "sub"
+                if dist < min_dist:
+                    min_dist = dist
+                    min_error = error
+                self.space[i][j]["dist"] = min_dist
+                self.space[i][j]["error"] = min_error
+        # Tracing back
+        result = {
+            "lab": [],
+            "rec": [],
+            "code": [],
+            "all": 0,
+            "cor": 0,
+            "sub": 0,
+            "ins": 0,
+            "del": 0,
+        }
+        i = len(lab) - 1
+        j = len(rec) - 1
+        while True:
+            if self.space[i][j]["error"] == "cor":  # correct
+                if len(lab[i]) > 0:
+                    self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
+                    self.data[lab[i]]["cor"] = self.data[lab[i]]["cor"] + 1
+                    result["all"] = result["all"] + 1
+                    result["cor"] = result["cor"] + 1
+                result["lab"].insert(0, lab[i])
+                result["rec"].insert(0, rec[j])
+                result["code"].insert(0, Code.match)
+                i = i - 1
+                j = j - 1
+            elif self.space[i][j]["error"] == "sub":  # substitution
+                if len(lab[i]) > 0:
+                    self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
+                    self.data[lab[i]]["sub"] = self.data[lab[i]]["sub"] + 1
+                    result["all"] = result["all"] + 1
+                    result["sub"] = result["sub"] + 1
+                result["lab"].insert(0, lab[i])
+                result["rec"].insert(0, rec[j])
+                result["code"].insert(0, Code.substitution)
+                i = i - 1
+                j = j - 1
+            elif self.space[i][j]["error"] == "del":  # deletion
+                if len(lab[i]) > 0:
+                    self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
+                    self.data[lab[i]]["del"] = self.data[lab[i]]["del"] + 1
+                    result["all"] = result["all"] + 1
+                    result["del"] = result["del"] + 1
+                result["lab"].insert(0, lab[i])
+                result["rec"].insert(0, "")
+                result["code"].insert(0, Code.deletion)
+                i = i - 1
+            elif self.space[i][j]["error"] == "ins":  # insertion
+                if len(rec[j]) > 0:
+                    self.data[rec[j]]["ins"] = self.data[rec[j]]["ins"] + 1
+                    result["ins"] = result["ins"] + 1
+                result["lab"].insert(0, "")
+                result["rec"].insert(0, rec[j])
+                result["code"].insert(0, Code.insertion)
+                j = j - 1
+            elif self.space[i][j]["error"] == "non":  # starting point
+                break
+            else:  # shouldn't reach here
+                print(
+                    "this should not happen , i = {i} , j = {j} , error = {error}".format(
+                        i=i, j=j, error=self.space[i][j]["error"]
+                    )
+                )
+        return result
+
+    def overall(self):
+        result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
+        for token in self.data:
+            result["all"] = result["all"] + self.data[token]["all"]
+            result["cor"] = result["cor"] + self.data[token]["cor"]
+            result["sub"] = result["sub"] + self.data[token]["sub"]
+            result["ins"] = result["ins"] + self.data[token]["ins"]
+            result["del"] = result["del"] + self.data[token]["del"]
+        return result
+
+    def cluster(self, data):
+        result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
+        for token in data:
+            if token in self.data:
+                result["all"] = result["all"] + self.data[token]["all"]
+                result["cor"] = result["cor"] + self.data[token]["cor"]
+                result["sub"] = result["sub"] + self.data[token]["sub"]
+                result["ins"] = result["ins"] + self.data[token]["ins"]
+                result["del"] = result["del"] + self.data[token]["del"]
+        return result
+
+    def keys(self):
+        return list(self.data.keys())
+
+
+def width(string):
+    return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
+
+
+def default_cluster(word):
+    unicode_names = [unicodedata.name(char) for char in word]
+    for i in reversed(range(len(unicode_names))):
+        if unicode_names[i].startswith("DIGIT"):  # 1
+            unicode_names[i] = "Number"  # 'DIGIT'
+        elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[
+            i
+        ].startswith("CJK COMPATIBILITY IDEOGRAPH"):
+            # 鏄� / 铯�
+            unicode_names[i] = "Mandarin"  # 'CJK IDEOGRAPH'
+        elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[
+            i
+        ].startswith("LATIN SMALL LETTER"):
+            # A / a
+            unicode_names[i] = "English"  # 'LATIN LETTER'
+        elif unicode_names[i].startswith("HIRAGANA LETTER"):  # 銇� 銇� 銈�
+            unicode_names[i] = "Japanese"  # 'GANA LETTER'
+        elif (
+                unicode_names[i].startswith("AMPERSAND")
+                or unicode_names[i].startswith("APOSTROPHE")
+                or unicode_names[i].startswith("COMMERCIAL AT")
+                or unicode_names[i].startswith("DEGREE CELSIUS")
+                or unicode_names[i].startswith("EQUALS SIGN")
+                or unicode_names[i].startswith("FULL STOP")
+                or unicode_names[i].startswith("HYPHEN-MINUS")
+                or unicode_names[i].startswith("LOW LINE")
+                or unicode_names[i].startswith("NUMBER SIGN")
+                or unicode_names[i].startswith("PLUS SIGN")
+                or unicode_names[i].startswith("SEMICOLON")
+        ):
+            # & / ' / @ / 鈩� / = / . / - / _ / # / + / ;
+            del unicode_names[i]
+        else:
+            return "Other"
+    if len(unicode_names) == 0:
+        return "Other"
+    if len(unicode_names) == 1:
+        return unicode_names[0]
+    for i in range(len(unicode_names) - 1):
+        if unicode_names[i] != unicode_names[i + 1]:
+            return "Other"
+    return unicode_names[0]
+
+
+def get_args():
+    parser = argparse.ArgumentParser(description="wer cal")
+    parser.add_argument("--ref", type=str, help="Text input path")
+    parser.add_argument("--ref_ocr", type=str, help="Text input path")
+    parser.add_argument("--rec_name", type=str, action="append", default=[])
+    parser.add_argument("--rec_file", type=str, action="append", default=[])
+    parser.add_argument("--verbose", type=int, default=1, help="show")
+    parser.add_argument("--char", type=bool, default=True, help="show")
+    args = parser.parse_args()
+    return args
+
+
+def main(args):
+    cluster_file = ""
+    ignore_words = set()
+    tochar = args.char
+    verbose = args.verbose
+    padding_symbol = " "
+    case_sensitive = False
+    max_words_per_line = sys.maxsize
+    split = None
+
+    if not case_sensitive:
+        ig = set([w.upper() for w in ignore_words])
+        ignore_words = ig
+
+    default_clusters = {}
+    default_words = {}
+    ref_file = args.ref
+    ref_ocr = args.ref_ocr
+    rec_files = args.rec_file
+    rec_names = args.rec_name
+    assert len(rec_files) == len(rec_names)
+
+    # load ocr
+    ref_ocr_dict = {}
+    with codecs.open(ref_ocr, "r", "utf-8") as fh:
+        for line in fh:
+            if "$" in line:
+                line = line.replace("$", " ")
+            if tochar:
+                array = characterize(line)
+            else:
+                array = line.strip().split()
+            if len(array) == 0:
+                continue
+            fid = array[0]
+            ref_ocr_dict[fid] = normalize(array[1:], ignore_words, case_sensitive, split)
+
+    if split and not case_sensitive:
+        newsplit = dict()
+        for w in split:
+            words = split[w]
+            for i in range(len(words)):
+                words[i] = words[i].upper()
+            newsplit[w.upper()] = words
+        split = newsplit
+
+    rec_sets = {}
+    calculators_dict = dict()
+    ub_wer_dict = dict()
+    hotwords_related_dict = dict()  # 璁板綍recall鐩稿叧鐨勫唴瀹�
+    for i, hyp_file in enumerate(rec_files):
+        rec_sets[rec_names[i]] = dict()
+        with codecs.open(hyp_file, "r", "utf-8") as fh:
+            for line in fh:
+                if tochar:
+                    array = characterize(line)
+                else:
+                    array = line.strip().split()
+                if len(array) == 0:
+                    continue
+                fid = array[0]
+                rec_sets[rec_names[i]][fid] = normalize(array[1:], ignore_words, case_sensitive, split)
+
+        calculators_dict[rec_names[i]] = Calculator()
+        ub_wer_dict[rec_names[i]] = {"u_wer": WordError(), "b_wer": WordError(), "wer": WordError()}
+        hotwords_related_dict[rec_names[i]] = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}
+        # tp: 鐑瘝鍦╨abel閲岋紝鍚屾椂鍦╮ec閲�
+        # tn: 鐑瘝涓嶅湪label閲岋紝鍚屾椂涓嶅湪rec閲�
+        # fp: 鐑瘝涓嶅湪label閲岋紝浣嗘槸鍦╮ec閲�
+        # fn: 鐑瘝鍦╨abel閲岋紝浣嗘槸涓嶅湪rec閲�
+
+    # record wrong label but in ocr
+    wrong_rec_but_in_ocr_dict = {}
+    for rec_name in rec_names:
+        wrong_rec_but_in_ocr_dict[rec_name] = 0
+
+    _file_total_len = 0
+    with os.popen("cat {} | wc -l".format(ref_file)) as pipe:
+        _file_total_len = int(pipe.read().strip())
+
+    # compute error rate on the interaction of reference file and hyp file
+    for line in tqdm(open(ref_file, 'r', encoding='utf-8'), total=_file_total_len):
+        if tochar:
+            array = characterize(line)
+        else:
+            array = line.rstrip('\n').split()
+        if len(array) == 0: continue
+        fid = array[0]
+        lab = normalize(array[1:], ignore_words, case_sensitive, split)
+
+        if verbose:
+            print('\nutt: %s' % fid)
+
+        ocr_text = ref_ocr_dict[fid]
+        ocr_set = set(ocr_text)
+        print('ocr: {}'.format(" ".join(ocr_text)))
+        list_match = []  # 鎸噇abel閲岄潰鍦╫cr閲岄潰鐨勫唴瀹�
+        list_not_mathch = []
+        tmp_error = 0
+        tmp_match = 0
+        for index in range(len(lab)):
+            # text_list.append(uttlist[index+1])
+            if lab[index] not in ocr_set:
+                tmp_error += 1
+                list_not_mathch.append(lab[index])
+            else:
+                tmp_match += 1
+                list_match.append(lab[index])
+        print('label in ocr: {}'.format(" ".join(list_match)))
+
+        # for each reco file
+        base_wrong_ocr_wer = None
+        ocr_wrong_ocr_wer = None
+
+        for rec_name in rec_names:
+            rec_set = rec_sets[rec_name]
+            if fid not in rec_set:
+                continue
+            rec = rec_set[fid]
+
+            # print(rec)
+            for word in rec + lab:
+                if word not in default_words:
+                    default_cluster_name = default_cluster(word)
+                    if default_cluster_name not in default_clusters:
+                        default_clusters[default_cluster_name] = {}
+                    if word not in default_clusters[default_cluster_name]:
+                        default_clusters[default_cluster_name][word] = 1
+                    default_words[word] = default_cluster_name
+
+            result = calculators_dict[rec_name].calculate(lab.copy(), rec.copy())
+            if verbose:
+                if result['all'] != 0:
+                    wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
+                else:
+                    wer = 0.0
+            print('WER(%s): %4.2f %%' % (rec_name, wer), end=' ')
+            print('N=%d C=%d S=%d D=%d I=%d' %
+                  (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
+
+
+            # print(result['rec'])
+            wrong_rec_but_in_ocr = []
+            for idx in range(len(result['lab'])):
+                if result['lab'][idx] != "":
+                    if result['lab'][idx] != result['rec'][idx].replace("<BIAS>", ""):
+                        if result['lab'][idx] in list_match:
+                            wrong_rec_but_in_ocr.append(result['lab'][idx])
+                            wrong_rec_but_in_ocr_dict[rec_name] += 1
+            print('wrong_rec_but_in_ocr: {}'.format(" ".join(wrong_rec_but_in_ocr)))
+
+            if rec_name == "base":
+                base_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
+            if "ocr" in rec_name or "hot" in rec_name:
+                ocr_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
+                if ocr_wrong_ocr_wer < base_wrong_ocr_wer:
+                    print("{} {} helps, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
+                elif ocr_wrong_ocr_wer > base_wrong_ocr_wer:
+                    print("{} {} hurts, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
+
+            # recall = 0
+            # false_alarm = 0
+            # for idx in range(len(result['lab'])):
+            #     if "<BIAS>" in result['rec'][idx]:
+            #         if result['rec'][idx].replace("<BIAS>", "") in list_match:
+            #             recall += 1
+            #         else:
+            #             false_alarm += 1
+            # print("bias hotwords recall: {}, fa: {}, list_match {}, recall: {:.2f}, fa: {:.2f}".format(
+            #     recall, false_alarm, len(list_match), recall / len(list_match) if len(list_match) != 0 else 0, false_alarm / len(list_match) if len(list_match) != 0 else 0
+            # ))
+            # tp: 鐑瘝鍦╨abel閲岋紝鍚屾椂鍦╮ec閲�
+            # tn: 鐑瘝涓嶅湪label閲岋紝鍚屾椂涓嶅湪rec閲�
+            # fp: 鐑瘝涓嶅湪label閲岋紝浣嗘槸鍦╮ec閲�
+            # fn: 鐑瘝鍦╨abel閲岋紝浣嗘槸涓嶅湪rec閲�
+            _rec_list = [word.replace("<BIAS>", "") for word in rec]
+            _label_list = [word for word in lab]
+            _tp = _tn = _fp = _fn = 0
+            hot_true_list = [hotword for hotword in ocr_text if hotword in _label_list]
+            hot_bad_list = [hotword for hotword in ocr_text if hotword not in _label_list]
+            for badhotword in hot_bad_list:
+                count = len([word for word in _rec_list if word == badhotword])
+                # print(f"bad {badhotword} count: {count}")
+                # for word in _rec_list:
+                #     if badhotword == word:
+                #         count += 1
+                if count == 0:
+                    hotwords_related_dict[rec_name]['tn'] += 1
+                    _tn += 1
+                    # fp: 0
+                else:
+                    hotwords_related_dict[rec_name]['fp'] += count
+                    _fp += count
+                    # tn: 0
+                # if badhotword in _rec_list:
+                #     hotwords_related_dict[rec_name]['fp'] += 1
+                # else:
+                #     hotwords_related_dict[rec_name]['tn'] += 1
+            for hotword in hot_true_list:
+                true_count = len([word for word in _label_list if hotword == word])
+                rec_count = len([word for word in _rec_list if hotword == word])
+                # print(f"good {hotword} true_count: {true_count}, rec_count: {rec_count}")
+                if rec_count == true_count:
+                    hotwords_related_dict[rec_name]['tp'] += true_count
+                    _tp += true_count
+                elif rec_count > true_count:
+                    hotwords_related_dict[rec_name]['tp'] += true_count
+                    # fp: 涓嶅湪label閲岋紝浣嗘槸鍦╮ec閲�
+                    hotwords_related_dict[rec_name]['fp'] += rec_count - true_count
+                    _tp += true_count
+                    _fp += rec_count - true_count
+                else:
+                    hotwords_related_dict[rec_name]['tp'] += rec_count
+                    # fn: 鐑瘝鍦╨abel閲岋紝浣嗘槸涓嶅湪rec閲�
+                    hotwords_related_dict[rec_name]['fn'] += true_count - rec_count
+                    _tp += rec_count
+                    _fn += true_count - rec_count
+            print("hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
+                _tp, _tn, _fp, _fn, sum([_tp, _tn, _fp, _fn]), _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0
+            ))
+
+            # if hotword in _rec_list:
+            #     hotwords_related_dict[rec_name]['tp'] += 1
+            # else:
+            #     hotwords_related_dict[rec_name]['fn'] += 1
+            # 璁$畻uwer, bwer, wer
+            for code, rec_word, lab_word in zip(result["code"], result["rec"], result["lab"]):
+                if code == Code.match:
+                    ub_wer_dict[rec_name]["wer"].ref_words += 1
+                    if lab_word in hot_true_list:
+                        # tmp_ref.append(ref_tokens[ref_idx])
+                        ub_wer_dict[rec_name]["b_wer"].ref_words += 1
+                    else:
+                        ub_wer_dict[rec_name]["u_wer"].ref_words += 1
+                elif code == Code.substitution:
+                    ub_wer_dict[rec_name]["wer"].ref_words += 1
+                    ub_wer_dict[rec_name]["wer"].errors[Code.substitution] += 1
+                    if lab_word in hot_true_list:
+                        # tmp_ref.append(ref_tokens[ref_idx])
+                        ub_wer_dict[rec_name]["b_wer"].ref_words += 1
+                        ub_wer_dict[rec_name]["b_wer"].errors[Code.substitution] += 1
+                    else:
+                        ub_wer_dict[rec_name]["u_wer"].ref_words += 1
+                        ub_wer_dict[rec_name]["u_wer"].errors[Code.substitution] += 1
+                elif code == Code.deletion:
+                    ub_wer_dict[rec_name]["wer"].ref_words += 1
+                    ub_wer_dict[rec_name]["wer"].errors[Code.deletion] += 1
+                    if lab_word in hot_true_list:
+                        # tmp_ref.append(ref_tokens[ref_idx])
+                        ub_wer_dict[rec_name]["b_wer"].ref_words += 1
+                        ub_wer_dict[rec_name]["b_wer"].errors[Code.deletion] += 1
+                    else:
+                        ub_wer_dict[rec_name]["u_wer"].ref_words += 1
+                        ub_wer_dict[rec_name]["u_wer"].errors[Code.deletion] += 1
+                elif code == Code.insertion:
+                    ub_wer_dict[rec_name]["wer"].errors[Code.insertion] += 1
+                    if rec_word in hot_true_list:
+                        ub_wer_dict[rec_name]["b_wer"].errors[Code.insertion] += 1
+                    else:
+                        ub_wer_dict[rec_name]["u_wer"].errors[Code.insertion] += 1
+
+            space = {}
+            space['lab'] = []
+            space['rec'] = []
+            for idx in range(len(result['lab'])):
+                len_lab = width(result['lab'][idx])
+                len_rec = width(result['rec'][idx])
+                length = max(len_lab, len_rec)
+                space['lab'].append(length - len_lab)
+                space['rec'].append(length - len_rec)
+            upper_lab = len(result['lab'])
+            upper_rec = len(result['rec'])
+            lab1, rec1 = 0, 0
+            while lab1 < upper_lab or rec1 < upper_rec:
+                if verbose > 1:
+                    print('lab(%s):' % fid.encode('utf-8'), end=' ')
+                else:
+                    print('lab:', end=' ')
+                lab2 = min(upper_lab, lab1 + max_words_per_line)
+                for idx in range(lab1, lab2):
+                    token = result['lab'][idx]
+                    print('{token}'.format(token=token), end='')
+                    for n in range(space['lab'][idx]):
+                        print(padding_symbol, end='')
+                    print(' ', end='')
+                print()
+                if verbose > 1:
+                    print('rec(%s):' % fid.encode('utf-8'), end=' ')
+                else:
+                    print('rec:', end=' ')
+
+                rec2 = min(upper_rec, rec1 + max_words_per_line)
+                for idx in range(rec1, rec2):
+                    token = result['rec'][idx]
+                    print('{token}'.format(token=token), end='')
+                    for n in range(space['rec'][idx]):
+                        print(padding_symbol, end='')
+                    print(' ', end='')
+                print()
+                # print('\n', end='\n')
+                lab1 = lab2
+                rec1 = rec2
+        print('\n', end='\n')
+        # break
+    if verbose:
+        print('===========================================================================')
+        print()
+
+    print(wrong_rec_but_in_ocr_dict)
+    for rec_name in rec_names:
+        result = calculators_dict[rec_name].overall()
+
+        if result['all'] != 0:
+            wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
+        else:
+            wer = 0.0
+        print('{} Overall -> {:4.2f} %'.format(rec_name, wer), end=' ')
+        print('N=%d C=%d S=%d D=%d I=%d' %
+              (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
+        print(f"WER: {ub_wer_dict[rec_name]['wer'].get_result_string()}")
+        print(f"U-WER: {ub_wer_dict[rec_name]['u_wer'].get_result_string()}")
+        print(f"B-WER: {ub_wer_dict[rec_name]['b_wer'].get_result_string()}")
+
+        print('hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%'.format(
+            hotwords_related_dict[rec_name]['tp'],
+            hotwords_related_dict[rec_name]['tn'],
+            hotwords_related_dict[rec_name]['fp'],
+            hotwords_related_dict[rec_name]['fn'],
+            sum([v for k, v in hotwords_related_dict[rec_name].items()]),
+            hotwords_related_dict[rec_name]['tp'] / (
+                    hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn']
+            ) * 100 if hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] != 0 else 0
+        ))
+
+        # tp: 鐑瘝鍦╨abel閲岋紝鍚屾椂鍦╮ec閲�
+        # tn: 鐑瘝涓嶅湪label閲岋紝鍚屾椂涓嶅湪rec閲�
+        # fp: 鐑瘝涓嶅湪label閲岋紝浣嗘槸鍦╮ec閲�
+        # fn: 鐑瘝鍦╨abel閲岋紝浣嗘槸涓嶅湪rec閲�
+        if not verbose:
+            print()
+        print()
+
+
+if __name__ == "__main__":
+    args = get_args()
+    
+    # print("")
+    print(args)
+    main(args)
+
diff --git a/examples/industrial_data_pretraining/lcbnet/demo.py b/examples/industrial_data_pretraining/lcbnet/demo.py
new file mode 100755
index 0000000..4ca5255
--- /dev/null
+++ b/examples/industrial_data_pretraining/lcbnet/demo.py
@@ -0,0 +1,13 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+from funasr import AutoModel
+
+model = AutoModel(model="iic/LCB-NET",
+                  model_revision="v1.0.0")
+
+res = model.generate(input=("https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/asr_example.wav","https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/ocr.txt"),data_type=("sound", "text"))
+
+print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/lcbnet/demo.sh b/examples/industrial_data_pretraining/lcbnet/demo.sh
new file mode 100755
index 0000000..2f226bc
--- /dev/null
+++ b/examples/industrial_data_pretraining/lcbnet/demo.sh
@@ -0,0 +1,72 @@
+file_dir="/home/yf352572/.cache/modelscope/hub/iic/LCB-NET/"
+CUDA_VISIBLE_DEVICES="0,1"
+inference_device="cuda"
+
+if [ ${inference_device} == "cuda" ]; then
+    nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+else
+    inference_batch_size=1
+    CUDA_VISIBLE_DEVICES=""
+    for JOB in $(seq ${nj}); do
+        CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1,"
+    done
+fi
+
+inference_dir="outputs/slidespeech_dev"
+_logdir="${inference_dir}/logdir"
+echo "inference_dir: ${inference_dir}"
+
+mkdir -p "${_logdir}"
+key_file1=${file_dir}/dev/wav.scp
+key_file2=${file_dir}/dev/ocr.txt
+split_scps1=
+split_scps2=
+for JOB in $(seq "${nj}"); do
+    split_scps1+=" ${_logdir}/wav.${JOB}.scp"
+    split_scps2+=" ${_logdir}/ocr.${JOB}.txt"
+done
+utils/split_scp.pl "${key_file1}" ${split_scps1}
+utils/split_scp.pl "${key_file2}" ${split_scps2}
+
+gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ })
+for JOB in $(seq ${nj}); do
+    {
+        id=$((JOB-1))
+        gpuid=${gpuid_list_array[$id]}
+
+        export CUDA_VISIBLE_DEVICES=${gpuid}
+
+        python -m funasr.bin.inference \
+        --config-path=${file_dir} \
+        --config-name="config.yaml" \
+        ++init_param=${file_dir}/model.pt \
+        ++tokenizer_conf.token_list=${file_dir}/tokens.txt \
+        ++input=[${_logdir}/wav.${JOB}.scp,${_logdir}/ocr.${JOB}.txt] \
+        +data_type='["kaldi_ark", "text"]' \
+        ++tokenizer_conf.bpemodel=${file_dir}/bpe.pt \
+        ++normalize_conf.stats_file=${file_dir}/am.mvn \
+        ++output_dir="${inference_dir}/${JOB}" \
+        ++device="${inference_device}" \
+        ++ncpu=1 \
+        ++disable_log=true  &> ${_logdir}/log.${JOB}.txt
+
+    }&
+done
+wait
+
+
+mkdir -p ${inference_dir}/1best_recog
+
+for JOB in $(seq "${nj}"); do
+   cat "${inference_dir}/${JOB}/1best_recog/token" >> "${inference_dir}/1best_recog/token"
+done  
+
+echo "Computing WER ..."
+sed -e 's/ /\t/' -e 's/ //g' -e 's/鈻�/ /g' -e 's/\t /\t/'  ${inference_dir}/1best_recog/token > ${inference_dir}/1best_recog/token.proc
+cp  ${file_dir}/dev/text ${inference_dir}/1best_recog/token.ref
+cp  ${file_dir}/dev/ocr.list ${inference_dir}/1best_recog/ocr.list
+python utils/compute_wer.py ${inference_dir}/1best_recog/token.ref ${inference_dir}/1best_recog/token.proc ${inference_dir}/1best_recog/token.cer
+tail -n 3 ${inference_dir}/1best_recog/token.cer
+
+./run_bwer_recall.sh  ${inference_dir}/1best_recog/
+tail -n 6 ${inference_dir}/1best_recog/BWER-UWER.results |head -n 5
diff --git a/examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh b/examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh
new file mode 100755
index 0000000..7d6b6ff
--- /dev/null
+++ b/examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh
@@ -0,0 +1,11 @@
+#now_result_name=asr_conformer_acc1_lr002_warm20000/decode_asr_asr_model_valid.acc.ave
+#hotword_type=ocr_1ngram_top10_hotwords_list
+hot_exp_suf=$1
+
+
+python compute_wer_details.py --v 1 \
+   --ref ${hot_exp_suf}/token.ref \
+   --ref_ocr ${hot_exp_suf}/ocr.list  \
+   --rec_name base \
+   --rec_file ${hot_exp_suf}/token.proc \
+   > ${hot_exp_suf}/BWER-UWER.results
diff --git a/examples/industrial_data_pretraining/lcbnet/utils b/examples/industrial_data_pretraining/lcbnet/utils
new file mode 120000
index 0000000..be5e5a3
--- /dev/null
+++ b/examples/industrial_data_pretraining/lcbnet/utils
@@ -0,0 +1 @@
+../../aishell/paraformer/utils
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
index a44c649..551dd8b 100644
--- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py
+++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
@@ -7,10 +7,10 @@
 
 model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                   model_revision="v2.0.4",
-                  vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
-                  vad_model_revision="v2.0.4",
-                  punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
-                  punc_model_revision="v2.0.4",
+                  # vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+                  # vad_model_revision="v2.0.4",
+                  # punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
+                  # punc_model_revision="v2.0.4",
                   # spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
                   # spk_model_revision="v2.0.2",
                   )
@@ -43,4 +43,4 @@
 wav_file = os.path.join(model.model_path, "example/asr_example.wav")
 speech, sample_rate = soundfile.read(wav_file)
 res = model.generate(input=[speech], batch_size_s=300, is_final=True)
-'''
\ No newline at end of file
+'''
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 921ede8..ec3c3f3 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -28,7 +28,7 @@
     from funasr.models.campplus.cluster_backend import ClusterBackend
 except:
     print("If you want to use the speaker diarization, please `pip install hdbscan`")
-
+import pdb
 
 def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
     """
@@ -46,6 +46,7 @@
     chars = string.ascii_letters + string.digits
     if isinstance(data_in, str) and data_in.startswith('http'): # url
         data_in = download_from_url(data_in)
+
     if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
         _, file_extension = os.path.splitext(data_in)
         file_extension = file_extension.lower()
@@ -146,7 +147,7 @@
             kwargs = download_model(**kwargs)
         
         set_all_random_seed(kwargs.get("seed", 0))
-        
+
         device = kwargs.get("device", "cuda")
         if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
             device = "cpu"
@@ -168,7 +169,6 @@
             vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
         else:
             vocab_size = -1
-        
         # build frontend
         frontend = kwargs.get("frontend", None)
         kwargs["input_size"] = None
@@ -181,7 +181,6 @@
         # build model
         model_class = tables.model_classes.get(kwargs["model"])
         model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
-        
         model.to(device)
         
         # init_param
@@ -224,9 +223,9 @@
         batch_size = kwargs.get("batch_size", 1)
         # if kwargs.get("device", "cpu") == "cpu":
         #     batch_size = 1
-        
+
         key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key)
-        
+
         speed_stats = {}
         asr_result_list = []
         num_samples = len(data_list)
@@ -239,6 +238,7 @@
             data_batch = data_list[beg_idx:end_idx]
             key_batch = key_list[beg_idx:end_idx]
             batch = {"data_in": data_batch, "key": key_batch}
+
             if (end_idx - beg_idx) == 1 and kwargs.get("data_type", None) == "fbank": # fbank
                 batch["data_in"] = data_batch[0]
                 batch["data_lengths"] = input_len
diff --git a/funasr/frontends/default.py b/funasr/frontends/default.py
index 8ac1ca8..c4bdbd7 100644
--- a/funasr/frontends/default.py
+++ b/funasr/frontends/default.py
@@ -3,7 +3,6 @@
 from typing import Tuple
 from typing import Union
 import logging
-import humanfriendly
 import numpy as np
 import torch
 import torch.nn as nn
@@ -16,8 +15,10 @@
 from funasr.frontends.utils.stft import Stft
 from funasr.frontends.utils.frontend import Frontend
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.register import tables
 
 
+@tables.register("frontend_classes", "DefaultFrontend")
 class DefaultFrontend(nn.Module):
     """Conventional frontend structure for ASR.
     Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
@@ -25,7 +26,7 @@
 
     def __init__(
             self,
-            fs: Union[int, str] = 16000,
+            fs: int = 16000,
             n_fft: int = 512,
             win_length: int = None,
             hop_length: int = 128,
@@ -40,14 +41,14 @@
             frontend_conf: Optional[dict] = None,
             apply_stft: bool = True,
             use_channel: int = None,
+            **kwargs,
     ):
         super().__init__()
-        if isinstance(fs, str):
-            fs = humanfriendly.parse_size(fs)
 
         # Deepcopy (In general, dict shouldn't be used as default arg)
         frontend_conf = copy.deepcopy(frontend_conf)
         self.hop_length = hop_length
+        self.fs = fs
 
         if apply_stft:
             self.stft = Stft(
@@ -84,8 +85,12 @@
         return self.n_mels
 
     def forward(
-            self, input: torch.Tensor, input_lengths: torch.Tensor
+            self, input: torch.Tensor, input_lengths:  Union[torch.Tensor, list]
     ) -> Tuple[torch.Tensor, torch.Tensor]:
+        if isinstance(input_lengths, list):
+            input_lengths = torch.tensor(input_lengths)
+        if  input.dtype == torch.float64:
+            input = input.float()
         # 1. Domain-conversion: e.g. Stft: time -> time-freq
         if self.stft is not None:
             input_stft, feats_lens = self._compute_stft(input, input_lengths)
@@ -145,7 +150,7 @@
 
     def __init__(
             self,
-            fs: Union[int, str] = 16000,
+            fs: int = 16000,
             n_fft: int = 512,
             win_length: int = None,
             hop_length: int = None,
@@ -168,9 +173,6 @@
             mc: bool = True
     ):
         super().__init__()
-        if isinstance(fs, str):
-            fs = humanfriendly.parse_size(fs)
-
         # Deepcopy (In general, dict shouldn't be used as default arg)
         frontend_conf = copy.deepcopy(frontend_conf)
         if win_length is None and hop_length is None:
diff --git a/funasr/models/conformer/encoder.py b/funasr/models/conformer/encoder.py
index 1d252c2..be973c6 100644
--- a/funasr/models/conformer/encoder.py
+++ b/funasr/models/conformer/encoder.py
@@ -47,7 +47,7 @@
 from funasr.models.transformer.utils.subsampling import Conv2dSubsamplingPad
 from funasr.models.transformer.utils.subsampling import StreamingConvInput
 from funasr.register import tables
-
+import pdb
 
 class ConvolutionModule(nn.Module):
     """ConvolutionModule in Conformer model.
diff --git a/funasr/models/contextual_paraformer/model.py b/funasr/models/contextual_paraformer/model.py
index 49868a8..7d6f729 100644
--- a/funasr/models/contextual_paraformer/model.py
+++ b/funasr/models/contextual_paraformer/model.py
@@ -29,7 +29,7 @@
 from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
 from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-
+import pdb
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
@@ -62,7 +62,6 @@
         crit_attn_weight = kwargs.get("crit_attn_weight", 0.0)
         crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0)
         bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)
-
 
         if bias_encoder_type == 'lstm':
             self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate)
@@ -103,17 +102,16 @@
             text_lengths = text_lengths[:, 0]
         if len(speech_lengths.size()) > 1:
             speech_lengths = speech_lengths[:, 0]
-        
+
         batch_size = speech.shape[0]
 
         hotword_pad = kwargs.get("hotword_pad")
         hotword_lengths = kwargs.get("hotword_lengths")
         dha_pad = kwargs.get("dha_pad")
-        
+
         # 1. Encoder
         encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
 
-        
         loss_ctc, cer_ctc = None, None
         
         stats = dict()
@@ -128,12 +126,11 @@
             stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
             stats["cer_ctc"] = cer_ctc
         
-
         # 2b. Attention decoder branch
         loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
             encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths
         )
-        
+
         # 3. CTC-Att loss definition
         if self.ctc_weight == 0.0:
             loss = loss_att + loss_pre * self.predictor_weight
@@ -171,22 +168,24 @@
     ):
         encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
             encoder_out.device)
+
         if self.predictor_bias == 1:
             _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
             ys_pad_lens = ys_pad_lens + self.predictor_bias
+
         pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
                                                                      ignore_id=self.ignore_id)
-        
         # -1. bias encoder
         if self.use_decoder_embedding:
             hw_embed = self.decoder.embed(hotword_pad)
         else:
             hw_embed = self.bias_embed(hotword_pad)
+
         hw_embed, (_, _) = self.bias_encoder(hw_embed)
         _ind = np.arange(0, hotword_pad.shape[0]).tolist()
         selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]]
         contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
-        
+
         # 0. sampler
         decoder_out_1st = None
         if self.sampling_ratio > 0.0:
@@ -195,7 +194,7 @@
                                                            pre_acoustic_embeds, contextual_info)
         else:
             sematic_embeds = pre_acoustic_embeds
-        
+
         # 1. Forward decoder
         decoder_outs = self.decoder(
             encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
@@ -211,7 +210,7 @@
             loss_ideal = None
         '''
         loss_ideal = None
-        
+
         if decoder_out_1st is None:
             decoder_out_1st = decoder_out
         # 2. Compute attention loss
@@ -288,10 +287,11 @@
                                                                enforce_sorted=False)
             _, (h_n, _) = self.bias_encoder(hw_embed)
             hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
-        
+
         decoder_outs = self.decoder(
             encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
         )
+
         decoder_out = decoder_outs[0]
         decoder_out = torch.log_softmax(decoder_out, dim=-1)
         return decoder_out, ys_pad_lens
@@ -305,38 +305,42 @@
                  **kwargs,
                  ):
         # init beamsearch
+
         is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
         is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
         if self.beam_search is None and (is_use_lm or is_use_ctc):
             logging.info("enable beam_search")
             self.init_beam_search(**kwargs)
             self.nbest = kwargs.get("nbest", 1)
-        
+
         meta_data = {}
         
         # extract fbank feats
         time1 = time.perf_counter()
+
         audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
+
         time2 = time.perf_counter()
         meta_data["load_data"] = f"{time2 - time1:0.3f}"
+
         speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
                                                frontend=frontend)
         time3 = time.perf_counter()
         meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
         meta_data[
             "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
-        
+
         speech = speech.to(device=kwargs["device"])
         speech_lengths = speech_lengths.to(device=kwargs["device"])
 
         # hotword
         self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend)
-        
+
         # Encoder
         encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
         if isinstance(encoder_out, tuple):
             encoder_out = encoder_out[0]
-        
+
         # predictor
         predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
         pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
@@ -344,8 +348,7 @@
         pre_token_length = pre_token_length.round().long()
         if torch.max(pre_token_length) < 1:
             return []
-
-
+        
         decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens,
                                                                  pre_acoustic_embeds,
                                                                  pre_token_length,
diff --git a/funasr/models/lcbnet/__init__.py b/funasr/models/lcbnet/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/lcbnet/__init__.py
diff --git a/funasr/models/lcbnet/attention.py b/funasr/models/lcbnet/attention.py
new file mode 100644
index 0000000..8e8c594
--- /dev/null
+++ b/funasr/models/lcbnet/attention.py
@@ -0,0 +1,112 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2024 yufan
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Multi-Head Attention Return Weight layer definition."""
+
+import math
+
+import torch
+from torch import nn
+
+class MultiHeadedAttentionReturnWeight(nn.Module):
+    """Multi-Head Attention layer.
+
+    Args:
+        n_head (int): The number of heads.
+        n_feat (int): The number of features.
+        dropout_rate (float): Dropout rate.
+
+    """
+
+    def __init__(self, n_head, n_feat, dropout_rate):
+        """Construct an MultiHeadedAttentionReturnWeight object."""
+        super(MultiHeadedAttentionReturnWeight, self).__init__()
+        assert n_feat % n_head == 0
+        # We assume d_v always equals d_k
+        self.d_k = n_feat // n_head
+        self.h = n_head
+        self.linear_q = nn.Linear(n_feat, n_feat)
+        self.linear_k = nn.Linear(n_feat, n_feat)
+        self.linear_v = nn.Linear(n_feat, n_feat)
+        self.linear_out = nn.Linear(n_feat, n_feat)
+        self.attn = None
+        self.dropout = nn.Dropout(p=dropout_rate)
+
+    def forward_qkv(self, query, key, value):
+        """Transform query, key and value.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+
+        Returns:
+            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+        """
+        n_batch = query.size(0)
+        q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
+        k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
+        v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
+        q = q.transpose(1, 2)  # (batch, head, time1, d_k)
+        k = k.transpose(1, 2)  # (batch, head, time2, d_k)
+        v = v.transpose(1, 2)  # (batch, head, time2, d_k)
+
+        return q, k, v
+
+    def forward_attention(self, value, scores, mask):
+        """Compute attention context vector.
+
+        Args:
+            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Transformed value (#batch, time1, d_model)
+                weighted by the attention score (#batch, time1, time2).
+
+        """
+        n_batch = value.size(0)
+        if mask is not None:
+            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
+            min_value = torch.finfo(scores.dtype).min
+            scores = scores.masked_fill(mask, min_value)
+            self.attn = torch.softmax(scores, dim=-1).masked_fill(
+                mask, 0.0
+            )  # (batch, head, time1, time2)
+        else:
+            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
+
+        p_attn = self.dropout(self.attn)
+        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
+        x = (
+            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+        )  # (batch, time1, d_model)
+
+        return self.linear_out(x), self.attn  # (batch, time1, d_model)
+
+    def forward(self, query, key, value, mask):
+        """Compute scaled dot product attention.
+
+        Args:
+            query (torch.Tensor): Query tensor (#batch, time1, size).
+            key (torch.Tensor): Key tensor (#batch, time2, size).
+            value (torch.Tensor): Value tensor (#batch, time2, size).
+            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+                (#batch, time1, time2).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time1, d_model).
+
+        """
+        q, k, v = self.forward_qkv(query, key, value)
+        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+        return self.forward_attention(v, scores, mask)
+
+
diff --git a/funasr/models/lcbnet/encoder.py b/funasr/models/lcbnet/encoder.py
new file mode 100644
index 0000000..c65823c
--- /dev/null
+++ b/funasr/models/lcbnet/encoder.py
@@ -0,0 +1,392 @@
+# Copyright 2019 Shigeki Karita
+#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Transformer encoder definition."""
+
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+import torch
+from torch import nn
+import logging
+
+from funasr.models.transformer.attention import MultiHeadedAttention
+from funasr.models.lcbnet.attention import MultiHeadedAttentionReturnWeight
+from funasr.models.transformer.embedding import PositionalEncoding
+from funasr.models.transformer.layer_norm import LayerNorm
+
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.register import tables
+
+class EncoderLayer(nn.Module):
+    """Encoder layer module.
+
+    Args:
+        size (int): Input dimension.
+        self_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
+            can be used as the argument.
+        feed_forward (torch.nn.Module): Feed-forward module instance.
+            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+            can be used as the argument.
+        dropout_rate (float): Dropout rate.
+        normalize_before (bool): Whether to use layer_norm before the first block.
+        concat_after (bool): Whether to concat attention layer's input and output.
+            if True, additional linear will be applied.
+            i.e. x -> x + linear(concat(x, att(x)))
+            if False, no additional linear will be applied. i.e. x -> x + att(x)
+        stochastic_depth_rate (float): Proability to skip this layer.
+            During training, the layer may skip residual computation and return input
+            as-is with given probability.
+    """
+
+    def __init__(
+            self,
+            size,
+            self_attn,
+            feed_forward,
+            dropout_rate,
+            normalize_before=True,
+            concat_after=False,
+            stochastic_depth_rate=0.0,
+    ):
+        """Construct an EncoderLayer object."""
+        super(EncoderLayer, self).__init__()
+        self.self_attn = self_attn
+        self.feed_forward = feed_forward
+        self.norm1 = LayerNorm(size)
+        self.norm2 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.size = size
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        if self.concat_after:
+            self.concat_linear = nn.Linear(size + size, size)
+        self.stochastic_depth_rate = stochastic_depth_rate
+
+    def forward(self, x, mask, cache=None):
+        """Compute encoded features.
+
+        Args:
+            x_input (torch.Tensor): Input tensor (#batch, time, size).
+            mask (torch.Tensor): Mask tensor for the input (#batch, time).
+            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor (#batch, time, size).
+            torch.Tensor: Mask tensor (#batch, time).
+
+        """
+        skip_layer = False
+        # with stochastic depth, residual connection `x + f(x)` becomes
+        # `x <- x + 1 / (1 - p) * f(x)` at training time.
+        stoch_layer_coeff = 1.0
+        if self.training and self.stochastic_depth_rate > 0:
+            skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
+            stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
+
+        if skip_layer:
+            if cache is not None:
+                x = torch.cat([cache, x], dim=1)
+            return x, mask
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm1(x)
+
+        if cache is None:
+            x_q = x
+        else:
+            assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
+            x_q = x[:, -1:, :]
+            residual = residual[:, -1:, :]
+            mask = None if mask is None else mask[:, -1:, :]
+
+        if self.concat_after:
+            x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
+            x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
+        else:
+            x = residual + stoch_layer_coeff * self.dropout(
+                self.self_attn(x_q, x, x, mask)
+            )
+        if not self.normalize_before:
+            x = self.norm1(x)
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm2(x)
+        x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
+        if not self.normalize_before:
+            x = self.norm2(x)
+
+        if cache is not None:
+            x = torch.cat([cache, x], dim=1)
+
+        return x, mask
+
+@tables.register("encoder_classes", "TransformerTextEncoder")
+class TransformerTextEncoder(nn.Module):
+    """Transformer text encoder module.
+
+    Args:
+        input_size: input dim
+        output_size: dimension of attention
+        attention_heads: the number of heads of multi head attention
+        linear_units: the number of units of position-wise feed forward
+        num_blocks: the number of decoder blocks
+        dropout_rate: dropout rate
+        attention_dropout_rate: dropout rate in attention
+        positional_dropout_rate: dropout rate after adding positional encoding
+        input_layer: input layer type
+        pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
+        normalize_before: whether to use layer_norm before the first block
+        concat_after: whether to concat attention layer's input and output
+            if True, additional linear will be applied.
+            i.e. x -> x + linear(concat(x, att(x)))
+            if False, no additional linear will be applied.
+            i.e. x -> x + att(x)
+        positionwise_layer_type: linear of conv1d
+        positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
+        padding_idx: padding_idx for input_layer=embed
+    """
+
+    def __init__(
+            self,
+            input_size: int,
+            output_size: int = 256,
+            attention_heads: int = 4,
+            linear_units: int = 2048,
+            num_blocks: int = 6,
+            dropout_rate: float = 0.1,
+            positional_dropout_rate: float = 0.1,
+            attention_dropout_rate: float = 0.0,
+            pos_enc_class=PositionalEncoding,
+            normalize_before: bool = True,
+            concat_after: bool = False,
+    ):
+        super().__init__()
+        self._output_size = output_size
+
+        self.embed = torch.nn.Sequential(
+            torch.nn.Embedding(input_size, output_size),
+            pos_enc_class(output_size, positional_dropout_rate),
+        )
+
+        self.normalize_before = normalize_before
+
+        positionwise_layer = PositionwiseFeedForward
+        positionwise_layer_args = (
+            output_size,
+            linear_units,
+            dropout_rate,
+        )
+        self.encoders = repeat(
+            num_blocks,
+            lambda lnum: EncoderLayer(
+                output_size,
+                MultiHeadedAttention(
+                    attention_heads, output_size, attention_dropout_rate
+                ),
+                positionwise_layer(*positionwise_layer_args),
+                dropout_rate,
+                normalize_before,
+                concat_after,
+            ),
+        )
+        if self.normalize_before:
+            self.after_norm = LayerNorm(output_size)
+
+    def output_size(self) -> int:
+        return self._output_size
+
+    def forward(
+            self,
+            xs_pad: torch.Tensor,
+            ilens: torch.Tensor,
+    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+        """Embed positions in tensor.
+
+        Args:
+            xs_pad: input tensor (B, L, D)
+            ilens: input length (B)
+        Returns:
+            position embedded tensor and mask
+        """
+        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+        xs_pad = self.embed(xs_pad)
+
+        xs_pad, masks = self.encoders(xs_pad, masks)
+
+        if self.normalize_before:
+            xs_pad = self.after_norm(xs_pad)
+
+        olens = masks.squeeze(1).sum(1)
+        return xs_pad, olens, None
+
+
+
+
+@tables.register("encoder_classes", "FusionSANEncoder")
+class SelfSrcAttention(nn.Module):
+    """Single decoder layer module.
+
+    Args:
+        size (int): Input dimension.
+        self_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` instance can be used as the argument.
+        src_attn (torch.nn.Module): Self-attention module instance.
+            `MultiHeadedAttention` instance can be used as the argument.
+        feed_forward (torch.nn.Module): Feed-forward module instance.
+            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+            can be used as the argument.
+        dropout_rate (float): Dropout rate.
+        normalize_before (bool): Whether to use layer_norm before the first block.
+        concat_after (bool): Whether to concat attention layer's input and output.
+            if True, additional linear will be applied.
+            i.e. x -> x + linear(concat(x, att(x)))
+            if False, no additional linear will be applied. i.e. x -> x + att(x)
+
+
+    """
+    def __init__(
+            self,
+            size,
+            attention_heads,
+            attention_dim,
+            linear_units,
+            self_attention_dropout_rate,
+            src_attention_dropout_rate,
+            positional_dropout_rate,
+            dropout_rate,
+            normalize_before=True,
+            concat_after=False,
+    ):
+        """Construct an SelfSrcAttention object."""
+        super(SelfSrcAttention, self).__init__()
+        self.size = size
+        self.self_attn = MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate)
+        self.src_attn = MultiHeadedAttentionReturnWeight(attention_heads, attention_dim, src_attention_dropout_rate)
+        self.feed_forward = PositionwiseFeedForward(attention_dim, linear_units, positional_dropout_rate)
+        self.norm1 = LayerNorm(size)
+        self.norm2 = LayerNorm(size)
+        self.norm3 = LayerNorm(size)
+        self.dropout = nn.Dropout(dropout_rate)
+        self.normalize_before = normalize_before
+        self.concat_after = concat_after
+        if self.concat_after:
+            self.concat_linear1 = nn.Linear(size + size, size)
+            self.concat_linear2 = nn.Linear(size + size, size)
+
+    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
+        """Compute decoded features.
+
+        Args:
+            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+            cache (List[torch.Tensor]): List of cached tensors.
+                Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+        Returns:
+            torch.Tensor: Output tensor(#batch, maxlen_out, size).
+            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+        """
+        residual = tgt
+        if self.normalize_before:
+            tgt = self.norm1(tgt)
+
+        if cache is None:
+            tgt_q = tgt
+            tgt_q_mask = tgt_mask
+        else:
+            # compute only the last frame query keeping dim: max_time_out -> 1
+            assert cache.shape == (
+                tgt.shape[0],
+                tgt.shape[1] - 1,
+                self.size,
+            ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
+            tgt_q = tgt[:, -1:, :]
+            residual = residual[:, -1:, :]
+            tgt_q_mask = None
+            if tgt_mask is not None:
+                tgt_q_mask = tgt_mask[:, -1:, :]
+
+        if self.concat_after:
+            tgt_concat = torch.cat(
+                (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
+            )
+            x = residual + self.concat_linear1(tgt_concat)
+        else:
+            x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
+        if not self.normalize_before:
+            x = self.norm1(x)
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm2(x)
+        if self.concat_after:
+            x_concat = torch.cat(
+                (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
+            )
+            x = residual + self.concat_linear2(x_concat)
+        else:
+            x, score = self.src_attn(x, memory, memory, memory_mask)
+            x = residual + self.dropout(x)
+        if not self.normalize_before:
+            x = self.norm2(x)
+
+        residual = x
+        if self.normalize_before:
+            x = self.norm3(x)
+        x = residual + self.dropout(self.feed_forward(x))
+        if not self.normalize_before:
+            x = self.norm3(x)
+
+        if cache is not None:
+            x = torch.cat([cache, x], dim=1)
+
+        return x, tgt_mask, memory, memory_mask
+
+
+@tables.register("encoder_classes", "ConvBiasPredictor")
+class ConvPredictor(nn.Module):
+    def __init__(self, size=256, l_order=3, r_order=3, attention_heads=4, attention_dropout_rate=0.1, linear_units=2048):
+        super().__init__()
+        self.atten = MultiHeadedAttention(attention_heads, size, attention_dropout_rate)
+        self.norm1 = LayerNorm(size)
+        self.feed_forward = PositionwiseFeedForward(size, linear_units, attention_dropout_rate)
+        self.norm2 = LayerNorm(size)
+        self.pad = nn.ConstantPad1d((l_order, r_order), 0)
+        self.conv1d = nn.Conv1d(size, size, l_order + r_order + 1, groups=size)
+        self.output_linear = nn.Linear(size, 1)
+
+
+    def forward(self, text_enc, asr_enc):
+        # stage1 cross-attention
+        residual = text_enc
+        text_enc = residual + self.atten(text_enc, asr_enc, asr_enc, None)
+        
+        # stage2 FFN
+        residual = text_enc
+        text_enc = self.norm1(text_enc)
+        text_enc = residual + self.feed_forward(text_enc)
+        
+        # stage Conv predictor
+        text_enc = self.norm2(text_enc)
+        context = text_enc.transpose(1, 2)
+        queries = self.pad(context)
+        memory = self.conv1d(queries)
+        output = memory + context
+        output = output.transpose(1, 2)
+        output = torch.relu(output)
+        output = self.output_linear(output)
+        if output.dim()==3:
+          output = output.squeeze(2)
+        return output
diff --git a/funasr/models/lcbnet/model.py b/funasr/models/lcbnet/model.py
new file mode 100644
index 0000000..3ac319c
--- /dev/null
+++ b/funasr/models/lcbnet/model.py
@@ -0,0 +1,495 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+#  MIT License  (https://opensource.org/licenses/MIT)
+
+import logging
+from typing import Union, Dict, List, Tuple, Optional
+
+import time
+import torch
+import torch.nn as nn
+from torch.cuda.amp import autocast
+
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.ctc.ctc import CTC
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.metrics.compute_acc import th_accuracy
+# from funasr.models.e2e_asr_common import ErrorCalculator
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.register import tables
+
+import pdb
+@tables.register("model_classes", "LCBNet")
+class LCBNet(nn.Module):
+    """
+    Author: Speech Lab of DAMO Academy, Alibaba Group
+    LCB-NET: LONG-CONTEXT BIASING FOR AUDIO-VISUAL SPEECH RECOGNITION
+    https://arxiv.org/abs/2401.06390
+    """
+    
+    def __init__(
+        self,
+        specaug: str = None,
+        specaug_conf: dict = None,
+        normalize: str = None,
+        normalize_conf: dict = None,
+        encoder: str = None,
+        encoder_conf: dict = None,
+        decoder: str = None,
+        decoder_conf: dict = None,
+        text_encoder: str = None,
+        text_encoder_conf: dict = None,
+        bias_predictor: str = None,
+        bias_predictor_conf: dict = None,
+        fusion_encoder: str = None,
+        fusion_encoder_conf: dict = None,
+        ctc: str = None,
+        ctc_conf: dict = None,
+        ctc_weight: float = 0.5,
+        interctc_weight: float = 0.0,
+        select_num: int = 2,
+        select_length: int = 3,
+        insert_blank: bool = True,
+        input_size: int = 80,
+        vocab_size: int = -1,
+        ignore_id: int = -1,
+        blank_id: int = 0,
+        sos: int = 1,
+        eos: int = 2,
+        lsm_weight: float = 0.0,
+        length_normalized_loss: bool = False,
+        report_cer: bool = True,
+        report_wer: bool = True,
+        sym_space: str = "<space>",
+        sym_blank: str = "<blank>",
+        # extract_feats_in_collect_stats: bool = True,
+        share_embedding: bool = False,
+        # preencoder: Optional[AbsPreEncoder] = None,
+        # postencoder: Optional[AbsPostEncoder] = None,
+        **kwargs,
+    ):
+
+        super().__init__()
+
+        if specaug is not None:
+            specaug_class = tables.specaug_classes.get(specaug)
+            specaug = specaug_class(**specaug_conf)
+        if normalize is not None:
+            normalize_class = tables.normalize_classes.get(normalize)
+            normalize = normalize_class(**normalize_conf)
+        encoder_class = tables.encoder_classes.get(encoder)
+        encoder = encoder_class(input_size=input_size, **encoder_conf)
+        encoder_output_size = encoder.output_size()
+
+        # lcbnet modules: text encoder, fusion encoder and bias predictor
+        text_encoder_class = tables.encoder_classes.get(text_encoder)
+        text_encoder = text_encoder_class(input_size=vocab_size, **text_encoder_conf)
+        fusion_encoder_class = tables.encoder_classes.get(fusion_encoder)
+        fusion_encoder = fusion_encoder_class(**fusion_encoder_conf)
+        bias_predictor_class = tables.encoder_classes.get(bias_predictor)
+        bias_predictor = bias_predictor_class(**bias_predictor_conf)
+
+
+        if decoder is not None:
+            decoder_class = tables.decoder_classes.get(decoder)
+            decoder = decoder_class(
+                vocab_size=vocab_size,
+                encoder_output_size=encoder_output_size,
+                **decoder_conf,
+            )
+        if ctc_weight > 0.0:
+            
+            if ctc_conf is None:
+                ctc_conf = {}
+            
+            ctc = CTC(
+                odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+            )
+    
+        self.blank_id = blank_id
+        self.sos = vocab_size - 1
+        self.eos = vocab_size - 1
+        self.vocab_size = vocab_size
+        self.ignore_id = ignore_id
+        self.ctc_weight = ctc_weight
+        self.specaug = specaug
+        self.normalize = normalize
+        self.encoder = encoder
+        # lcbnet
+        self.text_encoder = text_encoder
+        self.fusion_encoder = fusion_encoder
+        self.bias_predictor = bias_predictor
+        self.select_num = select_num
+        self.select_length = select_length
+        self.insert_blank = insert_blank
+
+        if not hasattr(self.encoder, "interctc_use_conditioning"):
+            self.encoder.interctc_use_conditioning = False
+        if self.encoder.interctc_use_conditioning:
+            self.encoder.conditioning_layer = torch.nn.Linear(
+                vocab_size, self.encoder.output_size()
+            )
+        self.interctc_weight = interctc_weight
+
+        # self.error_calculator = None
+        if ctc_weight == 1.0:
+            self.decoder = None
+        else:
+            self.decoder = decoder
+        
+        self.criterion_att = LabelSmoothingLoss(
+            size=vocab_size,
+            padding_idx=ignore_id,
+            smoothing=lsm_weight,
+            normalize_length=length_normalized_loss,
+        )
+        #
+        # if report_cer or report_wer:
+        #     self.error_calculator = ErrorCalculator(
+        #         token_list, sym_space, sym_blank, report_cer, report_wer
+        #     )
+        #
+        self.error_calculator = None
+        if ctc_weight == 0.0:
+            self.ctc = None
+        else:
+            self.ctc = ctc
+            
+        self.share_embedding = share_embedding
+        if self.share_embedding:
+            self.decoder.embed = None
+        
+        self.length_normalized_loss = length_normalized_loss
+        self.beam_search = None
+    
+    def forward(
+        self,
+        speech: torch.Tensor,
+        speech_lengths: torch.Tensor,
+        text: torch.Tensor,
+        text_lengths: torch.Tensor,
+        **kwargs,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Encoder + Decoder + Calc loss
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                text: (Batch, Length)
+                text_lengths: (Batch,)
+        """
+
+        if len(text_lengths.size()) > 1:
+            text_lengths = text_lengths[:, 0]
+        if len(speech_lengths.size()) > 1:
+            speech_lengths = speech_lengths[:, 0]
+        
+        batch_size = speech.shape[0]
+        
+        # 1. Encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        intermediate_outs = None
+        if isinstance(encoder_out, tuple):
+            intermediate_outs = encoder_out[1]
+            encoder_out = encoder_out[0]
+        
+        loss_att, acc_att, cer_att, wer_att = None, None, None, None
+        loss_ctc, cer_ctc = None, None
+        stats = dict()
+        
+        # decoder: CTC branch
+        if self.ctc_weight != 0.0:
+            loss_ctc, cer_ctc = self._calc_ctc_loss(
+                encoder_out, encoder_out_lens, text, text_lengths
+            )
+            
+            # Collect CTC branch stats
+            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+            stats["cer_ctc"] = cer_ctc
+        
+        # Intermediate CTC (optional)
+        loss_interctc = 0.0
+        if self.interctc_weight != 0.0 and intermediate_outs is not None:
+            for layer_idx, intermediate_out in intermediate_outs:
+                # we assume intermediate_out has the same length & padding
+                # as those of encoder_out
+                loss_ic, cer_ic = self._calc_ctc_loss(
+                    intermediate_out, encoder_out_lens, text, text_lengths
+                )
+                loss_interctc = loss_interctc + loss_ic
+                
+                # Collect Intermedaite CTC stats
+                stats["loss_interctc_layer{}".format(layer_idx)] = (
+                    loss_ic.detach() if loss_ic is not None else None
+                )
+                stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
+            
+            loss_interctc = loss_interctc / len(intermediate_outs)
+            
+            # calculate whole encoder loss
+            loss_ctc = (
+                           1 - self.interctc_weight
+                       ) * loss_ctc + self.interctc_weight * loss_interctc
+        
+        # decoder: Attention decoder branch
+        loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
+            encoder_out, encoder_out_lens, text, text_lengths
+        )
+        
+        # 3. CTC-Att loss definition
+        if self.ctc_weight == 0.0:
+            loss = loss_att
+        elif self.ctc_weight == 1.0:
+            loss = loss_ctc
+        else:
+            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
+        
+        # Collect Attn branch stats
+        stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+        stats["acc"] = acc_att
+        stats["cer"] = cer_att
+        stats["wer"] = wer_att
+        
+        # Collect total loss stats
+        stats["loss"] = torch.clone(loss.detach())
+        
+        # force_gatherable: to-device and to-tensor if scalar for DataParallel
+        if self.length_normalized_loss:
+            batch_size = int((text_lengths + 1).sum())
+        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+        return loss, stats, weight
+    
+
+    def encode(
+        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        """Frontend + Encoder. Note that this method is used by asr_inference.py
+        Args:
+                speech: (Batch, Length, ...)
+                speech_lengths: (Batch, )
+                ind: int
+        """
+        with autocast(False):
+            # Data augmentation
+            if self.specaug is not None and self.training:
+                speech, speech_lengths = self.specaug(speech, speech_lengths)
+            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                speech, speech_lengths = self.normalize(speech, speech_lengths)
+        # Forward encoder
+        # feats: (Batch, Length, Dim)
+        # -> encoder_out: (Batch, Length2, Dim2)
+        if self.encoder.interctc_use_conditioning:
+            encoder_out, encoder_out_lens, _ = self.encoder(
+                speech, speech_lengths, ctc=self.ctc
+            )
+        else:
+            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+        intermediate_outs = None
+        if isinstance(encoder_out, tuple):
+            intermediate_outs = encoder_out[1]
+            encoder_out = encoder_out[0]
+        
+        if intermediate_outs is not None:
+            return (encoder_out, intermediate_outs), encoder_out_lens
+        return encoder_out, encoder_out_lens
+    
+    def _calc_att_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+    ):
+        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+        ys_in_lens = ys_pad_lens + 1
+        
+        # 1. Forward decoder
+        decoder_out, _ = self.decoder(
+            encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
+        )
+        
+        # 2. Compute attention loss
+        loss_att = self.criterion_att(decoder_out, ys_out_pad)
+        acc_att = th_accuracy(
+            decoder_out.view(-1, self.vocab_size),
+            ys_out_pad,
+            ignore_label=self.ignore_id,
+        )
+        
+        # Compute cer/wer using attention-decoder
+        if self.training or self.error_calculator is None:
+            cer_att, wer_att = None, None
+        else:
+            ys_hat = decoder_out.argmax(dim=-1)
+            cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+        
+        return loss_att, acc_att, cer_att, wer_att
+    
+    def _calc_ctc_loss(
+        self,
+        encoder_out: torch.Tensor,
+        encoder_out_lens: torch.Tensor,
+        ys_pad: torch.Tensor,
+        ys_pad_lens: torch.Tensor,
+    ):
+        # Calc CTC loss
+        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+        
+        # Calc CER using CTC
+        cer_ctc = None
+        if not self.training and self.error_calculator is not None:
+            ys_hat = self.ctc.argmax(encoder_out).data
+            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+        return loss_ctc, cer_ctc
+    
+    def init_beam_search(self,
+                         **kwargs,
+                         ):
+        from funasr.models.transformer.search import BeamSearch
+        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
+        from funasr.models.transformer.scorers.length_bonus import LengthBonus
+    
+        # 1. Build ASR model
+        scorers = {}
+        
+        if self.ctc != None:
+            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
+            scorers.update(
+                ctc=ctc
+            )
+        token_list = kwargs.get("token_list")
+        scorers.update(
+            decoder=self.decoder,
+            length_bonus=LengthBonus(len(token_list)),
+        )
+
+        
+        # 3. Build ngram model
+        # ngram is not supported now
+        ngram = None
+        scorers["ngram"] = ngram
+        
+        weights = dict(
+            decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.3),
+            ctc=kwargs.get("decoding_ctc_weight", 0.3),
+            lm=kwargs.get("lm_weight", 0.0),
+            ngram=kwargs.get("ngram_weight", 0.0),
+            length_bonus=kwargs.get("penalty", 0.0),
+        )
+        beam_search = BeamSearch(
+            beam_size=kwargs.get("beam_size", 20),
+            weights=weights,
+            scorers=scorers,
+            sos=self.sos,
+            eos=self.eos,
+            vocab_size=len(token_list),
+            token_list=token_list,
+            pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
+        )
+
+        self.beam_search = beam_search
+        
+    def inference(self,
+             data_in,
+             data_lengths=None,
+             key: list=None,
+             tokenizer=None,
+             frontend=None,
+             **kwargs,
+             ):
+        
+        if kwargs.get("batch_size", 1) > 1:
+            raise NotImplementedError("batch decoding is not implemented")
+        
+        # init beamsearch
+        if self.beam_search is None:
+            logging.info("enable beam_search")
+            self.init_beam_search(**kwargs)
+            self.nbest = kwargs.get("nbest", 1)
+
+        meta_data = {}
+        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
+            speech, speech_lengths = data_in, data_lengths
+            if len(speech.shape) < 3:
+                speech = speech[None, :, :]
+            if speech_lengths is None:
+                speech_lengths = speech.shape[1]
+        else:
+            # extract fbank feats
+            time1 = time.perf_counter()
+            sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
+                                                            data_type=kwargs.get("data_type", "sound"),
+                                                            tokenizer=tokenizer)
+            time2 = time.perf_counter()
+            meta_data["load_data"] = f"{time2 - time1:0.3f}"
+            audio_sample_list = sample_list[0]
+            if len(sample_list) >1:
+                ocr_sample_list = sample_list[1]
+            else:
+                ocr_sample_list = [[294, 0]]
+            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+                                                   frontend=frontend)
+            time3 = time.perf_counter()
+            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+            frame_shift = 10 
+            meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift / 1000
+
+        speech = speech.to(device=kwargs["device"])
+        speech_lengths = speech_lengths.to(device=kwargs["device"])
+        # Encoder
+        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+        if isinstance(encoder_out, tuple):
+            encoder_out = encoder_out[0]
+
+        ocr_list_new = [[x + 1 if x != 0 else x for x in sublist] for sublist in ocr_sample_list]
+        ocr = torch.tensor(ocr_list_new).to(device=kwargs["device"])
+        ocr_lengths = ocr.new_full([1], dtype=torch.long, fill_value=ocr.size(1)).to(device=kwargs["device"])
+        ocr, ocr_lens, _ = self.text_encoder(ocr, ocr_lengths)
+        fusion_out, _, _, _ = self.fusion_encoder(encoder_out,None, ocr, None)
+        encoder_out = encoder_out + fusion_out
+        # c. Passed the encoder result and the beam search
+        nbest_hyps = self.beam_search(
+            x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
+        )
+        
+        nbest_hyps = nbest_hyps[: self.nbest]
+
+        results = []
+        b, n, d = encoder_out.size()
+        for i in range(b):
+
+            for nbest_idx, hyp in enumerate(nbest_hyps):
+                ibest_writer = None
+                if kwargs.get("output_dir") is not None:
+                    if not hasattr(self, "writer"):
+                        self.writer = DatadirWriter(kwargs.get("output_dir"))
+                    ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
+                    
+                # remove sos/eos and get results
+                last_pos = -1
+                if isinstance(hyp.yseq, list):
+                    token_int = hyp.yseq[1:last_pos]
+                else:
+                    token_int = hyp.yseq[1:last_pos].tolist()
+                    
+                # remove blank symbol id, which is assumed to be 0
+                token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+                
+                # Change integer-ids to tokens
+                token = tokenizer.ids2tokens(token_int)
+                text = tokenizer.tokens2text(token)
+
+                text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+                result_i = {"key": key[i], "token": token, "text": text_postprocessed}
+                results.append(result_i)
+                
+                if ibest_writer is not None:
+                    ibest_writer["token"][key[i]] = " ".join(token)
+                    ibest_writer["text"][key[i]] = text_postprocessed
+        
+        return results, meta_data
+
diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py
index 20b0cc8..a8b1f1f 100644
--- a/funasr/models/seaco_paraformer/model.py
+++ b/funasr/models/seaco_paraformer/model.py
@@ -30,7 +30,7 @@
 from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
 
-
+import pdb
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
 else:
@@ -128,7 +128,7 @@
         hotword_pad = kwargs.get("hotword_pad")
         hotword_lengths = kwargs.get("hotword_lengths")
         dha_pad = kwargs.get("dha_pad")
-
+        
         batch_size = speech.shape[0]
         # for data-parallel
         text = text[:, : text_lengths.max()]
@@ -209,17 +209,20 @@
                                nfilter=50,
                                seaco_weight=1.0):
         # decoder forward
+
         decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True)
+
         decoder_pred = torch.log_softmax(decoder_out, dim=-1)
         if hw_list is not None:
             hw_lengths = [len(i) for i in hw_list]
             hw_list_ = [torch.Tensor(i).long() for i in hw_list]
             hw_list_pad = pad_list(hw_list_, 0).to(encoder_out.device)
             selected = self._hotword_representation(hw_list_pad, torch.Tensor(hw_lengths).int().to(encoder_out.device))
+
             contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device)
             num_hot_word = contextual_info.shape[1]
             _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device)
-            
+
             # ASF Core
             if nfilter > 0 and nfilter < num_hot_word:
                 hotword_scores = self.seaco_decoder.forward_asf6(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens)
@@ -239,7 +242,7 @@
             cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens)
             dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens)
             merged = self._merge(cif_attended, dec_attended)
-            
+
             dha_output = self.hotword_output_layer(merged)  # remove the last token in loss calculation
             dha_pred = torch.log_softmax(dha_output, dim=-1)
             def _merge_res(dec_output, dha_output):
@@ -253,8 +256,8 @@
                 # logits = dec_output * dha_mask + dha_output[:,:,:-1] * (1-dha_mask)
                 logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask)
                 return logits
+
             merged_pred = _merge_res(decoder_pred, dha_pred)
-            # import pdb; pdb.set_trace()
             return merged_pred
         else:
             return decoder_pred
@@ -304,7 +307,6 @@
             logging.info("enable beam_search")
             self.init_beam_search(**kwargs)
             self.nbest = kwargs.get("nbest", 1)
-        
         meta_data = {}
         
         # extract fbank feats
@@ -330,6 +332,7 @@
         if isinstance(encoder_out, tuple):
             encoder_out = encoder_out[0]
         
+
         # predictor
         predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
         pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \
@@ -338,15 +341,14 @@
         if torch.max(pre_token_length) < 1:
             return []
 
-
         decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens,
                                                    pre_acoustic_embeds,
                                                    pre_token_length,
                                                    hw_list=self.hotword_list)
+
         # decoder_out, _ = decoder_outs[0], decoder_outs[1]
         _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
                                                                   pre_token_length)
-        
         results = []
         b, n, d = decoder_out.size()
         for i in range(b):
diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py
index ea23725..0c46449 100644
--- a/funasr/train_utils/load_pretrained_model.py
+++ b/funasr/train_utils/load_pretrained_model.py
@@ -7,7 +7,7 @@
 import torch
 import torch.nn
 import torch.optim
-
+import pdb
 
 def filter_state_dict(
 	dst_state: Dict[str, Union[float, torch.Tensor]],
@@ -63,6 +63,7 @@
 	dst_state = obj.state_dict()
 	
 	print(f"ckpt: {path}")
+
 	if oss_bucket is None:
 		src_state = torch.load(path, map_location=map_location)
 	else:
diff --git a/funasr/utils/load_utils.py b/funasr/utils/load_utils.py
index 7748172..84c38f9 100644
--- a/funasr/utils/load_utils.py
+++ b/funasr/utils/load_utils.py
@@ -13,29 +13,25 @@
     from funasr.download.file import download_from_url
 except:
     print("urllib is not installed, if you infer from url, please install it first.")
-
+import pdb
 
 
 def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None, **kwargs):
     if isinstance(data_or_path_or_list, (list, tuple)):
         if data_type is not None and isinstance(data_type, (list, tuple)):
-
             data_types = [data_type] * len(data_or_path_or_list)
             data_or_path_or_list_ret = [[] for d in data_type]
             for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)):
-                
                 for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)):
-                    
                     data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs)
                     data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
 
             return data_or_path_or_list_ret
         else:
             return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list]
-    
     if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file
         data_or_path_or_list = download_from_url(data_or_path_or_list)
-    
+
     if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
         if data_type is None or data_type == "sound":
             data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
@@ -56,10 +52,22 @@
         data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
     elif isinstance(data_or_path_or_list, np.ndarray):  # audio sample point
         data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze()  # [n_samples,]
+    elif isinstance(data_or_path_or_list, str) and data_type == "kaldi_ark":
+        data_mat = kaldiio.load_mat(data_or_path_or_list) 
+        if isinstance(data_mat, tuple):
+            audio_fs, mat = data_mat
+        else:
+            mat = data_mat
+        if mat.dtype == 'int16' or mat.dtype == 'int32':
+            mat = mat.astype(np.float64)
+            mat = mat / 32768
+        if mat.ndim ==2:
+            mat = mat[:,0]
+        data_or_path_or_list = mat
     else:
         pass
         # print(f"unsupport data type: {data_or_path_or_list}, return raw data")
-        
+
     if audio_fs != fs and data_type != "text":
         resampler = torchaudio.transforms.Resample(audio_fs, fs)
         data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
@@ -81,8 +89,6 @@
     return array
 
 def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None, **kwargs):
-    # import pdb;
-    # pdb.set_trace()
     if isinstance(data, np.ndarray):
         data = torch.from_numpy(data)
         if len(data.shape) < 2:
@@ -100,9 +106,7 @@
             data_list.append(data_i)
             data_len.append(data_i.shape[0])
         data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
-    # import pdb;
-    # pdb.set_trace()
-    # if data_type == "sound":
+
     data, data_len = frontend(data, data_len, **kwargs)
     
     if isinstance(data_len, (list, tuple)):

--
Gitblit v1.9.1