From d2c1204d91d7c98be7998e3966bd82e22750293b Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 04 三月 2024 17:50:29 +0800
Subject: [PATCH] Revert "Dev yf" (#1418)

---
 /dev/null                                                          |  495 ---------------------------------------------
 funasr/models/seaco_paraformer/model.py                            |   18 
 funasr/auto/auto_model.py                                          |   12 
 funasr/models/contextual_paraformer/model.py                       |   39 +-
 funasr/models/conformer/encoder.py                                 |    2 
 examples/industrial_data_pretraining/contextual_paraformer/demo.py |    0 
 funasr/utils/load_utils.py                                         |   28 +-
 examples/industrial_data_pretraining/seaco_paraformer/demo.py      |   10 
 examples/industrial_data_pretraining/contextual_paraformer/demo.sh |    2 
 funasr/frontends/default.py                                        |   20 -
 funasr/train_utils/load_pretrained_model.py                        |    3 
 11 files changed, 61 insertions(+), 568 deletions(-)

diff --git a/examples/industrial_data_pretraining/contextual_paraformer/demo.py b/examples/industrial_data_pretraining/contextual_paraformer/demo.py
old mode 100755
new mode 100644
diff --git a/examples/industrial_data_pretraining/contextual_paraformer/demo.sh b/examples/industrial_data_pretraining/contextual_paraformer/demo.sh
old mode 100755
new mode 100644
index 1bd4f7f..8fc66f3
--- a/examples/industrial_data_pretraining/contextual_paraformer/demo.sh
+++ b/examples/industrial_data_pretraining/contextual_paraformer/demo.sh
@@ -2,7 +2,7 @@
 model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
 model_revision="v2.0.4"
 
-python ../../../funasr/bin/inference.py \
+python funasr/bin/inference.py \
 +model=${model} \
 +model_revision=${model_revision} \
 +input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \
diff --git a/examples/industrial_data_pretraining/contextual_paraformer/demo2.sh b/examples/industrial_data_pretraining/contextual_paraformer/demo2.sh
deleted file mode 100755
index 282f4f1..0000000
--- a/examples/industrial_data_pretraining/contextual_paraformer/demo2.sh
+++ /dev/null
@@ -1,9 +0,0 @@
-python -m funasr.bin.inference \
---config-path="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \
---config-name="config.yaml" \
-++init_param="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/model.pb" \
-++tokenizer_conf.token_list="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/tokens.txt" \
-++frontend_conf.cmvn_file="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/am.mvn" \
-++input="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/asr_example_zh.wav" \
-++output_dir="./outputs/debug2" \
-++device="" \
diff --git a/examples/industrial_data_pretraining/contextual_paraformer/path.sh b/examples/industrial_data_pretraining/contextual_paraformer/path.sh
deleted file mode 100755
index 1a6d67e..0000000
--- a/examples/industrial_data_pretraining/contextual_paraformer/path.sh
+++ /dev/null
@@ -1,6 +0,0 @@
-export FUNASR_DIR=$PWD/../../../
-
-# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
-export PYTHONIOENCODING=UTF-8
-export PATH=$FUNASR_DIR/funasr/bin:$PATH
-export PYTHONPATH=$FUNASR_DIR/funasr/bin:$FUNASR_DIR/funasr:$FUNASR_DIR:$PYTHONPATH
diff --git a/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py b/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py
deleted file mode 100755
index e72d871..0000000
--- a/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py
+++ /dev/null
@@ -1,702 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-
-from enum import Enum
-import re, sys, unicodedata
-import codecs
-import argparse
-from tqdm import tqdm
-import os
-import pdb
-remove_tag = False
-spacelist = [" ", "\t", "\r", "\n"]
-puncts = [
-    "!",
-    ",",
-    "?",
-    "銆�",
-    "銆�",
-    "锛�",
-    "锛�",
-    "锛�",
-    "锛�",
-    "锛�",
-    "銆�",
-    "銆�",
-    "锔�",
-    "銆�",
-    "銆�",
-    "銆�",
-    "銆�",
-]
-
-
-class Code(Enum):
-    match = 1
-    substitution = 2
-    insertion = 3
-    deletion = 4
-
-
-class WordError(object):
-    def __init__(self):
-        self.errors = {
-            Code.substitution: 0,
-            Code.insertion: 0,
-            Code.deletion: 0,
-        }
-        self.ref_words = 0
-
-    def get_wer(self):
-        assert self.ref_words != 0
-        errors = (
-                self.errors[Code.substitution]
-                + self.errors[Code.insertion]
-                + self.errors[Code.deletion]
-        )
-        return 100.0 * errors / self.ref_words
-
-    def get_result_string(self):
-        return (
-            f"error_rate={self.get_wer():.4f}, "
-            f"ref_words={self.ref_words}, "
-            f"subs={self.errors[Code.substitution]}, "
-            f"ins={self.errors[Code.insertion]}, "
-            f"dels={self.errors[Code.deletion]}"
-        )
-
-
-def characterize(string):
-    res = []
-    i = 0
-    while i < len(string):
-        char = string[i]
-        if char in puncts:
-            i += 1
-            continue
-        cat1 = unicodedata.category(char)
-        # https://unicodebook.readthedocs.io/unicode.html#unicode-categories
-        if cat1 == "Zs" or cat1 == "Cn" or char in spacelist:  # space or not assigned
-            i += 1
-            continue
-        if cat1 == "Lo":  # letter-other
-            res.append(char)
-            i += 1
-        else:
-            # some input looks like: <unk><noise>, we want to separate it to two words.
-            sep = " "
-            if char == "<":
-                sep = ">"
-            j = i + 1
-            while j < len(string):
-                c = string[j]
-                if ord(c) >= 128 or (c in spacelist) or (c == sep):
-                    break
-                j += 1
-            if j < len(string) and string[j] == ">":
-                j += 1
-            res.append(string[i:j])
-            i = j
-    return res
-
-
-def stripoff_tags(x):
-    if not x:
-        return ""
-    chars = []
-    i = 0
-    T = len(x)
-    while i < T:
-        if x[i] == "<":
-            while i < T and x[i] != ">":
-                i += 1
-            i += 1
-        else:
-            chars.append(x[i])
-            i += 1
-    return "".join(chars)
-
-
-def normalize(sentence, ignore_words, cs, split=None):
-    """sentence, ignore_words are both in unicode"""
-    new_sentence = []
-    for token in sentence:
-        x = token
-        if not cs:
-            x = x.upper()
-        if x in ignore_words:
-            continue
-        if remove_tag:
-            x = stripoff_tags(x)
-        if not x:
-            continue
-        if split and x in split:
-            new_sentence += split[x]
-        else:
-            new_sentence.append(x)
-    return new_sentence
-
-
-class Calculator:
-    def __init__(self):
-        self.data = {}
-        self.space = []
-        self.cost = {}
-        self.cost["cor"] = 0
-        self.cost["sub"] = 1
-        self.cost["del"] = 1
-        self.cost["ins"] = 1
-
-    def calculate(self, lab, rec):
-        # Initialization
-        lab.insert(0, "")
-        rec.insert(0, "")
-        while len(self.space) < len(lab):
-            self.space.append([])
-        for row in self.space:
-            for element in row:
-                element["dist"] = 0
-                element["error"] = "non"
-            while len(row) < len(rec):
-                row.append({"dist": 0, "error": "non"})
-        for i in range(len(lab)):
-            self.space[i][0]["dist"] = i
-            self.space[i][0]["error"] = "del"
-        for j in range(len(rec)):
-            self.space[0][j]["dist"] = j
-            self.space[0][j]["error"] = "ins"
-        self.space[0][0]["error"] = "non"
-        for token in lab:
-            if token not in self.data and len(token) > 0:
-                self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
-        for token in rec:
-            if token not in self.data and len(token) > 0:
-                self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
-        # Computing edit distance
-        for i, lab_token in enumerate(lab):
-            for j, rec_token in enumerate(rec):
-                if i == 0 or j == 0:
-                    continue
-                min_dist = sys.maxsize
-                min_error = "none"
-                dist = self.space[i - 1][j]["dist"] + self.cost["del"]
-                error = "del"
-                if dist < min_dist:
-                    min_dist = dist
-                    min_error = error
-                dist = self.space[i][j - 1]["dist"] + self.cost["ins"]
-                error = "ins"
-                if dist < min_dist:
-                    min_dist = dist
-                    min_error = error
-                if lab_token == rec_token.replace("<BIAS>", ""):
-                    dist = self.space[i - 1][j - 1]["dist"] + self.cost["cor"]
-                    error = "cor"
-                else:
-                    dist = self.space[i - 1][j - 1]["dist"] + self.cost["sub"]
-                    error = "sub"
-                if dist < min_dist:
-                    min_dist = dist
-                    min_error = error
-                self.space[i][j]["dist"] = min_dist
-                self.space[i][j]["error"] = min_error
-        # Tracing back
-        result = {
-            "lab": [],
-            "rec": [],
-            "code": [],
-            "all": 0,
-            "cor": 0,
-            "sub": 0,
-            "ins": 0,
-            "del": 0,
-        }
-        i = len(lab) - 1
-        j = len(rec) - 1
-        while True:
-            if self.space[i][j]["error"] == "cor":  # correct
-                if len(lab[i]) > 0:
-                    self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
-                    self.data[lab[i]]["cor"] = self.data[lab[i]]["cor"] + 1
-                    result["all"] = result["all"] + 1
-                    result["cor"] = result["cor"] + 1
-                result["lab"].insert(0, lab[i])
-                result["rec"].insert(0, rec[j])
-                result["code"].insert(0, Code.match)
-                i = i - 1
-                j = j - 1
-            elif self.space[i][j]["error"] == "sub":  # substitution
-                if len(lab[i]) > 0:
-                    self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
-                    self.data[lab[i]]["sub"] = self.data[lab[i]]["sub"] + 1
-                    result["all"] = result["all"] + 1
-                    result["sub"] = result["sub"] + 1
-                result["lab"].insert(0, lab[i])
-                result["rec"].insert(0, rec[j])
-                result["code"].insert(0, Code.substitution)
-                i = i - 1
-                j = j - 1
-            elif self.space[i][j]["error"] == "del":  # deletion
-                if len(lab[i]) > 0:
-                    self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
-                    self.data[lab[i]]["del"] = self.data[lab[i]]["del"] + 1
-                    result["all"] = result["all"] + 1
-                    result["del"] = result["del"] + 1
-                result["lab"].insert(0, lab[i])
-                result["rec"].insert(0, "")
-                result["code"].insert(0, Code.deletion)
-                i = i - 1
-            elif self.space[i][j]["error"] == "ins":  # insertion
-                if len(rec[j]) > 0:
-                    self.data[rec[j]]["ins"] = self.data[rec[j]]["ins"] + 1
-                    result["ins"] = result["ins"] + 1
-                result["lab"].insert(0, "")
-                result["rec"].insert(0, rec[j])
-                result["code"].insert(0, Code.insertion)
-                j = j - 1
-            elif self.space[i][j]["error"] == "non":  # starting point
-                break
-            else:  # shouldn't reach here
-                print(
-                    "this should not happen , i = {i} , j = {j} , error = {error}".format(
-                        i=i, j=j, error=self.space[i][j]["error"]
-                    )
-                )
-        return result
-
-    def overall(self):
-        result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
-        for token in self.data:
-            result["all"] = result["all"] + self.data[token]["all"]
-            result["cor"] = result["cor"] + self.data[token]["cor"]
-            result["sub"] = result["sub"] + self.data[token]["sub"]
-            result["ins"] = result["ins"] + self.data[token]["ins"]
-            result["del"] = result["del"] + self.data[token]["del"]
-        return result
-
-    def cluster(self, data):
-        result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
-        for token in data:
-            if token in self.data:
-                result["all"] = result["all"] + self.data[token]["all"]
-                result["cor"] = result["cor"] + self.data[token]["cor"]
-                result["sub"] = result["sub"] + self.data[token]["sub"]
-                result["ins"] = result["ins"] + self.data[token]["ins"]
-                result["del"] = result["del"] + self.data[token]["del"]
-        return result
-
-    def keys(self):
-        return list(self.data.keys())
-
-
-def width(string):
-    return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
-
-
-def default_cluster(word):
-    unicode_names = [unicodedata.name(char) for char in word]
-    for i in reversed(range(len(unicode_names))):
-        if unicode_names[i].startswith("DIGIT"):  # 1
-            unicode_names[i] = "Number"  # 'DIGIT'
-        elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[
-            i
-        ].startswith("CJK COMPATIBILITY IDEOGRAPH"):
-            # 鏄� / 铯�
-            unicode_names[i] = "Mandarin"  # 'CJK IDEOGRAPH'
-        elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[
-            i
-        ].startswith("LATIN SMALL LETTER"):
-            # A / a
-            unicode_names[i] = "English"  # 'LATIN LETTER'
-        elif unicode_names[i].startswith("HIRAGANA LETTER"):  # 銇� 銇� 銈�
-            unicode_names[i] = "Japanese"  # 'GANA LETTER'
-        elif (
-                unicode_names[i].startswith("AMPERSAND")
-                or unicode_names[i].startswith("APOSTROPHE")
-                or unicode_names[i].startswith("COMMERCIAL AT")
-                or unicode_names[i].startswith("DEGREE CELSIUS")
-                or unicode_names[i].startswith("EQUALS SIGN")
-                or unicode_names[i].startswith("FULL STOP")
-                or unicode_names[i].startswith("HYPHEN-MINUS")
-                or unicode_names[i].startswith("LOW LINE")
-                or unicode_names[i].startswith("NUMBER SIGN")
-                or unicode_names[i].startswith("PLUS SIGN")
-                or unicode_names[i].startswith("SEMICOLON")
-        ):
-            # & / ' / @ / 鈩� / = / . / - / _ / # / + / ;
-            del unicode_names[i]
-        else:
-            return "Other"
-    if len(unicode_names) == 0:
-        return "Other"
-    if len(unicode_names) == 1:
-        return unicode_names[0]
-    for i in range(len(unicode_names) - 1):
-        if unicode_names[i] != unicode_names[i + 1]:
-            return "Other"
-    return unicode_names[0]
-
-
-def get_args():
-    parser = argparse.ArgumentParser(description="wer cal")
-    parser.add_argument("--ref", type=str, help="Text input path")
-    parser.add_argument("--ref_ocr", type=str, help="Text input path")
-    parser.add_argument("--rec_name", type=str, action="append", default=[])
-    parser.add_argument("--rec_file", type=str, action="append", default=[])
-    parser.add_argument("--verbose", type=int, default=1, help="show")
-    parser.add_argument("--char", type=bool, default=True, help="show")
-    args = parser.parse_args()
-    return args
-
-
-def main(args):
-    cluster_file = ""
-    ignore_words = set()
-    tochar = args.char
-    verbose = args.verbose
-    padding_symbol = " "
-    case_sensitive = False
-    max_words_per_line = sys.maxsize
-    split = None
-
-    if not case_sensitive:
-        ig = set([w.upper() for w in ignore_words])
-        ignore_words = ig
-
-    default_clusters = {}
-    default_words = {}
-    ref_file = args.ref
-    ref_ocr = args.ref_ocr
-    rec_files = args.rec_file
-    rec_names = args.rec_name
-    assert len(rec_files) == len(rec_names)
-
-    # load ocr
-    ref_ocr_dict = {}
-    with codecs.open(ref_ocr, "r", "utf-8") as fh:
-        for line in fh:
-            if "$" in line:
-                line = line.replace("$", " ")
-            if tochar:
-                array = characterize(line)
-            else:
-                array = line.strip().split()
-            if len(array) == 0:
-                continue
-            fid = array[0]
-            ref_ocr_dict[fid] = normalize(array[1:], ignore_words, case_sensitive, split)
-
-    if split and not case_sensitive:
-        newsplit = dict()
-        for w in split:
-            words = split[w]
-            for i in range(len(words)):
-                words[i] = words[i].upper()
-            newsplit[w.upper()] = words
-        split = newsplit
-
-    rec_sets = {}
-    calculators_dict = dict()
-    ub_wer_dict = dict()
-    hotwords_related_dict = dict()  # 璁板綍recall鐩稿叧鐨勫唴瀹�
-    for i, hyp_file in enumerate(rec_files):
-        rec_sets[rec_names[i]] = dict()
-        with codecs.open(hyp_file, "r", "utf-8") as fh:
-            for line in fh:
-                if tochar:
-                    array = characterize(line)
-                else:
-                    array = line.strip().split()
-                if len(array) == 0:
-                    continue
-                fid = array[0]
-                rec_sets[rec_names[i]][fid] = normalize(array[1:], ignore_words, case_sensitive, split)
-
-        calculators_dict[rec_names[i]] = Calculator()
-        ub_wer_dict[rec_names[i]] = {"u_wer": WordError(), "b_wer": WordError(), "wer": WordError()}
-        hotwords_related_dict[rec_names[i]] = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}
-        # tp: 鐑瘝鍦╨abel閲岋紝鍚屾椂鍦╮ec閲�
-        # tn: 鐑瘝涓嶅湪label閲岋紝鍚屾椂涓嶅湪rec閲�
-        # fp: 鐑瘝涓嶅湪label閲岋紝浣嗘槸鍦╮ec閲�
-        # fn: 鐑瘝鍦╨abel閲岋紝浣嗘槸涓嶅湪rec閲�
-
-    # record wrong label but in ocr
-    wrong_rec_but_in_ocr_dict = {}
-    for rec_name in rec_names:
-        wrong_rec_but_in_ocr_dict[rec_name] = 0
-
-    _file_total_len = 0
-    with os.popen("cat {} | wc -l".format(ref_file)) as pipe:
-        _file_total_len = int(pipe.read().strip())
-
-    # compute error rate on the interaction of reference file and hyp file
-    for line in tqdm(open(ref_file, 'r', encoding='utf-8'), total=_file_total_len):
-        if tochar:
-            array = characterize(line)
-        else:
-            array = line.rstrip('\n').split()
-        if len(array) == 0: continue
-        fid = array[0]
-        lab = normalize(array[1:], ignore_words, case_sensitive, split)
-
-        if verbose:
-            print('\nutt: %s' % fid)
-
-        ocr_text = ref_ocr_dict[fid]
-        ocr_set = set(ocr_text)
-        print('ocr: {}'.format(" ".join(ocr_text)))
-        list_match = []  # 鎸噇abel閲岄潰鍦╫cr閲岄潰鐨勫唴瀹�
-        list_not_mathch = []
-        tmp_error = 0
-        tmp_match = 0
-        for index in range(len(lab)):
-            # text_list.append(uttlist[index+1])
-            if lab[index] not in ocr_set:
-                tmp_error += 1
-                list_not_mathch.append(lab[index])
-            else:
-                tmp_match += 1
-                list_match.append(lab[index])
-        print('label in ocr: {}'.format(" ".join(list_match)))
-
-        # for each reco file
-        base_wrong_ocr_wer = None
-        ocr_wrong_ocr_wer = None
-
-        for rec_name in rec_names:
-            rec_set = rec_sets[rec_name]
-            if fid not in rec_set:
-                continue
-            rec = rec_set[fid]
-
-            # print(rec)
-            for word in rec + lab:
-                if word not in default_words:
-                    default_cluster_name = default_cluster(word)
-                    if default_cluster_name not in default_clusters:
-                        default_clusters[default_cluster_name] = {}
-                    if word not in default_clusters[default_cluster_name]:
-                        default_clusters[default_cluster_name][word] = 1
-                    default_words[word] = default_cluster_name
-
-            result = calculators_dict[rec_name].calculate(lab.copy(), rec.copy())
-            if verbose:
-                if result['all'] != 0:
-                    wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
-                else:
-                    wer = 0.0
-            print('WER(%s): %4.2f %%' % (rec_name, wer), end=' ')
-            print('N=%d C=%d S=%d D=%d I=%d' %
-                  (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
-
-
-            # print(result['rec'])
-            wrong_rec_but_in_ocr = []
-            for idx in range(len(result['lab'])):
-                if result['lab'][idx] != "":
-                    if result['lab'][idx] != result['rec'][idx].replace("<BIAS>", ""):
-                        if result['lab'][idx] in list_match:
-                            wrong_rec_but_in_ocr.append(result['lab'][idx])
-                            wrong_rec_but_in_ocr_dict[rec_name] += 1
-            print('wrong_rec_but_in_ocr: {}'.format(" ".join(wrong_rec_but_in_ocr)))
-
-            if rec_name == "base":
-                base_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
-            if "ocr" in rec_name or "hot" in rec_name:
-                ocr_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
-                if ocr_wrong_ocr_wer < base_wrong_ocr_wer:
-                    print("{} {} helps, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
-                elif ocr_wrong_ocr_wer > base_wrong_ocr_wer:
-                    print("{} {} hurts, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
-
-            # recall = 0
-            # false_alarm = 0
-            # for idx in range(len(result['lab'])):
-            #     if "<BIAS>" in result['rec'][idx]:
-            #         if result['rec'][idx].replace("<BIAS>", "") in list_match:
-            #             recall += 1
-            #         else:
-            #             false_alarm += 1
-            # print("bias hotwords recall: {}, fa: {}, list_match {}, recall: {:.2f}, fa: {:.2f}".format(
-            #     recall, false_alarm, len(list_match), recall / len(list_match) if len(list_match) != 0 else 0, false_alarm / len(list_match) if len(list_match) != 0 else 0
-            # ))
-            # tp: 鐑瘝鍦╨abel閲岋紝鍚屾椂鍦╮ec閲�
-            # tn: 鐑瘝涓嶅湪label閲岋紝鍚屾椂涓嶅湪rec閲�
-            # fp: 鐑瘝涓嶅湪label閲岋紝浣嗘槸鍦╮ec閲�
-            # fn: 鐑瘝鍦╨abel閲岋紝浣嗘槸涓嶅湪rec閲�
-            _rec_list = [word.replace("<BIAS>", "") for word in rec]
-            _label_list = [word for word in lab]
-            _tp = _tn = _fp = _fn = 0
-            hot_true_list = [hotword for hotword in ocr_text if hotword in _label_list]
-            hot_bad_list = [hotword for hotword in ocr_text if hotword not in _label_list]
-            for badhotword in hot_bad_list:
-                count = len([word for word in _rec_list if word == badhotword])
-                # print(f"bad {badhotword} count: {count}")
-                # for word in _rec_list:
-                #     if badhotword == word:
-                #         count += 1
-                if count == 0:
-                    hotwords_related_dict[rec_name]['tn'] += 1
-                    _tn += 1
-                    # fp: 0
-                else:
-                    hotwords_related_dict[rec_name]['fp'] += count
-                    _fp += count
-                    # tn: 0
-                # if badhotword in _rec_list:
-                #     hotwords_related_dict[rec_name]['fp'] += 1
-                # else:
-                #     hotwords_related_dict[rec_name]['tn'] += 1
-            for hotword in hot_true_list:
-                true_count = len([word for word in _label_list if hotword == word])
-                rec_count = len([word for word in _rec_list if hotword == word])
-                # print(f"good {hotword} true_count: {true_count}, rec_count: {rec_count}")
-                if rec_count == true_count:
-                    hotwords_related_dict[rec_name]['tp'] += true_count
-                    _tp += true_count
-                elif rec_count > true_count:
-                    hotwords_related_dict[rec_name]['tp'] += true_count
-                    # fp: 涓嶅湪label閲岋紝浣嗘槸鍦╮ec閲�
-                    hotwords_related_dict[rec_name]['fp'] += rec_count - true_count
-                    _tp += true_count
-                    _fp += rec_count - true_count
-                else:
-                    hotwords_related_dict[rec_name]['tp'] += rec_count
-                    # fn: 鐑瘝鍦╨abel閲岋紝浣嗘槸涓嶅湪rec閲�
-                    hotwords_related_dict[rec_name]['fn'] += true_count - rec_count
-                    _tp += rec_count
-                    _fn += true_count - rec_count
-            print("hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
-                _tp, _tn, _fp, _fn, sum([_tp, _tn, _fp, _fn]), _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0
-            ))
-
-            # if hotword in _rec_list:
-            #     hotwords_related_dict[rec_name]['tp'] += 1
-            # else:
-            #     hotwords_related_dict[rec_name]['fn'] += 1
-            # 璁$畻uwer, bwer, wer
-            for code, rec_word, lab_word in zip(result["code"], result["rec"], result["lab"]):
-                if code == Code.match:
-                    ub_wer_dict[rec_name]["wer"].ref_words += 1
-                    if lab_word in hot_true_list:
-                        # tmp_ref.append(ref_tokens[ref_idx])
-                        ub_wer_dict[rec_name]["b_wer"].ref_words += 1
-                    else:
-                        ub_wer_dict[rec_name]["u_wer"].ref_words += 1
-                elif code == Code.substitution:
-                    ub_wer_dict[rec_name]["wer"].ref_words += 1
-                    ub_wer_dict[rec_name]["wer"].errors[Code.substitution] += 1
-                    if lab_word in hot_true_list:
-                        # tmp_ref.append(ref_tokens[ref_idx])
-                        ub_wer_dict[rec_name]["b_wer"].ref_words += 1
-                        ub_wer_dict[rec_name]["b_wer"].errors[Code.substitution] += 1
-                    else:
-                        ub_wer_dict[rec_name]["u_wer"].ref_words += 1
-                        ub_wer_dict[rec_name]["u_wer"].errors[Code.substitution] += 1
-                elif code == Code.deletion:
-                    ub_wer_dict[rec_name]["wer"].ref_words += 1
-                    ub_wer_dict[rec_name]["wer"].errors[Code.deletion] += 1
-                    if lab_word in hot_true_list:
-                        # tmp_ref.append(ref_tokens[ref_idx])
-                        ub_wer_dict[rec_name]["b_wer"].ref_words += 1
-                        ub_wer_dict[rec_name]["b_wer"].errors[Code.deletion] += 1
-                    else:
-                        ub_wer_dict[rec_name]["u_wer"].ref_words += 1
-                        ub_wer_dict[rec_name]["u_wer"].errors[Code.deletion] += 1
-                elif code == Code.insertion:
-                    ub_wer_dict[rec_name]["wer"].errors[Code.insertion] += 1
-                    if rec_word in hot_true_list:
-                        ub_wer_dict[rec_name]["b_wer"].errors[Code.insertion] += 1
-                    else:
-                        ub_wer_dict[rec_name]["u_wer"].errors[Code.insertion] += 1
-
-            space = {}
-            space['lab'] = []
-            space['rec'] = []
-            for idx in range(len(result['lab'])):
-                len_lab = width(result['lab'][idx])
-                len_rec = width(result['rec'][idx])
-                length = max(len_lab, len_rec)
-                space['lab'].append(length - len_lab)
-                space['rec'].append(length - len_rec)
-            upper_lab = len(result['lab'])
-            upper_rec = len(result['rec'])
-            lab1, rec1 = 0, 0
-            while lab1 < upper_lab or rec1 < upper_rec:
-                if verbose > 1:
-                    print('lab(%s):' % fid.encode('utf-8'), end=' ')
-                else:
-                    print('lab:', end=' ')
-                lab2 = min(upper_lab, lab1 + max_words_per_line)
-                for idx in range(lab1, lab2):
-                    token = result['lab'][idx]
-                    print('{token}'.format(token=token), end='')
-                    for n in range(space['lab'][idx]):
-                        print(padding_symbol, end='')
-                    print(' ', end='')
-                print()
-                if verbose > 1:
-                    print('rec(%s):' % fid.encode('utf-8'), end=' ')
-                else:
-                    print('rec:', end=' ')
-
-                rec2 = min(upper_rec, rec1 + max_words_per_line)
-                for idx in range(rec1, rec2):
-                    token = result['rec'][idx]
-                    print('{token}'.format(token=token), end='')
-                    for n in range(space['rec'][idx]):
-                        print(padding_symbol, end='')
-                    print(' ', end='')
-                print()
-                # print('\n', end='\n')
-                lab1 = lab2
-                rec1 = rec2
-        print('\n', end='\n')
-        # break
-    if verbose:
-        print('===========================================================================')
-        print()
-
-    print(wrong_rec_but_in_ocr_dict)
-    for rec_name in rec_names:
-        result = calculators_dict[rec_name].overall()
-
-        if result['all'] != 0:
-            wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
-        else:
-            wer = 0.0
-        print('{} Overall -> {:4.2f} %'.format(rec_name, wer), end=' ')
-        print('N=%d C=%d S=%d D=%d I=%d' %
-              (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
-        print(f"WER: {ub_wer_dict[rec_name]['wer'].get_result_string()}")
-        print(f"U-WER: {ub_wer_dict[rec_name]['u_wer'].get_result_string()}")
-        print(f"B-WER: {ub_wer_dict[rec_name]['b_wer'].get_result_string()}")
-
-        print('hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%'.format(
-            hotwords_related_dict[rec_name]['tp'],
-            hotwords_related_dict[rec_name]['tn'],
-            hotwords_related_dict[rec_name]['fp'],
-            hotwords_related_dict[rec_name]['fn'],
-            sum([v for k, v in hotwords_related_dict[rec_name].items()]),
-            hotwords_related_dict[rec_name]['tp'] / (
-                    hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn']
-            ) * 100 if hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] != 0 else 0
-        ))
-
-        # tp: 鐑瘝鍦╨abel閲岋紝鍚屾椂鍦╮ec閲�
-        # tn: 鐑瘝涓嶅湪label閲岋紝鍚屾椂涓嶅湪rec閲�
-        # fp: 鐑瘝涓嶅湪label閲岋紝浣嗘槸鍦╮ec閲�
-        # fn: 鐑瘝鍦╨abel閲岋紝浣嗘槸涓嶅湪rec閲�
-        if not verbose:
-            print()
-        print()
-
-
-if __name__ == "__main__":
-    args = get_args()
-    
-    # print("")
-    print(args)
-    main(args)
-
diff --git a/examples/industrial_data_pretraining/lcbnet/demo.py b/examples/industrial_data_pretraining/lcbnet/demo.py
deleted file mode 100755
index 4ca5255..0000000
--- a/examples/industrial_data_pretraining/lcbnet/demo.py
+++ /dev/null
@@ -1,13 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-from funasr import AutoModel
-
-model = AutoModel(model="iic/LCB-NET",
-                  model_revision="v1.0.0")
-
-res = model.generate(input=("https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/asr_example.wav","https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/ocr.txt"),data_type=("sound", "text"))
-
-print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/lcbnet/demo.sh b/examples/industrial_data_pretraining/lcbnet/demo.sh
deleted file mode 100755
index 2f226bc..0000000
--- a/examples/industrial_data_pretraining/lcbnet/demo.sh
+++ /dev/null
@@ -1,72 +0,0 @@
-file_dir="/home/yf352572/.cache/modelscope/hub/iic/LCB-NET/"
-CUDA_VISIBLE_DEVICES="0,1"
-inference_device="cuda"
-
-if [ ${inference_device} == "cuda" ]; then
-    nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
-else
-    inference_batch_size=1
-    CUDA_VISIBLE_DEVICES=""
-    for JOB in $(seq ${nj}); do
-        CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1,"
-    done
-fi
-
-inference_dir="outputs/slidespeech_dev"
-_logdir="${inference_dir}/logdir"
-echo "inference_dir: ${inference_dir}"
-
-mkdir -p "${_logdir}"
-key_file1=${file_dir}/dev/wav.scp
-key_file2=${file_dir}/dev/ocr.txt
-split_scps1=
-split_scps2=
-for JOB in $(seq "${nj}"); do
-    split_scps1+=" ${_logdir}/wav.${JOB}.scp"
-    split_scps2+=" ${_logdir}/ocr.${JOB}.txt"
-done
-utils/split_scp.pl "${key_file1}" ${split_scps1}
-utils/split_scp.pl "${key_file2}" ${split_scps2}
-
-gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ })
-for JOB in $(seq ${nj}); do
-    {
-        id=$((JOB-1))
-        gpuid=${gpuid_list_array[$id]}
-
-        export CUDA_VISIBLE_DEVICES=${gpuid}
-
-        python -m funasr.bin.inference \
-        --config-path=${file_dir} \
-        --config-name="config.yaml" \
-        ++init_param=${file_dir}/model.pt \
-        ++tokenizer_conf.token_list=${file_dir}/tokens.txt \
-        ++input=[${_logdir}/wav.${JOB}.scp,${_logdir}/ocr.${JOB}.txt] \
-        +data_type='["kaldi_ark", "text"]' \
-        ++tokenizer_conf.bpemodel=${file_dir}/bpe.pt \
-        ++normalize_conf.stats_file=${file_dir}/am.mvn \
-        ++output_dir="${inference_dir}/${JOB}" \
-        ++device="${inference_device}" \
-        ++ncpu=1 \
-        ++disable_log=true  &> ${_logdir}/log.${JOB}.txt
-
-    }&
-done
-wait
-
-
-mkdir -p ${inference_dir}/1best_recog
-
-for JOB in $(seq "${nj}"); do
-   cat "${inference_dir}/${JOB}/1best_recog/token" >> "${inference_dir}/1best_recog/token"
-done  
-
-echo "Computing WER ..."
-sed -e 's/ /\t/' -e 's/ //g' -e 's/鈻�/ /g' -e 's/\t /\t/'  ${inference_dir}/1best_recog/token > ${inference_dir}/1best_recog/token.proc
-cp  ${file_dir}/dev/text ${inference_dir}/1best_recog/token.ref
-cp  ${file_dir}/dev/ocr.list ${inference_dir}/1best_recog/ocr.list
-python utils/compute_wer.py ${inference_dir}/1best_recog/token.ref ${inference_dir}/1best_recog/token.proc ${inference_dir}/1best_recog/token.cer
-tail -n 3 ${inference_dir}/1best_recog/token.cer
-
-./run_bwer_recall.sh  ${inference_dir}/1best_recog/
-tail -n 6 ${inference_dir}/1best_recog/BWER-UWER.results |head -n 5
diff --git a/examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh b/examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh
deleted file mode 100755
index 7d6b6ff..0000000
--- a/examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh
+++ /dev/null
@@ -1,11 +0,0 @@
-#now_result_name=asr_conformer_acc1_lr002_warm20000/decode_asr_asr_model_valid.acc.ave
-#hotword_type=ocr_1ngram_top10_hotwords_list
-hot_exp_suf=$1
-
-
-python compute_wer_details.py --v 1 \
-   --ref ${hot_exp_suf}/token.ref \
-   --ref_ocr ${hot_exp_suf}/ocr.list  \
-   --rec_name base \
-   --rec_file ${hot_exp_suf}/token.proc \
-   > ${hot_exp_suf}/BWER-UWER.results
diff --git a/examples/industrial_data_pretraining/lcbnet/utils b/examples/industrial_data_pretraining/lcbnet/utils
deleted file mode 120000
index be5e5a3..0000000
--- a/examples/industrial_data_pretraining/lcbnet/utils
+++ /dev/null
@@ -1 +0,0 @@
-../../aishell/paraformer/utils
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
index 551dd8b..a44c649 100644
--- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py
+++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
@@ -7,10 +7,10 @@
 
 model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
                   model_revision="v2.0.4",
-                  # vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
-                  # vad_model_revision="v2.0.4",
-                  # punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
-                  # punc_model_revision="v2.0.4",
+                  vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+                  vad_model_revision="v2.0.4",
+                  punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
+                  punc_model_revision="v2.0.4",
                   # spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
                   # spk_model_revision="v2.0.2",
                   )
@@ -43,4 +43,4 @@
 wav_file = os.path.join(model.model_path, "example/asr_example.wav")
 speech, sample_rate = soundfile.read(wav_file)
 res = model.generate(input=[speech], batch_size_s=300, is_final=True)
-'''
+'''
\ No newline at end of file
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index ec3c3f3..921ede8 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -28,7 +28,7 @@
     from funasr.models.campplus.cluster_backend import ClusterBackend
 except:
     print("If you want to use the speaker diarization, please `pip install hdbscan`")
-import pdb
+
 
 def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
     """
@@ -46,7 +46,6 @@
     chars = string.ascii_letters + string.digits
     if isinstance(data_in, str) and data_in.startswith('http'): # url
         data_in = download_from_url(data_in)
-
     if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
         _, file_extension = os.path.splitext(data_in)
         file_extension = file_extension.lower()
@@ -147,7 +146,7 @@
             kwargs = download_model(**kwargs)
         
         set_all_random_seed(kwargs.get("seed", 0))
-
+        
         device = kwargs.get("device", "cuda")
         if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
             device = "cpu"
@@ -169,6 +168,7 @@
             vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
         else:
             vocab_size = -1
+        
         # build frontend
         frontend = kwargs.get("frontend", None)
         kwargs["input_size"] = None
@@ -181,6 +181,7 @@
         # build model
         model_class = tables.model_classes.get(kwargs["model"])
         model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
+        
         model.to(device)
         
         # init_param
@@ -223,9 +224,9 @@
         batch_size = kwargs.get("batch_size", 1)
         # if kwargs.get("device", "cpu") == "cpu":
         #     batch_size = 1
-
+        
         key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key)
-
+        
         speed_stats = {}
         asr_result_list = []
         num_samples = len(data_list)
@@ -238,7 +239,6 @@
             data_batch = data_list[beg_idx:end_idx]
             key_batch = key_list[beg_idx:end_idx]
             batch = {"data_in": data_batch, "key": key_batch}
-
             if (end_idx - beg_idx) == 1 and kwargs.get("data_type", None) == "fbank": # fbank
                 batch["data_in"] = data_batch[0]
                 batch["data_lengths"] = input_len
diff --git a/funasr/frontends/default.py b/funasr/frontends/default.py
index c4bdbd7..8ac1ca8 100644
--- a/funasr/frontends/default.py
+++ b/funasr/frontends/default.py
@@ -3,6 +3,7 @@
 from typing import Tuple
 from typing import Union
 import logging
+import humanfriendly
 import numpy as np
 import torch
 import torch.nn as nn
@@ -15,10 +16,8 @@
 from funasr.frontends.utils.stft import Stft
 from funasr.frontends.utils.frontend import Frontend
 from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.register import tables
 
 
-@tables.register("frontend_classes", "DefaultFrontend")
 class DefaultFrontend(nn.Module):
     """Conventional frontend structure for ASR.
     Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
@@ -26,7 +25,7 @@
 
     def __init__(
             self,
-            fs: int = 16000,
+            fs: Union[int, str] = 16000,
             n_fft: int = 512,
             win_length: int = None,
             hop_length: int = 128,
@@ -41,14 +40,14 @@
             frontend_conf: Optional[dict] = None,
             apply_stft: bool = True,
             use_channel: int = None,
-            **kwargs,
     ):
         super().__init__()
+        if isinstance(fs, str):
+            fs = humanfriendly.parse_size(fs)
 
         # Deepcopy (In general, dict shouldn't be used as default arg)
         frontend_conf = copy.deepcopy(frontend_conf)
         self.hop_length = hop_length
-        self.fs = fs
 
         if apply_stft:
             self.stft = Stft(
@@ -85,12 +84,8 @@
         return self.n_mels
 
     def forward(
-            self, input: torch.Tensor, input_lengths:  Union[torch.Tensor, list]
+            self, input: torch.Tensor, input_lengths: torch.Tensor
     ) -> Tuple[torch.Tensor, torch.Tensor]:
-        if isinstance(input_lengths, list):
-            input_lengths = torch.tensor(input_lengths)
-        if  input.dtype == torch.float64:
-            input = input.float()
         # 1. Domain-conversion: e.g. Stft: time -> time-freq
         if self.stft is not None:
             input_stft, feats_lens = self._compute_stft(input, input_lengths)
@@ -150,7 +145,7 @@
 
     def __init__(
             self,
-            fs: int = 16000,
+            fs: Union[int, str] = 16000,
             n_fft: int = 512,
             win_length: int = None,
             hop_length: int = None,
@@ -173,6 +168,9 @@
             mc: bool = True
     ):
         super().__init__()
+        if isinstance(fs, str):
+            fs = humanfriendly.parse_size(fs)
+
         # Deepcopy (In general, dict shouldn't be used as default arg)
         frontend_conf = copy.deepcopy(frontend_conf)
         if win_length is None and hop_length is None:
diff --git a/funasr/models/conformer/encoder.py b/funasr/models/conformer/encoder.py
index be973c6..1d252c2 100644
--- a/funasr/models/conformer/encoder.py
+++ b/funasr/models/conformer/encoder.py
@@ -47,7 +47,7 @@
 from funasr.models.transformer.utils.subsampling import Conv2dSubsamplingPad
 from funasr.models.transformer.utils.subsampling import StreamingConvInput
 from funasr.register import tables
-import pdb
+
 
 class ConvolutionModule(nn.Module):
     """ConvolutionModule in Conformer model.
diff --git a/funasr/models/contextual_paraformer/model.py b/funasr/models/contextual_paraformer/model.py
index 7d6f729..49868a8 100644
--- a/funasr/models/contextual_paraformer/model.py
+++ b/funasr/models/contextual_paraformer/model.py
@@ -29,7 +29,7 @@
 from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
 from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-import pdb
+
 
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
@@ -62,6 +62,7 @@
         crit_attn_weight = kwargs.get("crit_attn_weight", 0.0)
         crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0)
         bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)
+
 
         if bias_encoder_type == 'lstm':
             self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate)
@@ -102,16 +103,17 @@
             text_lengths = text_lengths[:, 0]
         if len(speech_lengths.size()) > 1:
             speech_lengths = speech_lengths[:, 0]
-
+        
         batch_size = speech.shape[0]
 
         hotword_pad = kwargs.get("hotword_pad")
         hotword_lengths = kwargs.get("hotword_lengths")
         dha_pad = kwargs.get("dha_pad")
-
+        
         # 1. Encoder
         encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
 
+        
         loss_ctc, cer_ctc = None, None
         
         stats = dict()
@@ -126,11 +128,12 @@
             stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
             stats["cer_ctc"] = cer_ctc
         
+
         # 2b. Attention decoder branch
         loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
             encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths
         )
-
+        
         # 3. CTC-Att loss definition
         if self.ctc_weight == 0.0:
             loss = loss_att + loss_pre * self.predictor_weight
@@ -168,24 +171,22 @@
     ):
         encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
             encoder_out.device)
-
         if self.predictor_bias == 1:
             _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
             ys_pad_lens = ys_pad_lens + self.predictor_bias
-
         pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
                                                                      ignore_id=self.ignore_id)
+        
         # -1. bias encoder
         if self.use_decoder_embedding:
             hw_embed = self.decoder.embed(hotword_pad)
         else:
             hw_embed = self.bias_embed(hotword_pad)
-
         hw_embed, (_, _) = self.bias_encoder(hw_embed)
         _ind = np.arange(0, hotword_pad.shape[0]).tolist()
         selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]]
         contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
-
+        
         # 0. sampler
         decoder_out_1st = None
         if self.sampling_ratio > 0.0:
@@ -194,7 +195,7 @@
                                                            pre_acoustic_embeds, contextual_info)
         else:
             sematic_embeds = pre_acoustic_embeds
-
+        
         # 1. Forward decoder
         decoder_outs = self.decoder(
             encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
@@ -210,7 +211,7 @@
             loss_ideal = None
         '''
         loss_ideal = None
-
+        
         if decoder_out_1st is None:
             decoder_out_1st = decoder_out
         # 2. Compute attention loss
@@ -287,11 +288,10 @@
                                                                enforce_sorted=False)
             _, (h_n, _) = self.bias_encoder(hw_embed)
             hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
-
+        
         decoder_outs = self.decoder(
             encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
         )
-
         decoder_out = decoder_outs[0]
         decoder_out = torch.log_softmax(decoder_out, dim=-1)
         return decoder_out, ys_pad_lens
@@ -305,42 +305,38 @@
                  **kwargs,
                  ):
         # init beamsearch
-
         is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
         is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
         if self.beam_search is None and (is_use_lm or is_use_ctc):
             logging.info("enable beam_search")
             self.init_beam_search(**kwargs)
             self.nbest = kwargs.get("nbest", 1)
-
+        
         meta_data = {}
         
         # extract fbank feats
         time1 = time.perf_counter()
-
         audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
-
         time2 = time.perf_counter()
         meta_data["load_data"] = f"{time2 - time1:0.3f}"
-
         speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
                                                frontend=frontend)
         time3 = time.perf_counter()
         meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
         meta_data[
             "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
-
+        
         speech = speech.to(device=kwargs["device"])
         speech_lengths = speech_lengths.to(device=kwargs["device"])
 
         # hotword
         self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend)
-
+        
         # Encoder
         encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
         if isinstance(encoder_out, tuple):
             encoder_out = encoder_out[0]
-
+        
         # predictor
         predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
         pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
@@ -348,7 +344,8 @@
         pre_token_length = pre_token_length.round().long()
         if torch.max(pre_token_length) < 1:
             return []
-        
+
+
         decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens,
                                                                  pre_acoustic_embeds,
                                                                  pre_token_length,
diff --git a/funasr/models/lcbnet/__init__.py b/funasr/models/lcbnet/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/models/lcbnet/__init__.py
+++ /dev/null
diff --git a/funasr/models/lcbnet/attention.py b/funasr/models/lcbnet/attention.py
deleted file mode 100644
index 8e8c594..0000000
--- a/funasr/models/lcbnet/attention.py
+++ /dev/null
@@ -1,112 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-# Copyright 2024 yufan
-#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
-
-"""Multi-Head Attention Return Weight layer definition."""
-
-import math
-
-import torch
-from torch import nn
-
-class MultiHeadedAttentionReturnWeight(nn.Module):
-    """Multi-Head Attention layer.
-
-    Args:
-        n_head (int): The number of heads.
-        n_feat (int): The number of features.
-        dropout_rate (float): Dropout rate.
-
-    """
-
-    def __init__(self, n_head, n_feat, dropout_rate):
-        """Construct an MultiHeadedAttentionReturnWeight object."""
-        super(MultiHeadedAttentionReturnWeight, self).__init__()
-        assert n_feat % n_head == 0
-        # We assume d_v always equals d_k
-        self.d_k = n_feat // n_head
-        self.h = n_head
-        self.linear_q = nn.Linear(n_feat, n_feat)
-        self.linear_k = nn.Linear(n_feat, n_feat)
-        self.linear_v = nn.Linear(n_feat, n_feat)
-        self.linear_out = nn.Linear(n_feat, n_feat)
-        self.attn = None
-        self.dropout = nn.Dropout(p=dropout_rate)
-
-    def forward_qkv(self, query, key, value):
-        """Transform query, key and value.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-
-        Returns:
-            torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
-            torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
-            torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
-
-        """
-        n_batch = query.size(0)
-        q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
-        k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
-        v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
-        q = q.transpose(1, 2)  # (batch, head, time1, d_k)
-        k = k.transpose(1, 2)  # (batch, head, time2, d_k)
-        v = v.transpose(1, 2)  # (batch, head, time2, d_k)
-
-        return q, k, v
-
-    def forward_attention(self, value, scores, mask):
-        """Compute attention context vector.
-
-        Args:
-            value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
-            scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
-            mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Transformed value (#batch, time1, d_model)
-                weighted by the attention score (#batch, time1, time2).
-
-        """
-        n_batch = value.size(0)
-        if mask is not None:
-            mask = mask.unsqueeze(1).eq(0)  # (batch, 1, *, time2)
-            min_value = torch.finfo(scores.dtype).min
-            scores = scores.masked_fill(mask, min_value)
-            self.attn = torch.softmax(scores, dim=-1).masked_fill(
-                mask, 0.0
-            )  # (batch, head, time1, time2)
-        else:
-            self.attn = torch.softmax(scores, dim=-1)  # (batch, head, time1, time2)
-
-        p_attn = self.dropout(self.attn)
-        x = torch.matmul(p_attn, value)  # (batch, head, time1, d_k)
-        x = (
-            x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
-        )  # (batch, time1, d_model)
-
-        return self.linear_out(x), self.attn  # (batch, time1, d_model)
-
-    def forward(self, query, key, value, mask):
-        """Compute scaled dot product attention.
-
-        Args:
-            query (torch.Tensor): Query tensor (#batch, time1, size).
-            key (torch.Tensor): Key tensor (#batch, time2, size).
-            value (torch.Tensor): Value tensor (#batch, time2, size).
-            mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
-                (#batch, time1, time2).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time1, d_model).
-
-        """
-        q, k, v = self.forward_qkv(query, key, value)
-        scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
-        return self.forward_attention(v, scores, mask)
-
-
diff --git a/funasr/models/lcbnet/encoder.py b/funasr/models/lcbnet/encoder.py
deleted file mode 100644
index c65823c..0000000
--- a/funasr/models/lcbnet/encoder.py
+++ /dev/null
@@ -1,392 +0,0 @@
-# Copyright 2019 Shigeki Karita
-#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
-
-"""Transformer encoder definition."""
-
-from typing import List
-from typing import Optional
-from typing import Tuple
-
-import torch
-from torch import nn
-import logging
-
-from funasr.models.transformer.attention import MultiHeadedAttention
-from funasr.models.lcbnet.attention import MultiHeadedAttentionReturnWeight
-from funasr.models.transformer.embedding import PositionalEncoding
-from funasr.models.transformer.layer_norm import LayerNorm
-
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
-from funasr.models.transformer.utils.repeat import repeat
-from funasr.register import tables
-
-class EncoderLayer(nn.Module):
-    """Encoder layer module.
-
-    Args:
-        size (int): Input dimension.
-        self_attn (torch.nn.Module): Self-attention module instance.
-            `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
-            can be used as the argument.
-        feed_forward (torch.nn.Module): Feed-forward module instance.
-            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
-            can be used as the argument.
-        dropout_rate (float): Dropout rate.
-        normalize_before (bool): Whether to use layer_norm before the first block.
-        concat_after (bool): Whether to concat attention layer's input and output.
-            if True, additional linear will be applied.
-            i.e. x -> x + linear(concat(x, att(x)))
-            if False, no additional linear will be applied. i.e. x -> x + att(x)
-        stochastic_depth_rate (float): Proability to skip this layer.
-            During training, the layer may skip residual computation and return input
-            as-is with given probability.
-    """
-
-    def __init__(
-            self,
-            size,
-            self_attn,
-            feed_forward,
-            dropout_rate,
-            normalize_before=True,
-            concat_after=False,
-            stochastic_depth_rate=0.0,
-    ):
-        """Construct an EncoderLayer object."""
-        super(EncoderLayer, self).__init__()
-        self.self_attn = self_attn
-        self.feed_forward = feed_forward
-        self.norm1 = LayerNorm(size)
-        self.norm2 = LayerNorm(size)
-        self.dropout = nn.Dropout(dropout_rate)
-        self.size = size
-        self.normalize_before = normalize_before
-        self.concat_after = concat_after
-        if self.concat_after:
-            self.concat_linear = nn.Linear(size + size, size)
-        self.stochastic_depth_rate = stochastic_depth_rate
-
-    def forward(self, x, mask, cache=None):
-        """Compute encoded features.
-
-        Args:
-            x_input (torch.Tensor): Input tensor (#batch, time, size).
-            mask (torch.Tensor): Mask tensor for the input (#batch, time).
-            cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
-
-        Returns:
-            torch.Tensor: Output tensor (#batch, time, size).
-            torch.Tensor: Mask tensor (#batch, time).
-
-        """
-        skip_layer = False
-        # with stochastic depth, residual connection `x + f(x)` becomes
-        # `x <- x + 1 / (1 - p) * f(x)` at training time.
-        stoch_layer_coeff = 1.0
-        if self.training and self.stochastic_depth_rate > 0:
-            skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
-            stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
-
-        if skip_layer:
-            if cache is not None:
-                x = torch.cat([cache, x], dim=1)
-            return x, mask
-
-        residual = x
-        if self.normalize_before:
-            x = self.norm1(x)
-
-        if cache is None:
-            x_q = x
-        else:
-            assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
-            x_q = x[:, -1:, :]
-            residual = residual[:, -1:, :]
-            mask = None if mask is None else mask[:, -1:, :]
-
-        if self.concat_after:
-            x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
-            x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
-        else:
-            x = residual + stoch_layer_coeff * self.dropout(
-                self.self_attn(x_q, x, x, mask)
-            )
-        if not self.normalize_before:
-            x = self.norm1(x)
-
-        residual = x
-        if self.normalize_before:
-            x = self.norm2(x)
-        x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
-        if not self.normalize_before:
-            x = self.norm2(x)
-
-        if cache is not None:
-            x = torch.cat([cache, x], dim=1)
-
-        return x, mask
-
-@tables.register("encoder_classes", "TransformerTextEncoder")
-class TransformerTextEncoder(nn.Module):
-    """Transformer text encoder module.
-
-    Args:
-        input_size: input dim
-        output_size: dimension of attention
-        attention_heads: the number of heads of multi head attention
-        linear_units: the number of units of position-wise feed forward
-        num_blocks: the number of decoder blocks
-        dropout_rate: dropout rate
-        attention_dropout_rate: dropout rate in attention
-        positional_dropout_rate: dropout rate after adding positional encoding
-        input_layer: input layer type
-        pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
-        normalize_before: whether to use layer_norm before the first block
-        concat_after: whether to concat attention layer's input and output
-            if True, additional linear will be applied.
-            i.e. x -> x + linear(concat(x, att(x)))
-            if False, no additional linear will be applied.
-            i.e. x -> x + att(x)
-        positionwise_layer_type: linear of conv1d
-        positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
-        padding_idx: padding_idx for input_layer=embed
-    """
-
-    def __init__(
-            self,
-            input_size: int,
-            output_size: int = 256,
-            attention_heads: int = 4,
-            linear_units: int = 2048,
-            num_blocks: int = 6,
-            dropout_rate: float = 0.1,
-            positional_dropout_rate: float = 0.1,
-            attention_dropout_rate: float = 0.0,
-            pos_enc_class=PositionalEncoding,
-            normalize_before: bool = True,
-            concat_after: bool = False,
-    ):
-        super().__init__()
-        self._output_size = output_size
-
-        self.embed = torch.nn.Sequential(
-            torch.nn.Embedding(input_size, output_size),
-            pos_enc_class(output_size, positional_dropout_rate),
-        )
-
-        self.normalize_before = normalize_before
-
-        positionwise_layer = PositionwiseFeedForward
-        positionwise_layer_args = (
-            output_size,
-            linear_units,
-            dropout_rate,
-        )
-        self.encoders = repeat(
-            num_blocks,
-            lambda lnum: EncoderLayer(
-                output_size,
-                MultiHeadedAttention(
-                    attention_heads, output_size, attention_dropout_rate
-                ),
-                positionwise_layer(*positionwise_layer_args),
-                dropout_rate,
-                normalize_before,
-                concat_after,
-            ),
-        )
-        if self.normalize_before:
-            self.after_norm = LayerNorm(output_size)
-
-    def output_size(self) -> int:
-        return self._output_size
-
-    def forward(
-            self,
-            xs_pad: torch.Tensor,
-            ilens: torch.Tensor,
-    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
-        """Embed positions in tensor.
-
-        Args:
-            xs_pad: input tensor (B, L, D)
-            ilens: input length (B)
-        Returns:
-            position embedded tensor and mask
-        """
-        masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
-        xs_pad = self.embed(xs_pad)
-
-        xs_pad, masks = self.encoders(xs_pad, masks)
-
-        if self.normalize_before:
-            xs_pad = self.after_norm(xs_pad)
-
-        olens = masks.squeeze(1).sum(1)
-        return xs_pad, olens, None
-
-
-
-
-@tables.register("encoder_classes", "FusionSANEncoder")
-class SelfSrcAttention(nn.Module):
-    """Single decoder layer module.
-
-    Args:
-        size (int): Input dimension.
-        self_attn (torch.nn.Module): Self-attention module instance.
-            `MultiHeadedAttention` instance can be used as the argument.
-        src_attn (torch.nn.Module): Self-attention module instance.
-            `MultiHeadedAttention` instance can be used as the argument.
-        feed_forward (torch.nn.Module): Feed-forward module instance.
-            `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
-            can be used as the argument.
-        dropout_rate (float): Dropout rate.
-        normalize_before (bool): Whether to use layer_norm before the first block.
-        concat_after (bool): Whether to concat attention layer's input and output.
-            if True, additional linear will be applied.
-            i.e. x -> x + linear(concat(x, att(x)))
-            if False, no additional linear will be applied. i.e. x -> x + att(x)
-
-
-    """
-    def __init__(
-            self,
-            size,
-            attention_heads,
-            attention_dim,
-            linear_units,
-            self_attention_dropout_rate,
-            src_attention_dropout_rate,
-            positional_dropout_rate,
-            dropout_rate,
-            normalize_before=True,
-            concat_after=False,
-    ):
-        """Construct an SelfSrcAttention object."""
-        super(SelfSrcAttention, self).__init__()
-        self.size = size
-        self.self_attn = MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate)
-        self.src_attn = MultiHeadedAttentionReturnWeight(attention_heads, attention_dim, src_attention_dropout_rate)
-        self.feed_forward = PositionwiseFeedForward(attention_dim, linear_units, positional_dropout_rate)
-        self.norm1 = LayerNorm(size)
-        self.norm2 = LayerNorm(size)
-        self.norm3 = LayerNorm(size)
-        self.dropout = nn.Dropout(dropout_rate)
-        self.normalize_before = normalize_before
-        self.concat_after = concat_after
-        if self.concat_after:
-            self.concat_linear1 = nn.Linear(size + size, size)
-            self.concat_linear2 = nn.Linear(size + size, size)
-
-    def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
-        """Compute decoded features.
-
-        Args:
-            tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
-            tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
-            memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
-            memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
-            cache (List[torch.Tensor]): List of cached tensors.
-                Each tensor shape should be (#batch, maxlen_out - 1, size).
-
-        Returns:
-            torch.Tensor: Output tensor(#batch, maxlen_out, size).
-            torch.Tensor: Mask for output tensor (#batch, maxlen_out).
-            torch.Tensor: Encoded memory (#batch, maxlen_in, size).
-            torch.Tensor: Encoded memory mask (#batch, maxlen_in).
-
-        """
-        residual = tgt
-        if self.normalize_before:
-            tgt = self.norm1(tgt)
-
-        if cache is None:
-            tgt_q = tgt
-            tgt_q_mask = tgt_mask
-        else:
-            # compute only the last frame query keeping dim: max_time_out -> 1
-            assert cache.shape == (
-                tgt.shape[0],
-                tgt.shape[1] - 1,
-                self.size,
-            ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
-            tgt_q = tgt[:, -1:, :]
-            residual = residual[:, -1:, :]
-            tgt_q_mask = None
-            if tgt_mask is not None:
-                tgt_q_mask = tgt_mask[:, -1:, :]
-
-        if self.concat_after:
-            tgt_concat = torch.cat(
-                (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
-            )
-            x = residual + self.concat_linear1(tgt_concat)
-        else:
-            x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
-        if not self.normalize_before:
-            x = self.norm1(x)
-
-        residual = x
-        if self.normalize_before:
-            x = self.norm2(x)
-        if self.concat_after:
-            x_concat = torch.cat(
-                (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
-            )
-            x = residual + self.concat_linear2(x_concat)
-        else:
-            x, score = self.src_attn(x, memory, memory, memory_mask)
-            x = residual + self.dropout(x)
-        if not self.normalize_before:
-            x = self.norm2(x)
-
-        residual = x
-        if self.normalize_before:
-            x = self.norm3(x)
-        x = residual + self.dropout(self.feed_forward(x))
-        if not self.normalize_before:
-            x = self.norm3(x)
-
-        if cache is not None:
-            x = torch.cat([cache, x], dim=1)
-
-        return x, tgt_mask, memory, memory_mask
-
-
-@tables.register("encoder_classes", "ConvBiasPredictor")
-class ConvPredictor(nn.Module):
-    def __init__(self, size=256, l_order=3, r_order=3, attention_heads=4, attention_dropout_rate=0.1, linear_units=2048):
-        super().__init__()
-        self.atten = MultiHeadedAttention(attention_heads, size, attention_dropout_rate)
-        self.norm1 = LayerNorm(size)
-        self.feed_forward = PositionwiseFeedForward(size, linear_units, attention_dropout_rate)
-        self.norm2 = LayerNorm(size)
-        self.pad = nn.ConstantPad1d((l_order, r_order), 0)
-        self.conv1d = nn.Conv1d(size, size, l_order + r_order + 1, groups=size)
-        self.output_linear = nn.Linear(size, 1)
-
-
-    def forward(self, text_enc, asr_enc):
-        # stage1 cross-attention
-        residual = text_enc
-        text_enc = residual + self.atten(text_enc, asr_enc, asr_enc, None)
-        
-        # stage2 FFN
-        residual = text_enc
-        text_enc = self.norm1(text_enc)
-        text_enc = residual + self.feed_forward(text_enc)
-        
-        # stage Conv predictor
-        text_enc = self.norm2(text_enc)
-        context = text_enc.transpose(1, 2)
-        queries = self.pad(context)
-        memory = self.conv1d(queries)
-        output = memory + context
-        output = output.transpose(1, 2)
-        output = torch.relu(output)
-        output = self.output_linear(output)
-        if output.dim()==3:
-          output = output.squeeze(2)
-        return output
diff --git a/funasr/models/lcbnet/model.py b/funasr/models/lcbnet/model.py
deleted file mode 100644
index 3ac319c..0000000
--- a/funasr/models/lcbnet/model.py
+++ /dev/null
@@ -1,495 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-#  MIT License  (https://opensource.org/licenses/MIT)
-
-import logging
-from typing import Union, Dict, List, Tuple, Optional
-
-import time
-import torch
-import torch.nn as nn
-from torch.cuda.amp import autocast
-
-from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
-from funasr.models.ctc.ctc import CTC
-from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
-from funasr.metrics.compute_acc import th_accuracy
-# from funasr.models.e2e_asr_common import ErrorCalculator
-from funasr.train_utils.device_funcs import force_gatherable
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-from funasr.utils import postprocess_utils
-from funasr.utils.datadir_writer import DatadirWriter
-from funasr.register import tables
-
-import pdb
-@tables.register("model_classes", "LCBNet")
-class LCBNet(nn.Module):
-    """
-    Author: Speech Lab of DAMO Academy, Alibaba Group
-    LCB-NET: LONG-CONTEXT BIASING FOR AUDIO-VISUAL SPEECH RECOGNITION
-    https://arxiv.org/abs/2401.06390
-    """
-    
-    def __init__(
-        self,
-        specaug: str = None,
-        specaug_conf: dict = None,
-        normalize: str = None,
-        normalize_conf: dict = None,
-        encoder: str = None,
-        encoder_conf: dict = None,
-        decoder: str = None,
-        decoder_conf: dict = None,
-        text_encoder: str = None,
-        text_encoder_conf: dict = None,
-        bias_predictor: str = None,
-        bias_predictor_conf: dict = None,
-        fusion_encoder: str = None,
-        fusion_encoder_conf: dict = None,
-        ctc: str = None,
-        ctc_conf: dict = None,
-        ctc_weight: float = 0.5,
-        interctc_weight: float = 0.0,
-        select_num: int = 2,
-        select_length: int = 3,
-        insert_blank: bool = True,
-        input_size: int = 80,
-        vocab_size: int = -1,
-        ignore_id: int = -1,
-        blank_id: int = 0,
-        sos: int = 1,
-        eos: int = 2,
-        lsm_weight: float = 0.0,
-        length_normalized_loss: bool = False,
-        report_cer: bool = True,
-        report_wer: bool = True,
-        sym_space: str = "<space>",
-        sym_blank: str = "<blank>",
-        # extract_feats_in_collect_stats: bool = True,
-        share_embedding: bool = False,
-        # preencoder: Optional[AbsPreEncoder] = None,
-        # postencoder: Optional[AbsPostEncoder] = None,
-        **kwargs,
-    ):
-
-        super().__init__()
-
-        if specaug is not None:
-            specaug_class = tables.specaug_classes.get(specaug)
-            specaug = specaug_class(**specaug_conf)
-        if normalize is not None:
-            normalize_class = tables.normalize_classes.get(normalize)
-            normalize = normalize_class(**normalize_conf)
-        encoder_class = tables.encoder_classes.get(encoder)
-        encoder = encoder_class(input_size=input_size, **encoder_conf)
-        encoder_output_size = encoder.output_size()
-
-        # lcbnet modules: text encoder, fusion encoder and bias predictor
-        text_encoder_class = tables.encoder_classes.get(text_encoder)
-        text_encoder = text_encoder_class(input_size=vocab_size, **text_encoder_conf)
-        fusion_encoder_class = tables.encoder_classes.get(fusion_encoder)
-        fusion_encoder = fusion_encoder_class(**fusion_encoder_conf)
-        bias_predictor_class = tables.encoder_classes.get(bias_predictor)
-        bias_predictor = bias_predictor_class(**bias_predictor_conf)
-
-
-        if decoder is not None:
-            decoder_class = tables.decoder_classes.get(decoder)
-            decoder = decoder_class(
-                vocab_size=vocab_size,
-                encoder_output_size=encoder_output_size,
-                **decoder_conf,
-            )
-        if ctc_weight > 0.0:
-            
-            if ctc_conf is None:
-                ctc_conf = {}
-            
-            ctc = CTC(
-                odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
-            )
-    
-        self.blank_id = blank_id
-        self.sos = vocab_size - 1
-        self.eos = vocab_size - 1
-        self.vocab_size = vocab_size
-        self.ignore_id = ignore_id
-        self.ctc_weight = ctc_weight
-        self.specaug = specaug
-        self.normalize = normalize
-        self.encoder = encoder
-        # lcbnet
-        self.text_encoder = text_encoder
-        self.fusion_encoder = fusion_encoder
-        self.bias_predictor = bias_predictor
-        self.select_num = select_num
-        self.select_length = select_length
-        self.insert_blank = insert_blank
-
-        if not hasattr(self.encoder, "interctc_use_conditioning"):
-            self.encoder.interctc_use_conditioning = False
-        if self.encoder.interctc_use_conditioning:
-            self.encoder.conditioning_layer = torch.nn.Linear(
-                vocab_size, self.encoder.output_size()
-            )
-        self.interctc_weight = interctc_weight
-
-        # self.error_calculator = None
-        if ctc_weight == 1.0:
-            self.decoder = None
-        else:
-            self.decoder = decoder
-        
-        self.criterion_att = LabelSmoothingLoss(
-            size=vocab_size,
-            padding_idx=ignore_id,
-            smoothing=lsm_weight,
-            normalize_length=length_normalized_loss,
-        )
-        #
-        # if report_cer or report_wer:
-        #     self.error_calculator = ErrorCalculator(
-        #         token_list, sym_space, sym_blank, report_cer, report_wer
-        #     )
-        #
-        self.error_calculator = None
-        if ctc_weight == 0.0:
-            self.ctc = None
-        else:
-            self.ctc = ctc
-            
-        self.share_embedding = share_embedding
-        if self.share_embedding:
-            self.decoder.embed = None
-        
-        self.length_normalized_loss = length_normalized_loss
-        self.beam_search = None
-    
-    def forward(
-        self,
-        speech: torch.Tensor,
-        speech_lengths: torch.Tensor,
-        text: torch.Tensor,
-        text_lengths: torch.Tensor,
-        **kwargs,
-    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
-        """Encoder + Decoder + Calc loss
-        Args:
-                speech: (Batch, Length, ...)
-                speech_lengths: (Batch, )
-                text: (Batch, Length)
-                text_lengths: (Batch,)
-        """
-
-        if len(text_lengths.size()) > 1:
-            text_lengths = text_lengths[:, 0]
-        if len(speech_lengths.size()) > 1:
-            speech_lengths = speech_lengths[:, 0]
-        
-        batch_size = speech.shape[0]
-        
-        # 1. Encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-        intermediate_outs = None
-        if isinstance(encoder_out, tuple):
-            intermediate_outs = encoder_out[1]
-            encoder_out = encoder_out[0]
-        
-        loss_att, acc_att, cer_att, wer_att = None, None, None, None
-        loss_ctc, cer_ctc = None, None
-        stats = dict()
-        
-        # decoder: CTC branch
-        if self.ctc_weight != 0.0:
-            loss_ctc, cer_ctc = self._calc_ctc_loss(
-                encoder_out, encoder_out_lens, text, text_lengths
-            )
-            
-            # Collect CTC branch stats
-            stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
-            stats["cer_ctc"] = cer_ctc
-        
-        # Intermediate CTC (optional)
-        loss_interctc = 0.0
-        if self.interctc_weight != 0.0 and intermediate_outs is not None:
-            for layer_idx, intermediate_out in intermediate_outs:
-                # we assume intermediate_out has the same length & padding
-                # as those of encoder_out
-                loss_ic, cer_ic = self._calc_ctc_loss(
-                    intermediate_out, encoder_out_lens, text, text_lengths
-                )
-                loss_interctc = loss_interctc + loss_ic
-                
-                # Collect Intermedaite CTC stats
-                stats["loss_interctc_layer{}".format(layer_idx)] = (
-                    loss_ic.detach() if loss_ic is not None else None
-                )
-                stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
-            
-            loss_interctc = loss_interctc / len(intermediate_outs)
-            
-            # calculate whole encoder loss
-            loss_ctc = (
-                           1 - self.interctc_weight
-                       ) * loss_ctc + self.interctc_weight * loss_interctc
-        
-        # decoder: Attention decoder branch
-        loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
-            encoder_out, encoder_out_lens, text, text_lengths
-        )
-        
-        # 3. CTC-Att loss definition
-        if self.ctc_weight == 0.0:
-            loss = loss_att
-        elif self.ctc_weight == 1.0:
-            loss = loss_ctc
-        else:
-            loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
-        
-        # Collect Attn branch stats
-        stats["loss_att"] = loss_att.detach() if loss_att is not None else None
-        stats["acc"] = acc_att
-        stats["cer"] = cer_att
-        stats["wer"] = wer_att
-        
-        # Collect total loss stats
-        stats["loss"] = torch.clone(loss.detach())
-        
-        # force_gatherable: to-device and to-tensor if scalar for DataParallel
-        if self.length_normalized_loss:
-            batch_size = int((text_lengths + 1).sum())
-        loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
-        return loss, stats, weight
-    
-
-    def encode(
-        self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
-    ) -> Tuple[torch.Tensor, torch.Tensor]:
-        """Frontend + Encoder. Note that this method is used by asr_inference.py
-        Args:
-                speech: (Batch, Length, ...)
-                speech_lengths: (Batch, )
-                ind: int
-        """
-        with autocast(False):
-            # Data augmentation
-            if self.specaug is not None and self.training:
-                speech, speech_lengths = self.specaug(speech, speech_lengths)
-            # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
-            if self.normalize is not None:
-                speech, speech_lengths = self.normalize(speech, speech_lengths)
-        # Forward encoder
-        # feats: (Batch, Length, Dim)
-        # -> encoder_out: (Batch, Length2, Dim2)
-        if self.encoder.interctc_use_conditioning:
-            encoder_out, encoder_out_lens, _ = self.encoder(
-                speech, speech_lengths, ctc=self.ctc
-            )
-        else:
-            encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
-        intermediate_outs = None
-        if isinstance(encoder_out, tuple):
-            intermediate_outs = encoder_out[1]
-            encoder_out = encoder_out[0]
-        
-        if intermediate_outs is not None:
-            return (encoder_out, intermediate_outs), encoder_out_lens
-        return encoder_out, encoder_out_lens
-    
-    def _calc_att_loss(
-        self,
-        encoder_out: torch.Tensor,
-        encoder_out_lens: torch.Tensor,
-        ys_pad: torch.Tensor,
-        ys_pad_lens: torch.Tensor,
-    ):
-        ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
-        ys_in_lens = ys_pad_lens + 1
-        
-        # 1. Forward decoder
-        decoder_out, _ = self.decoder(
-            encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
-        )
-        
-        # 2. Compute attention loss
-        loss_att = self.criterion_att(decoder_out, ys_out_pad)
-        acc_att = th_accuracy(
-            decoder_out.view(-1, self.vocab_size),
-            ys_out_pad,
-            ignore_label=self.ignore_id,
-        )
-        
-        # Compute cer/wer using attention-decoder
-        if self.training or self.error_calculator is None:
-            cer_att, wer_att = None, None
-        else:
-            ys_hat = decoder_out.argmax(dim=-1)
-            cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-        
-        return loss_att, acc_att, cer_att, wer_att
-    
-    def _calc_ctc_loss(
-        self,
-        encoder_out: torch.Tensor,
-        encoder_out_lens: torch.Tensor,
-        ys_pad: torch.Tensor,
-        ys_pad_lens: torch.Tensor,
-    ):
-        # Calc CTC loss
-        loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
-        
-        # Calc CER using CTC
-        cer_ctc = None
-        if not self.training and self.error_calculator is not None:
-            ys_hat = self.ctc.argmax(encoder_out).data
-            cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
-        return loss_ctc, cer_ctc
-    
-    def init_beam_search(self,
-                         **kwargs,
-                         ):
-        from funasr.models.transformer.search import BeamSearch
-        from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
-        from funasr.models.transformer.scorers.length_bonus import LengthBonus
-    
-        # 1. Build ASR model
-        scorers = {}
-        
-        if self.ctc != None:
-            ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
-            scorers.update(
-                ctc=ctc
-            )
-        token_list = kwargs.get("token_list")
-        scorers.update(
-            decoder=self.decoder,
-            length_bonus=LengthBonus(len(token_list)),
-        )
-
-        
-        # 3. Build ngram model
-        # ngram is not supported now
-        ngram = None
-        scorers["ngram"] = ngram
-        
-        weights = dict(
-            decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.3),
-            ctc=kwargs.get("decoding_ctc_weight", 0.3),
-            lm=kwargs.get("lm_weight", 0.0),
-            ngram=kwargs.get("ngram_weight", 0.0),
-            length_bonus=kwargs.get("penalty", 0.0),
-        )
-        beam_search = BeamSearch(
-            beam_size=kwargs.get("beam_size", 20),
-            weights=weights,
-            scorers=scorers,
-            sos=self.sos,
-            eos=self.eos,
-            vocab_size=len(token_list),
-            token_list=token_list,
-            pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
-        )
-
-        self.beam_search = beam_search
-        
-    def inference(self,
-             data_in,
-             data_lengths=None,
-             key: list=None,
-             tokenizer=None,
-             frontend=None,
-             **kwargs,
-             ):
-        
-        if kwargs.get("batch_size", 1) > 1:
-            raise NotImplementedError("batch decoding is not implemented")
-        
-        # init beamsearch
-        if self.beam_search is None:
-            logging.info("enable beam_search")
-            self.init_beam_search(**kwargs)
-            self.nbest = kwargs.get("nbest", 1)
-
-        meta_data = {}
-        if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank":  # fbank
-            speech, speech_lengths = data_in, data_lengths
-            if len(speech.shape) < 3:
-                speech = speech[None, :, :]
-            if speech_lengths is None:
-                speech_lengths = speech.shape[1]
-        else:
-            # extract fbank feats
-            time1 = time.perf_counter()
-            sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
-                                                            data_type=kwargs.get("data_type", "sound"),
-                                                            tokenizer=tokenizer)
-            time2 = time.perf_counter()
-            meta_data["load_data"] = f"{time2 - time1:0.3f}"
-            audio_sample_list = sample_list[0]
-            if len(sample_list) >1:
-                ocr_sample_list = sample_list[1]
-            else:
-                ocr_sample_list = [[294, 0]]
-            speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
-                                                   frontend=frontend)
-            time3 = time.perf_counter()
-            meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
-            frame_shift = 10 
-            meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift / 1000
-
-        speech = speech.to(device=kwargs["device"])
-        speech_lengths = speech_lengths.to(device=kwargs["device"])
-        # Encoder
-        encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-        if isinstance(encoder_out, tuple):
-            encoder_out = encoder_out[0]
-
-        ocr_list_new = [[x + 1 if x != 0 else x for x in sublist] for sublist in ocr_sample_list]
-        ocr = torch.tensor(ocr_list_new).to(device=kwargs["device"])
-        ocr_lengths = ocr.new_full([1], dtype=torch.long, fill_value=ocr.size(1)).to(device=kwargs["device"])
-        ocr, ocr_lens, _ = self.text_encoder(ocr, ocr_lengths)
-        fusion_out, _, _, _ = self.fusion_encoder(encoder_out,None, ocr, None)
-        encoder_out = encoder_out + fusion_out
-        # c. Passed the encoder result and the beam search
-        nbest_hyps = self.beam_search(
-            x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
-        )
-        
-        nbest_hyps = nbest_hyps[: self.nbest]
-
-        results = []
-        b, n, d = encoder_out.size()
-        for i in range(b):
-
-            for nbest_idx, hyp in enumerate(nbest_hyps):
-                ibest_writer = None
-                if kwargs.get("output_dir") is not None:
-                    if not hasattr(self, "writer"):
-                        self.writer = DatadirWriter(kwargs.get("output_dir"))
-                    ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
-                    
-                # remove sos/eos and get results
-                last_pos = -1
-                if isinstance(hyp.yseq, list):
-                    token_int = hyp.yseq[1:last_pos]
-                else:
-                    token_int = hyp.yseq[1:last_pos].tolist()
-                    
-                # remove blank symbol id, which is assumed to be 0
-                token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
-                
-                # Change integer-ids to tokens
-                token = tokenizer.ids2tokens(token_int)
-                text = tokenizer.tokens2text(token)
-
-                text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
-                result_i = {"key": key[i], "token": token, "text": text_postprocessed}
-                results.append(result_i)
-                
-                if ibest_writer is not None:
-                    ibest_writer["token"][key[i]] = " ".join(token)
-                    ibest_writer["text"][key[i]] = text_postprocessed
-        
-        return results, meta_data
-
diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py
index a8b1f1f..20b0cc8 100644
--- a/funasr/models/seaco_paraformer/model.py
+++ b/funasr/models/seaco_paraformer/model.py
@@ -30,7 +30,7 @@
 from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
 from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
 
-import pdb
+
 if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
     from torch.cuda.amp import autocast
 else:
@@ -128,7 +128,7 @@
         hotword_pad = kwargs.get("hotword_pad")
         hotword_lengths = kwargs.get("hotword_lengths")
         dha_pad = kwargs.get("dha_pad")
-        
+
         batch_size = speech.shape[0]
         # for data-parallel
         text = text[:, : text_lengths.max()]
@@ -209,20 +209,17 @@
                                nfilter=50,
                                seaco_weight=1.0):
         # decoder forward
-
         decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True)
-
         decoder_pred = torch.log_softmax(decoder_out, dim=-1)
         if hw_list is not None:
             hw_lengths = [len(i) for i in hw_list]
             hw_list_ = [torch.Tensor(i).long() for i in hw_list]
             hw_list_pad = pad_list(hw_list_, 0).to(encoder_out.device)
             selected = self._hotword_representation(hw_list_pad, torch.Tensor(hw_lengths).int().to(encoder_out.device))
-
             contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device)
             num_hot_word = contextual_info.shape[1]
             _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device)
-
+            
             # ASF Core
             if nfilter > 0 and nfilter < num_hot_word:
                 hotword_scores = self.seaco_decoder.forward_asf6(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens)
@@ -242,7 +239,7 @@
             cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens)
             dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens)
             merged = self._merge(cif_attended, dec_attended)
-
+            
             dha_output = self.hotword_output_layer(merged)  # remove the last token in loss calculation
             dha_pred = torch.log_softmax(dha_output, dim=-1)
             def _merge_res(dec_output, dha_output):
@@ -256,8 +253,8 @@
                 # logits = dec_output * dha_mask + dha_output[:,:,:-1] * (1-dha_mask)
                 logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask)
                 return logits
-
             merged_pred = _merge_res(decoder_pred, dha_pred)
+            # import pdb; pdb.set_trace()
             return merged_pred
         else:
             return decoder_pred
@@ -307,6 +304,7 @@
             logging.info("enable beam_search")
             self.init_beam_search(**kwargs)
             self.nbest = kwargs.get("nbest", 1)
+        
         meta_data = {}
         
         # extract fbank feats
@@ -332,7 +330,6 @@
         if isinstance(encoder_out, tuple):
             encoder_out = encoder_out[0]
         
-
         # predictor
         predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
         pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \
@@ -341,14 +338,15 @@
         if torch.max(pre_token_length) < 1:
             return []
 
+
         decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens,
                                                    pre_acoustic_embeds,
                                                    pre_token_length,
                                                    hw_list=self.hotword_list)
-
         # decoder_out, _ = decoder_outs[0], decoder_outs[1]
         _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
                                                                   pre_token_length)
+        
         results = []
         b, n, d = decoder_out.size()
         for i in range(b):
diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py
index 0c46449..ea23725 100644
--- a/funasr/train_utils/load_pretrained_model.py
+++ b/funasr/train_utils/load_pretrained_model.py
@@ -7,7 +7,7 @@
 import torch
 import torch.nn
 import torch.optim
-import pdb
+
 
 def filter_state_dict(
 	dst_state: Dict[str, Union[float, torch.Tensor]],
@@ -63,7 +63,6 @@
 	dst_state = obj.state_dict()
 	
 	print(f"ckpt: {path}")
-
 	if oss_bucket is None:
 		src_state = torch.load(path, map_location=map_location)
 	else:
diff --git a/funasr/utils/load_utils.py b/funasr/utils/load_utils.py
index 84c38f9..7748172 100644
--- a/funasr/utils/load_utils.py
+++ b/funasr/utils/load_utils.py
@@ -13,25 +13,29 @@
     from funasr.download.file import download_from_url
 except:
     print("urllib is not installed, if you infer from url, please install it first.")
-import pdb
+
 
 
 def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None, **kwargs):
     if isinstance(data_or_path_or_list, (list, tuple)):
         if data_type is not None and isinstance(data_type, (list, tuple)):
+
             data_types = [data_type] * len(data_or_path_or_list)
             data_or_path_or_list_ret = [[] for d in data_type]
             for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)):
+                
                 for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)):
+                    
                     data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs)
                     data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
 
             return data_or_path_or_list_ret
         else:
             return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list]
+    
     if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file
         data_or_path_or_list = download_from_url(data_or_path_or_list)
-
+    
     if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
         if data_type is None or data_type == "sound":
             data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
@@ -52,22 +56,10 @@
         data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
     elif isinstance(data_or_path_or_list, np.ndarray):  # audio sample point
         data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze()  # [n_samples,]
-    elif isinstance(data_or_path_or_list, str) and data_type == "kaldi_ark":
-        data_mat = kaldiio.load_mat(data_or_path_or_list) 
-        if isinstance(data_mat, tuple):
-            audio_fs, mat = data_mat
-        else:
-            mat = data_mat
-        if mat.dtype == 'int16' or mat.dtype == 'int32':
-            mat = mat.astype(np.float64)
-            mat = mat / 32768
-        if mat.ndim ==2:
-            mat = mat[:,0]
-        data_or_path_or_list = mat
     else:
         pass
         # print(f"unsupport data type: {data_or_path_or_list}, return raw data")
-
+        
     if audio_fs != fs and data_type != "text":
         resampler = torchaudio.transforms.Resample(audio_fs, fs)
         data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
@@ -89,6 +81,8 @@
     return array
 
 def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None, **kwargs):
+    # import pdb;
+    # pdb.set_trace()
     if isinstance(data, np.ndarray):
         data = torch.from_numpy(data)
         if len(data.shape) < 2:
@@ -106,7 +100,9 @@
             data_list.append(data_i)
             data_len.append(data_i.shape[0])
         data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
-
+    # import pdb;
+    # pdb.set_trace()
+    # if data_type == "sound":
     data, data_len = frontend(data, data_len, **kwargs)
     
     if isinstance(data_len, (list, tuple)):

--
Gitblit v1.9.1