From 920331972a136834a560d78917de60f6c6623d96 Mon Sep 17 00:00:00 2001
From: 语帆 <yf352572@alibaba-inc.com>
Date: 星期一, 04 三月 2024 17:47:25 +0800
Subject: [PATCH] commit
---
examples/industrial_data_pretraining/lcbnet/compute_wer_details.py | 702 +++++++++++++++++++++
funasr/auto/auto_model.py | 12
funasr/models/contextual_paraformer/model.py | 39
examples/industrial_data_pretraining/contextual_paraformer/path.sh | 6
examples/industrial_data_pretraining/contextual_paraformer/demo.py | 0
funasr/models/lcbnet/model.py | 495 +++++++++++++++
funasr/utils/load_utils.py | 28
examples/industrial_data_pretraining/lcbnet/demo.py | 13
funasr/models/lcbnet/attention.py | 112 +++
examples/industrial_data_pretraining/lcbnet/utils | 1
funasr/frontends/default.py | 20
funasr/train_utils/load_pretrained_model.py | 3
funasr/models/lcbnet/__init__.py | 0
examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh | 11
funasr/models/seaco_paraformer/model.py | 18
examples/industrial_data_pretraining/contextual_paraformer/demo2.sh | 9
funasr/models/conformer/encoder.py | 2
funasr/models/lcbnet/encoder.py | 392 +++++++++++
examples/industrial_data_pretraining/seaco_paraformer/demo.py | 10
examples/industrial_data_pretraining/contextual_paraformer/demo.sh | 2
examples/industrial_data_pretraining/lcbnet/demo.sh | 72 ++
21 files changed, 1,886 insertions(+), 61 deletions(-)
diff --git a/examples/industrial_data_pretraining/contextual_paraformer/demo.py b/examples/industrial_data_pretraining/contextual_paraformer/demo.py
old mode 100644
new mode 100755
diff --git a/examples/industrial_data_pretraining/contextual_paraformer/demo.sh b/examples/industrial_data_pretraining/contextual_paraformer/demo.sh
old mode 100644
new mode 100755
index 8fc66f3..1bd4f7f
--- a/examples/industrial_data_pretraining/contextual_paraformer/demo.sh
+++ b/examples/industrial_data_pretraining/contextual_paraformer/demo.sh
@@ -2,7 +2,7 @@
model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
model_revision="v2.0.4"
-python funasr/bin/inference.py \
+python ../../../funasr/bin/inference.py \
+model=${model} \
+model_revision=${model_revision} \
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \
diff --git a/examples/industrial_data_pretraining/contextual_paraformer/demo2.sh b/examples/industrial_data_pretraining/contextual_paraformer/demo2.sh
new file mode 100755
index 0000000..282f4f1
--- /dev/null
+++ b/examples/industrial_data_pretraining/contextual_paraformer/demo2.sh
@@ -0,0 +1,9 @@
+python -m funasr.bin.inference \
+--config-path="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \
+--config-name="config.yaml" \
+++init_param="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/model.pb" \
+++tokenizer_conf.token_list="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/tokens.txt" \
+++frontend_conf.cmvn_file="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/am.mvn" \
+++input="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/asr_example_zh.wav" \
+++output_dir="./outputs/debug2" \
+++device="" \
diff --git a/examples/industrial_data_pretraining/contextual_paraformer/path.sh b/examples/industrial_data_pretraining/contextual_paraformer/path.sh
new file mode 100755
index 0000000..1a6d67e
--- /dev/null
+++ b/examples/industrial_data_pretraining/contextual_paraformer/path.sh
@@ -0,0 +1,6 @@
+export FUNASR_DIR=$PWD/../../../
+
+# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
+export PYTHONIOENCODING=UTF-8
+export PATH=$FUNASR_DIR/funasr/bin:$PATH
+export PYTHONPATH=$FUNASR_DIR/funasr/bin:$FUNASR_DIR/funasr:$FUNASR_DIR:$PYTHONPATH
diff --git a/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py b/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py
new file mode 100755
index 0000000..e72d871
--- /dev/null
+++ b/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py
@@ -0,0 +1,702 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+
+from enum import Enum
+import re, sys, unicodedata
+import codecs
+import argparse
+from tqdm import tqdm
+import os
+import pdb
+remove_tag = False
+spacelist = [" ", "\t", "\r", "\n"]
+puncts = [
+ "!",
+ ",",
+ "?",
+ "銆�",
+ "銆�",
+ "锛�",
+ "锛�",
+ "锛�",
+ "锛�",
+ "锛�",
+ "銆�",
+ "銆�",
+ "锔�",
+ "銆�",
+ "銆�",
+ "銆�",
+ "銆�",
+]
+
+
+class Code(Enum):
+ match = 1
+ substitution = 2
+ insertion = 3
+ deletion = 4
+
+
+class WordError(object):
+ def __init__(self):
+ self.errors = {
+ Code.substitution: 0,
+ Code.insertion: 0,
+ Code.deletion: 0,
+ }
+ self.ref_words = 0
+
+ def get_wer(self):
+ assert self.ref_words != 0
+ errors = (
+ self.errors[Code.substitution]
+ + self.errors[Code.insertion]
+ + self.errors[Code.deletion]
+ )
+ return 100.0 * errors / self.ref_words
+
+ def get_result_string(self):
+ return (
+ f"error_rate={self.get_wer():.4f}, "
+ f"ref_words={self.ref_words}, "
+ f"subs={self.errors[Code.substitution]}, "
+ f"ins={self.errors[Code.insertion]}, "
+ f"dels={self.errors[Code.deletion]}"
+ )
+
+
+def characterize(string):
+ res = []
+ i = 0
+ while i < len(string):
+ char = string[i]
+ if char in puncts:
+ i += 1
+ continue
+ cat1 = unicodedata.category(char)
+ # https://unicodebook.readthedocs.io/unicode.html#unicode-categories
+ if cat1 == "Zs" or cat1 == "Cn" or char in spacelist: # space or not assigned
+ i += 1
+ continue
+ if cat1 == "Lo": # letter-other
+ res.append(char)
+ i += 1
+ else:
+ # some input looks like: <unk><noise>, we want to separate it to two words.
+ sep = " "
+ if char == "<":
+ sep = ">"
+ j = i + 1
+ while j < len(string):
+ c = string[j]
+ if ord(c) >= 128 or (c in spacelist) or (c == sep):
+ break
+ j += 1
+ if j < len(string) and string[j] == ">":
+ j += 1
+ res.append(string[i:j])
+ i = j
+ return res
+
+
+def stripoff_tags(x):
+ if not x:
+ return ""
+ chars = []
+ i = 0
+ T = len(x)
+ while i < T:
+ if x[i] == "<":
+ while i < T and x[i] != ">":
+ i += 1
+ i += 1
+ else:
+ chars.append(x[i])
+ i += 1
+ return "".join(chars)
+
+
+def normalize(sentence, ignore_words, cs, split=None):
+ """sentence, ignore_words are both in unicode"""
+ new_sentence = []
+ for token in sentence:
+ x = token
+ if not cs:
+ x = x.upper()
+ if x in ignore_words:
+ continue
+ if remove_tag:
+ x = stripoff_tags(x)
+ if not x:
+ continue
+ if split and x in split:
+ new_sentence += split[x]
+ else:
+ new_sentence.append(x)
+ return new_sentence
+
+
+class Calculator:
+ def __init__(self):
+ self.data = {}
+ self.space = []
+ self.cost = {}
+ self.cost["cor"] = 0
+ self.cost["sub"] = 1
+ self.cost["del"] = 1
+ self.cost["ins"] = 1
+
+ def calculate(self, lab, rec):
+ # Initialization
+ lab.insert(0, "")
+ rec.insert(0, "")
+ while len(self.space) < len(lab):
+ self.space.append([])
+ for row in self.space:
+ for element in row:
+ element["dist"] = 0
+ element["error"] = "non"
+ while len(row) < len(rec):
+ row.append({"dist": 0, "error": "non"})
+ for i in range(len(lab)):
+ self.space[i][0]["dist"] = i
+ self.space[i][0]["error"] = "del"
+ for j in range(len(rec)):
+ self.space[0][j]["dist"] = j
+ self.space[0][j]["error"] = "ins"
+ self.space[0][0]["error"] = "non"
+ for token in lab:
+ if token not in self.data and len(token) > 0:
+ self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
+ for token in rec:
+ if token not in self.data and len(token) > 0:
+ self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
+ # Computing edit distance
+ for i, lab_token in enumerate(lab):
+ for j, rec_token in enumerate(rec):
+ if i == 0 or j == 0:
+ continue
+ min_dist = sys.maxsize
+ min_error = "none"
+ dist = self.space[i - 1][j]["dist"] + self.cost["del"]
+ error = "del"
+ if dist < min_dist:
+ min_dist = dist
+ min_error = error
+ dist = self.space[i][j - 1]["dist"] + self.cost["ins"]
+ error = "ins"
+ if dist < min_dist:
+ min_dist = dist
+ min_error = error
+ if lab_token == rec_token.replace("<BIAS>", ""):
+ dist = self.space[i - 1][j - 1]["dist"] + self.cost["cor"]
+ error = "cor"
+ else:
+ dist = self.space[i - 1][j - 1]["dist"] + self.cost["sub"]
+ error = "sub"
+ if dist < min_dist:
+ min_dist = dist
+ min_error = error
+ self.space[i][j]["dist"] = min_dist
+ self.space[i][j]["error"] = min_error
+ # Tracing back
+ result = {
+ "lab": [],
+ "rec": [],
+ "code": [],
+ "all": 0,
+ "cor": 0,
+ "sub": 0,
+ "ins": 0,
+ "del": 0,
+ }
+ i = len(lab) - 1
+ j = len(rec) - 1
+ while True:
+ if self.space[i][j]["error"] == "cor": # correct
+ if len(lab[i]) > 0:
+ self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
+ self.data[lab[i]]["cor"] = self.data[lab[i]]["cor"] + 1
+ result["all"] = result["all"] + 1
+ result["cor"] = result["cor"] + 1
+ result["lab"].insert(0, lab[i])
+ result["rec"].insert(0, rec[j])
+ result["code"].insert(0, Code.match)
+ i = i - 1
+ j = j - 1
+ elif self.space[i][j]["error"] == "sub": # substitution
+ if len(lab[i]) > 0:
+ self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
+ self.data[lab[i]]["sub"] = self.data[lab[i]]["sub"] + 1
+ result["all"] = result["all"] + 1
+ result["sub"] = result["sub"] + 1
+ result["lab"].insert(0, lab[i])
+ result["rec"].insert(0, rec[j])
+ result["code"].insert(0, Code.substitution)
+ i = i - 1
+ j = j - 1
+ elif self.space[i][j]["error"] == "del": # deletion
+ if len(lab[i]) > 0:
+ self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
+ self.data[lab[i]]["del"] = self.data[lab[i]]["del"] + 1
+ result["all"] = result["all"] + 1
+ result["del"] = result["del"] + 1
+ result["lab"].insert(0, lab[i])
+ result["rec"].insert(0, "")
+ result["code"].insert(0, Code.deletion)
+ i = i - 1
+ elif self.space[i][j]["error"] == "ins": # insertion
+ if len(rec[j]) > 0:
+ self.data[rec[j]]["ins"] = self.data[rec[j]]["ins"] + 1
+ result["ins"] = result["ins"] + 1
+ result["lab"].insert(0, "")
+ result["rec"].insert(0, rec[j])
+ result["code"].insert(0, Code.insertion)
+ j = j - 1
+ elif self.space[i][j]["error"] == "non": # starting point
+ break
+ else: # shouldn't reach here
+ print(
+ "this should not happen , i = {i} , j = {j} , error = {error}".format(
+ i=i, j=j, error=self.space[i][j]["error"]
+ )
+ )
+ return result
+
+ def overall(self):
+ result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
+ for token in self.data:
+ result["all"] = result["all"] + self.data[token]["all"]
+ result["cor"] = result["cor"] + self.data[token]["cor"]
+ result["sub"] = result["sub"] + self.data[token]["sub"]
+ result["ins"] = result["ins"] + self.data[token]["ins"]
+ result["del"] = result["del"] + self.data[token]["del"]
+ return result
+
+ def cluster(self, data):
+ result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
+ for token in data:
+ if token in self.data:
+ result["all"] = result["all"] + self.data[token]["all"]
+ result["cor"] = result["cor"] + self.data[token]["cor"]
+ result["sub"] = result["sub"] + self.data[token]["sub"]
+ result["ins"] = result["ins"] + self.data[token]["ins"]
+ result["del"] = result["del"] + self.data[token]["del"]
+ return result
+
+ def keys(self):
+ return list(self.data.keys())
+
+
+def width(string):
+ return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
+
+
+def default_cluster(word):
+ unicode_names = [unicodedata.name(char) for char in word]
+ for i in reversed(range(len(unicode_names))):
+ if unicode_names[i].startswith("DIGIT"): # 1
+ unicode_names[i] = "Number" # 'DIGIT'
+ elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[
+ i
+ ].startswith("CJK COMPATIBILITY IDEOGRAPH"):
+ # 鏄� / 铯�
+ unicode_names[i] = "Mandarin" # 'CJK IDEOGRAPH'
+ elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[
+ i
+ ].startswith("LATIN SMALL LETTER"):
+ # A / a
+ unicode_names[i] = "English" # 'LATIN LETTER'
+ elif unicode_names[i].startswith("HIRAGANA LETTER"): # 銇� 銇� 銈�
+ unicode_names[i] = "Japanese" # 'GANA LETTER'
+ elif (
+ unicode_names[i].startswith("AMPERSAND")
+ or unicode_names[i].startswith("APOSTROPHE")
+ or unicode_names[i].startswith("COMMERCIAL AT")
+ or unicode_names[i].startswith("DEGREE CELSIUS")
+ or unicode_names[i].startswith("EQUALS SIGN")
+ or unicode_names[i].startswith("FULL STOP")
+ or unicode_names[i].startswith("HYPHEN-MINUS")
+ or unicode_names[i].startswith("LOW LINE")
+ or unicode_names[i].startswith("NUMBER SIGN")
+ or unicode_names[i].startswith("PLUS SIGN")
+ or unicode_names[i].startswith("SEMICOLON")
+ ):
+ # & / ' / @ / 鈩� / = / . / - / _ / # / + / ;
+ del unicode_names[i]
+ else:
+ return "Other"
+ if len(unicode_names) == 0:
+ return "Other"
+ if len(unicode_names) == 1:
+ return unicode_names[0]
+ for i in range(len(unicode_names) - 1):
+ if unicode_names[i] != unicode_names[i + 1]:
+ return "Other"
+ return unicode_names[0]
+
+
+def get_args():
+ parser = argparse.ArgumentParser(description="wer cal")
+ parser.add_argument("--ref", type=str, help="Text input path")
+ parser.add_argument("--ref_ocr", type=str, help="Text input path")
+ parser.add_argument("--rec_name", type=str, action="append", default=[])
+ parser.add_argument("--rec_file", type=str, action="append", default=[])
+ parser.add_argument("--verbose", type=int, default=1, help="show")
+ parser.add_argument("--char", type=bool, default=True, help="show")
+ args = parser.parse_args()
+ return args
+
+
+def main(args):
+ cluster_file = ""
+ ignore_words = set()
+ tochar = args.char
+ verbose = args.verbose
+ padding_symbol = " "
+ case_sensitive = False
+ max_words_per_line = sys.maxsize
+ split = None
+
+ if not case_sensitive:
+ ig = set([w.upper() for w in ignore_words])
+ ignore_words = ig
+
+ default_clusters = {}
+ default_words = {}
+ ref_file = args.ref
+ ref_ocr = args.ref_ocr
+ rec_files = args.rec_file
+ rec_names = args.rec_name
+ assert len(rec_files) == len(rec_names)
+
+ # load ocr
+ ref_ocr_dict = {}
+ with codecs.open(ref_ocr, "r", "utf-8") as fh:
+ for line in fh:
+ if "$" in line:
+ line = line.replace("$", " ")
+ if tochar:
+ array = characterize(line)
+ else:
+ array = line.strip().split()
+ if len(array) == 0:
+ continue
+ fid = array[0]
+ ref_ocr_dict[fid] = normalize(array[1:], ignore_words, case_sensitive, split)
+
+ if split and not case_sensitive:
+ newsplit = dict()
+ for w in split:
+ words = split[w]
+ for i in range(len(words)):
+ words[i] = words[i].upper()
+ newsplit[w.upper()] = words
+ split = newsplit
+
+ rec_sets = {}
+ calculators_dict = dict()
+ ub_wer_dict = dict()
+ hotwords_related_dict = dict() # 璁板綍recall鐩稿叧鐨勫唴瀹�
+ for i, hyp_file in enumerate(rec_files):
+ rec_sets[rec_names[i]] = dict()
+ with codecs.open(hyp_file, "r", "utf-8") as fh:
+ for line in fh:
+ if tochar:
+ array = characterize(line)
+ else:
+ array = line.strip().split()
+ if len(array) == 0:
+ continue
+ fid = array[0]
+ rec_sets[rec_names[i]][fid] = normalize(array[1:], ignore_words, case_sensitive, split)
+
+ calculators_dict[rec_names[i]] = Calculator()
+ ub_wer_dict[rec_names[i]] = {"u_wer": WordError(), "b_wer": WordError(), "wer": WordError()}
+ hotwords_related_dict[rec_names[i]] = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}
+ # tp: 鐑瘝鍦╨abel閲岋紝鍚屾椂鍦╮ec閲�
+ # tn: 鐑瘝涓嶅湪label閲岋紝鍚屾椂涓嶅湪rec閲�
+ # fp: 鐑瘝涓嶅湪label閲岋紝浣嗘槸鍦╮ec閲�
+ # fn: 鐑瘝鍦╨abel閲岋紝浣嗘槸涓嶅湪rec閲�
+
+ # record wrong label but in ocr
+ wrong_rec_but_in_ocr_dict = {}
+ for rec_name in rec_names:
+ wrong_rec_but_in_ocr_dict[rec_name] = 0
+
+ _file_total_len = 0
+ with os.popen("cat {} | wc -l".format(ref_file)) as pipe:
+ _file_total_len = int(pipe.read().strip())
+
+ # compute error rate on the interaction of reference file and hyp file
+ for line in tqdm(open(ref_file, 'r', encoding='utf-8'), total=_file_total_len):
+ if tochar:
+ array = characterize(line)
+ else:
+ array = line.rstrip('\n').split()
+ if len(array) == 0: continue
+ fid = array[0]
+ lab = normalize(array[1:], ignore_words, case_sensitive, split)
+
+ if verbose:
+ print('\nutt: %s' % fid)
+
+ ocr_text = ref_ocr_dict[fid]
+ ocr_set = set(ocr_text)
+ print('ocr: {}'.format(" ".join(ocr_text)))
+ list_match = [] # 鎸噇abel閲岄潰鍦╫cr閲岄潰鐨勫唴瀹�
+ list_not_mathch = []
+ tmp_error = 0
+ tmp_match = 0
+ for index in range(len(lab)):
+ # text_list.append(uttlist[index+1])
+ if lab[index] not in ocr_set:
+ tmp_error += 1
+ list_not_mathch.append(lab[index])
+ else:
+ tmp_match += 1
+ list_match.append(lab[index])
+ print('label in ocr: {}'.format(" ".join(list_match)))
+
+ # for each reco file
+ base_wrong_ocr_wer = None
+ ocr_wrong_ocr_wer = None
+
+ for rec_name in rec_names:
+ rec_set = rec_sets[rec_name]
+ if fid not in rec_set:
+ continue
+ rec = rec_set[fid]
+
+ # print(rec)
+ for word in rec + lab:
+ if word not in default_words:
+ default_cluster_name = default_cluster(word)
+ if default_cluster_name not in default_clusters:
+ default_clusters[default_cluster_name] = {}
+ if word not in default_clusters[default_cluster_name]:
+ default_clusters[default_cluster_name][word] = 1
+ default_words[word] = default_cluster_name
+
+ result = calculators_dict[rec_name].calculate(lab.copy(), rec.copy())
+ if verbose:
+ if result['all'] != 0:
+ wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
+ else:
+ wer = 0.0
+ print('WER(%s): %4.2f %%' % (rec_name, wer), end=' ')
+ print('N=%d C=%d S=%d D=%d I=%d' %
+ (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
+
+
+ # print(result['rec'])
+ wrong_rec_but_in_ocr = []
+ for idx in range(len(result['lab'])):
+ if result['lab'][idx] != "":
+ if result['lab'][idx] != result['rec'][idx].replace("<BIAS>", ""):
+ if result['lab'][idx] in list_match:
+ wrong_rec_but_in_ocr.append(result['lab'][idx])
+ wrong_rec_but_in_ocr_dict[rec_name] += 1
+ print('wrong_rec_but_in_ocr: {}'.format(" ".join(wrong_rec_but_in_ocr)))
+
+ if rec_name == "base":
+ base_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
+ if "ocr" in rec_name or "hot" in rec_name:
+ ocr_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
+ if ocr_wrong_ocr_wer < base_wrong_ocr_wer:
+ print("{} {} helps, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
+ elif ocr_wrong_ocr_wer > base_wrong_ocr_wer:
+ print("{} {} hurts, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
+
+ # recall = 0
+ # false_alarm = 0
+ # for idx in range(len(result['lab'])):
+ # if "<BIAS>" in result['rec'][idx]:
+ # if result['rec'][idx].replace("<BIAS>", "") in list_match:
+ # recall += 1
+ # else:
+ # false_alarm += 1
+ # print("bias hotwords recall: {}, fa: {}, list_match {}, recall: {:.2f}, fa: {:.2f}".format(
+ # recall, false_alarm, len(list_match), recall / len(list_match) if len(list_match) != 0 else 0, false_alarm / len(list_match) if len(list_match) != 0 else 0
+ # ))
+ # tp: 鐑瘝鍦╨abel閲岋紝鍚屾椂鍦╮ec閲�
+ # tn: 鐑瘝涓嶅湪label閲岋紝鍚屾椂涓嶅湪rec閲�
+ # fp: 鐑瘝涓嶅湪label閲岋紝浣嗘槸鍦╮ec閲�
+ # fn: 鐑瘝鍦╨abel閲岋紝浣嗘槸涓嶅湪rec閲�
+ _rec_list = [word.replace("<BIAS>", "") for word in rec]
+ _label_list = [word for word in lab]
+ _tp = _tn = _fp = _fn = 0
+ hot_true_list = [hotword for hotword in ocr_text if hotword in _label_list]
+ hot_bad_list = [hotword for hotword in ocr_text if hotword not in _label_list]
+ for badhotword in hot_bad_list:
+ count = len([word for word in _rec_list if word == badhotword])
+ # print(f"bad {badhotword} count: {count}")
+ # for word in _rec_list:
+ # if badhotword == word:
+ # count += 1
+ if count == 0:
+ hotwords_related_dict[rec_name]['tn'] += 1
+ _tn += 1
+ # fp: 0
+ else:
+ hotwords_related_dict[rec_name]['fp'] += count
+ _fp += count
+ # tn: 0
+ # if badhotword in _rec_list:
+ # hotwords_related_dict[rec_name]['fp'] += 1
+ # else:
+ # hotwords_related_dict[rec_name]['tn'] += 1
+ for hotword in hot_true_list:
+ true_count = len([word for word in _label_list if hotword == word])
+ rec_count = len([word for word in _rec_list if hotword == word])
+ # print(f"good {hotword} true_count: {true_count}, rec_count: {rec_count}")
+ if rec_count == true_count:
+ hotwords_related_dict[rec_name]['tp'] += true_count
+ _tp += true_count
+ elif rec_count > true_count:
+ hotwords_related_dict[rec_name]['tp'] += true_count
+ # fp: 涓嶅湪label閲岋紝浣嗘槸鍦╮ec閲�
+ hotwords_related_dict[rec_name]['fp'] += rec_count - true_count
+ _tp += true_count
+ _fp += rec_count - true_count
+ else:
+ hotwords_related_dict[rec_name]['tp'] += rec_count
+ # fn: 鐑瘝鍦╨abel閲岋紝浣嗘槸涓嶅湪rec閲�
+ hotwords_related_dict[rec_name]['fn'] += true_count - rec_count
+ _tp += rec_count
+ _fn += true_count - rec_count
+ print("hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
+ _tp, _tn, _fp, _fn, sum([_tp, _tn, _fp, _fn]), _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0
+ ))
+
+ # if hotword in _rec_list:
+ # hotwords_related_dict[rec_name]['tp'] += 1
+ # else:
+ # hotwords_related_dict[rec_name]['fn'] += 1
+ # 璁$畻uwer, bwer, wer
+ for code, rec_word, lab_word in zip(result["code"], result["rec"], result["lab"]):
+ if code == Code.match:
+ ub_wer_dict[rec_name]["wer"].ref_words += 1
+ if lab_word in hot_true_list:
+ # tmp_ref.append(ref_tokens[ref_idx])
+ ub_wer_dict[rec_name]["b_wer"].ref_words += 1
+ else:
+ ub_wer_dict[rec_name]["u_wer"].ref_words += 1
+ elif code == Code.substitution:
+ ub_wer_dict[rec_name]["wer"].ref_words += 1
+ ub_wer_dict[rec_name]["wer"].errors[Code.substitution] += 1
+ if lab_word in hot_true_list:
+ # tmp_ref.append(ref_tokens[ref_idx])
+ ub_wer_dict[rec_name]["b_wer"].ref_words += 1
+ ub_wer_dict[rec_name]["b_wer"].errors[Code.substitution] += 1
+ else:
+ ub_wer_dict[rec_name]["u_wer"].ref_words += 1
+ ub_wer_dict[rec_name]["u_wer"].errors[Code.substitution] += 1
+ elif code == Code.deletion:
+ ub_wer_dict[rec_name]["wer"].ref_words += 1
+ ub_wer_dict[rec_name]["wer"].errors[Code.deletion] += 1
+ if lab_word in hot_true_list:
+ # tmp_ref.append(ref_tokens[ref_idx])
+ ub_wer_dict[rec_name]["b_wer"].ref_words += 1
+ ub_wer_dict[rec_name]["b_wer"].errors[Code.deletion] += 1
+ else:
+ ub_wer_dict[rec_name]["u_wer"].ref_words += 1
+ ub_wer_dict[rec_name]["u_wer"].errors[Code.deletion] += 1
+ elif code == Code.insertion:
+ ub_wer_dict[rec_name]["wer"].errors[Code.insertion] += 1
+ if rec_word in hot_true_list:
+ ub_wer_dict[rec_name]["b_wer"].errors[Code.insertion] += 1
+ else:
+ ub_wer_dict[rec_name]["u_wer"].errors[Code.insertion] += 1
+
+ space = {}
+ space['lab'] = []
+ space['rec'] = []
+ for idx in range(len(result['lab'])):
+ len_lab = width(result['lab'][idx])
+ len_rec = width(result['rec'][idx])
+ length = max(len_lab, len_rec)
+ space['lab'].append(length - len_lab)
+ space['rec'].append(length - len_rec)
+ upper_lab = len(result['lab'])
+ upper_rec = len(result['rec'])
+ lab1, rec1 = 0, 0
+ while lab1 < upper_lab or rec1 < upper_rec:
+ if verbose > 1:
+ print('lab(%s):' % fid.encode('utf-8'), end=' ')
+ else:
+ print('lab:', end=' ')
+ lab2 = min(upper_lab, lab1 + max_words_per_line)
+ for idx in range(lab1, lab2):
+ token = result['lab'][idx]
+ print('{token}'.format(token=token), end='')
+ for n in range(space['lab'][idx]):
+ print(padding_symbol, end='')
+ print(' ', end='')
+ print()
+ if verbose > 1:
+ print('rec(%s):' % fid.encode('utf-8'), end=' ')
+ else:
+ print('rec:', end=' ')
+
+ rec2 = min(upper_rec, rec1 + max_words_per_line)
+ for idx in range(rec1, rec2):
+ token = result['rec'][idx]
+ print('{token}'.format(token=token), end='')
+ for n in range(space['rec'][idx]):
+ print(padding_symbol, end='')
+ print(' ', end='')
+ print()
+ # print('\n', end='\n')
+ lab1 = lab2
+ rec1 = rec2
+ print('\n', end='\n')
+ # break
+ if verbose:
+ print('===========================================================================')
+ print()
+
+ print(wrong_rec_but_in_ocr_dict)
+ for rec_name in rec_names:
+ result = calculators_dict[rec_name].overall()
+
+ if result['all'] != 0:
+ wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
+ else:
+ wer = 0.0
+ print('{} Overall -> {:4.2f} %'.format(rec_name, wer), end=' ')
+ print('N=%d C=%d S=%d D=%d I=%d' %
+ (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
+ print(f"WER: {ub_wer_dict[rec_name]['wer'].get_result_string()}")
+ print(f"U-WER: {ub_wer_dict[rec_name]['u_wer'].get_result_string()}")
+ print(f"B-WER: {ub_wer_dict[rec_name]['b_wer'].get_result_string()}")
+
+ print('hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%'.format(
+ hotwords_related_dict[rec_name]['tp'],
+ hotwords_related_dict[rec_name]['tn'],
+ hotwords_related_dict[rec_name]['fp'],
+ hotwords_related_dict[rec_name]['fn'],
+ sum([v for k, v in hotwords_related_dict[rec_name].items()]),
+ hotwords_related_dict[rec_name]['tp'] / (
+ hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn']
+ ) * 100 if hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] != 0 else 0
+ ))
+
+ # tp: 鐑瘝鍦╨abel閲岋紝鍚屾椂鍦╮ec閲�
+ # tn: 鐑瘝涓嶅湪label閲岋紝鍚屾椂涓嶅湪rec閲�
+ # fp: 鐑瘝涓嶅湪label閲岋紝浣嗘槸鍦╮ec閲�
+ # fn: 鐑瘝鍦╨abel閲岋紝浣嗘槸涓嶅湪rec閲�
+ if not verbose:
+ print()
+ print()
+
+
+if __name__ == "__main__":
+ args = get_args()
+
+ # print("")
+ print(args)
+ main(args)
+
diff --git a/examples/industrial_data_pretraining/lcbnet/demo.py b/examples/industrial_data_pretraining/lcbnet/demo.py
new file mode 100755
index 0000000..4ca5255
--- /dev/null
+++ b/examples/industrial_data_pretraining/lcbnet/demo.py
@@ -0,0 +1,13 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+from funasr import AutoModel
+
+model = AutoModel(model="iic/LCB-NET",
+ model_revision="v1.0.0")
+
+res = model.generate(input=("https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/asr_example.wav","https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/ocr.txt"),data_type=("sound", "text"))
+
+print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/lcbnet/demo.sh b/examples/industrial_data_pretraining/lcbnet/demo.sh
new file mode 100755
index 0000000..2f226bc
--- /dev/null
+++ b/examples/industrial_data_pretraining/lcbnet/demo.sh
@@ -0,0 +1,72 @@
+file_dir="/home/yf352572/.cache/modelscope/hub/iic/LCB-NET/"
+CUDA_VISIBLE_DEVICES="0,1"
+inference_device="cuda"
+
+if [ ${inference_device} == "cuda" ]; then
+ nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
+else
+ inference_batch_size=1
+ CUDA_VISIBLE_DEVICES=""
+ for JOB in $(seq ${nj}); do
+ CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1,"
+ done
+fi
+
+inference_dir="outputs/slidespeech_dev"
+_logdir="${inference_dir}/logdir"
+echo "inference_dir: ${inference_dir}"
+
+mkdir -p "${_logdir}"
+key_file1=${file_dir}/dev/wav.scp
+key_file2=${file_dir}/dev/ocr.txt
+split_scps1=
+split_scps2=
+for JOB in $(seq "${nj}"); do
+ split_scps1+=" ${_logdir}/wav.${JOB}.scp"
+ split_scps2+=" ${_logdir}/ocr.${JOB}.txt"
+done
+utils/split_scp.pl "${key_file1}" ${split_scps1}
+utils/split_scp.pl "${key_file2}" ${split_scps2}
+
+gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ })
+for JOB in $(seq ${nj}); do
+ {
+ id=$((JOB-1))
+ gpuid=${gpuid_list_array[$id]}
+
+ export CUDA_VISIBLE_DEVICES=${gpuid}
+
+ python -m funasr.bin.inference \
+ --config-path=${file_dir} \
+ --config-name="config.yaml" \
+ ++init_param=${file_dir}/model.pt \
+ ++tokenizer_conf.token_list=${file_dir}/tokens.txt \
+ ++input=[${_logdir}/wav.${JOB}.scp,${_logdir}/ocr.${JOB}.txt] \
+ +data_type='["kaldi_ark", "text"]' \
+ ++tokenizer_conf.bpemodel=${file_dir}/bpe.pt \
+ ++normalize_conf.stats_file=${file_dir}/am.mvn \
+ ++output_dir="${inference_dir}/${JOB}" \
+ ++device="${inference_device}" \
+ ++ncpu=1 \
+ ++disable_log=true &> ${_logdir}/log.${JOB}.txt
+
+ }&
+done
+wait
+
+
+mkdir -p ${inference_dir}/1best_recog
+
+for JOB in $(seq "${nj}"); do
+ cat "${inference_dir}/${JOB}/1best_recog/token" >> "${inference_dir}/1best_recog/token"
+done
+
+echo "Computing WER ..."
+sed -e 's/ /\t/' -e 's/ //g' -e 's/鈻�/ /g' -e 's/\t /\t/' ${inference_dir}/1best_recog/token > ${inference_dir}/1best_recog/token.proc
+cp ${file_dir}/dev/text ${inference_dir}/1best_recog/token.ref
+cp ${file_dir}/dev/ocr.list ${inference_dir}/1best_recog/ocr.list
+python utils/compute_wer.py ${inference_dir}/1best_recog/token.ref ${inference_dir}/1best_recog/token.proc ${inference_dir}/1best_recog/token.cer
+tail -n 3 ${inference_dir}/1best_recog/token.cer
+
+./run_bwer_recall.sh ${inference_dir}/1best_recog/
+tail -n 6 ${inference_dir}/1best_recog/BWER-UWER.results |head -n 5
diff --git a/examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh b/examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh
new file mode 100755
index 0000000..7d6b6ff
--- /dev/null
+++ b/examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh
@@ -0,0 +1,11 @@
+#now_result_name=asr_conformer_acc1_lr002_warm20000/decode_asr_asr_model_valid.acc.ave
+#hotword_type=ocr_1ngram_top10_hotwords_list
+hot_exp_suf=$1
+
+
+python compute_wer_details.py --v 1 \
+ --ref ${hot_exp_suf}/token.ref \
+ --ref_ocr ${hot_exp_suf}/ocr.list \
+ --rec_name base \
+ --rec_file ${hot_exp_suf}/token.proc \
+ > ${hot_exp_suf}/BWER-UWER.results
diff --git a/examples/industrial_data_pretraining/lcbnet/utils b/examples/industrial_data_pretraining/lcbnet/utils
new file mode 120000
index 0000000..be5e5a3
--- /dev/null
+++ b/examples/industrial_data_pretraining/lcbnet/utils
@@ -0,0 +1 @@
+../../aishell/paraformer/utils
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
index a44c649..551dd8b 100644
--- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py
+++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
@@ -7,10 +7,10 @@
model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v2.0.4",
- vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
- vad_model_revision="v2.0.4",
- punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
- punc_model_revision="v2.0.4",
+ # vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+ # vad_model_revision="v2.0.4",
+ # punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
+ # punc_model_revision="v2.0.4",
# spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
# spk_model_revision="v2.0.2",
)
@@ -43,4 +43,4 @@
wav_file = os.path.join(model.model_path, "example/asr_example.wav")
speech, sample_rate = soundfile.read(wav_file)
res = model.generate(input=[speech], batch_size_s=300, is_final=True)
-'''
\ No newline at end of file
+'''
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 921ede8..ec3c3f3 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -28,7 +28,7 @@
from funasr.models.campplus.cluster_backend import ClusterBackend
except:
print("If you want to use the speaker diarization, please `pip install hdbscan`")
-
+import pdb
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
"""
@@ -46,6 +46,7 @@
chars = string.ascii_letters + string.digits
if isinstance(data_in, str) and data_in.startswith('http'): # url
data_in = download_from_url(data_in)
+
if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
_, file_extension = os.path.splitext(data_in)
file_extension = file_extension.lower()
@@ -146,7 +147,7 @@
kwargs = download_model(**kwargs)
set_all_random_seed(kwargs.get("seed", 0))
-
+
device = kwargs.get("device", "cuda")
if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
device = "cpu"
@@ -168,7 +169,6 @@
vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
else:
vocab_size = -1
-
# build frontend
frontend = kwargs.get("frontend", None)
kwargs["input_size"] = None
@@ -181,7 +181,6 @@
# build model
model_class = tables.model_classes.get(kwargs["model"])
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
-
model.to(device)
# init_param
@@ -224,9 +223,9 @@
batch_size = kwargs.get("batch_size", 1)
# if kwargs.get("device", "cpu") == "cpu":
# batch_size = 1
-
+
key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key)
-
+
speed_stats = {}
asr_result_list = []
num_samples = len(data_list)
@@ -239,6 +238,7 @@
data_batch = data_list[beg_idx:end_idx]
key_batch = key_list[beg_idx:end_idx]
batch = {"data_in": data_batch, "key": key_batch}
+
if (end_idx - beg_idx) == 1 and kwargs.get("data_type", None) == "fbank": # fbank
batch["data_in"] = data_batch[0]
batch["data_lengths"] = input_len
diff --git a/funasr/frontends/default.py b/funasr/frontends/default.py
index 8ac1ca8..c4bdbd7 100644
--- a/funasr/frontends/default.py
+++ b/funasr/frontends/default.py
@@ -3,7 +3,6 @@
from typing import Tuple
from typing import Union
import logging
-import humanfriendly
import numpy as np
import torch
import torch.nn as nn
@@ -16,8 +15,10 @@
from funasr.frontends.utils.stft import Stft
from funasr.frontends.utils.frontend import Frontend
from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.register import tables
+@tables.register("frontend_classes", "DefaultFrontend")
class DefaultFrontend(nn.Module):
"""Conventional frontend structure for ASR.
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
@@ -25,7 +26,7 @@
def __init__(
self,
- fs: Union[int, str] = 16000,
+ fs: int = 16000,
n_fft: int = 512,
win_length: int = None,
hop_length: int = 128,
@@ -40,14 +41,14 @@
frontend_conf: Optional[dict] = None,
apply_stft: bool = True,
use_channel: int = None,
+ **kwargs,
):
super().__init__()
- if isinstance(fs, str):
- fs = humanfriendly.parse_size(fs)
# Deepcopy (In general, dict shouldn't be used as default arg)
frontend_conf = copy.deepcopy(frontend_conf)
self.hop_length = hop_length
+ self.fs = fs
if apply_stft:
self.stft = Stft(
@@ -84,8 +85,12 @@
return self.n_mels
def forward(
- self, input: torch.Tensor, input_lengths: torch.Tensor
+ self, input: torch.Tensor, input_lengths: Union[torch.Tensor, list]
) -> Tuple[torch.Tensor, torch.Tensor]:
+ if isinstance(input_lengths, list):
+ input_lengths = torch.tensor(input_lengths)
+ if input.dtype == torch.float64:
+ input = input.float()
# 1. Domain-conversion: e.g. Stft: time -> time-freq
if self.stft is not None:
input_stft, feats_lens = self._compute_stft(input, input_lengths)
@@ -145,7 +150,7 @@
def __init__(
self,
- fs: Union[int, str] = 16000,
+ fs: int = 16000,
n_fft: int = 512,
win_length: int = None,
hop_length: int = None,
@@ -168,9 +173,6 @@
mc: bool = True
):
super().__init__()
- if isinstance(fs, str):
- fs = humanfriendly.parse_size(fs)
-
# Deepcopy (In general, dict shouldn't be used as default arg)
frontend_conf = copy.deepcopy(frontend_conf)
if win_length is None and hop_length is None:
diff --git a/funasr/models/conformer/encoder.py b/funasr/models/conformer/encoder.py
index 1d252c2..be973c6 100644
--- a/funasr/models/conformer/encoder.py
+++ b/funasr/models/conformer/encoder.py
@@ -47,7 +47,7 @@
from funasr.models.transformer.utils.subsampling import Conv2dSubsamplingPad
from funasr.models.transformer.utils.subsampling import StreamingConvInput
from funasr.register import tables
-
+import pdb
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
diff --git a/funasr/models/contextual_paraformer/model.py b/funasr/models/contextual_paraformer/model.py
index 49868a8..7d6f729 100644
--- a/funasr/models/contextual_paraformer/model.py
+++ b/funasr/models/contextual_paraformer/model.py
@@ -29,7 +29,7 @@
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-
+import pdb
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -62,7 +62,6 @@
crit_attn_weight = kwargs.get("crit_attn_weight", 0.0)
crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0)
bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)
-
if bias_encoder_type == 'lstm':
self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate)
@@ -103,17 +102,16 @@
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
-
+
batch_size = speech.shape[0]
hotword_pad = kwargs.get("hotword_pad")
hotword_lengths = kwargs.get("hotword_lengths")
dha_pad = kwargs.get("dha_pad")
-
+
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
-
loss_ctc, cer_ctc = None, None
stats = dict()
@@ -128,12 +126,11 @@
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
-
# 2b. Attention decoder branch
loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths
)
-
+
# 3. CTC-Att loss definition
if self.ctc_weight == 0.0:
loss = loss_att + loss_pre * self.predictor_weight
@@ -171,22 +168,24 @@
):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
+
if self.predictor_bias == 1:
_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_pad_lens = ys_pad_lens + self.predictor_bias
+
pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
ignore_id=self.ignore_id)
-
# -1. bias encoder
if self.use_decoder_embedding:
hw_embed = self.decoder.embed(hotword_pad)
else:
hw_embed = self.bias_embed(hotword_pad)
+
hw_embed, (_, _) = self.bias_encoder(hw_embed)
_ind = np.arange(0, hotword_pad.shape[0]).tolist()
selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]]
contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
-
+
# 0. sampler
decoder_out_1st = None
if self.sampling_ratio > 0.0:
@@ -195,7 +194,7 @@
pre_acoustic_embeds, contextual_info)
else:
sematic_embeds = pre_acoustic_embeds
-
+
# 1. Forward decoder
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
@@ -211,7 +210,7 @@
loss_ideal = None
'''
loss_ideal = None
-
+
if decoder_out_1st is None:
decoder_out_1st = decoder_out
# 2. Compute attention loss
@@ -288,10 +287,11 @@
enforce_sorted=False)
_, (h_n, _) = self.bias_encoder(hw_embed)
hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
-
+
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
)
+
decoder_out = decoder_outs[0]
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
@@ -305,38 +305,42 @@
**kwargs,
):
# init beamsearch
+
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
-
+
meta_data = {}
# extract fbank feats
time1 = time.perf_counter()
+
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
+
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
+
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
frontend=frontend)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data[
"batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
-
+
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
# hotword
self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend)
-
+
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
-
+
# predictor
predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
@@ -344,8 +348,7 @@
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
-
-
+
decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens,
pre_acoustic_embeds,
pre_token_length,
diff --git a/funasr/models/lcbnet/__init__.py b/funasr/models/lcbnet/__init__.py
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/funasr/models/lcbnet/__init__.py
diff --git a/funasr/models/lcbnet/attention.py b/funasr/models/lcbnet/attention.py
new file mode 100644
index 0000000..8e8c594
--- /dev/null
+++ b/funasr/models/lcbnet/attention.py
@@ -0,0 +1,112 @@
+#!/usr/bin/env python3
+# -*- coding: utf-8 -*-
+
+# Copyright 2024 yufan
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Multi-Head Attention Return Weight layer definition."""
+
+import math
+
+import torch
+from torch import nn
+
+class MultiHeadedAttentionReturnWeight(nn.Module):
+ """Multi-Head Attention layer.
+
+ Args:
+ n_head (int): The number of heads.
+ n_feat (int): The number of features.
+ dropout_rate (float): Dropout rate.
+
+ """
+
+ def __init__(self, n_head, n_feat, dropout_rate):
+ """Construct an MultiHeadedAttentionReturnWeight object."""
+ super(MultiHeadedAttentionReturnWeight, self).__init__()
+ assert n_feat % n_head == 0
+ # We assume d_v always equals d_k
+ self.d_k = n_feat // n_head
+ self.h = n_head
+ self.linear_q = nn.Linear(n_feat, n_feat)
+ self.linear_k = nn.Linear(n_feat, n_feat)
+ self.linear_v = nn.Linear(n_feat, n_feat)
+ self.linear_out = nn.Linear(n_feat, n_feat)
+ self.attn = None
+ self.dropout = nn.Dropout(p=dropout_rate)
+
+ def forward_qkv(self, query, key, value):
+ """Transform query, key and value.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+
+ Returns:
+ torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
+ torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
+ torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
+
+ """
+ n_batch = query.size(0)
+ q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
+ k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
+ v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
+ q = q.transpose(1, 2) # (batch, head, time1, d_k)
+ k = k.transpose(1, 2) # (batch, head, time2, d_k)
+ v = v.transpose(1, 2) # (batch, head, time2, d_k)
+
+ return q, k, v
+
+ def forward_attention(self, value, scores, mask):
+ """Compute attention context vector.
+
+ Args:
+ value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
+ scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
+ mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Transformed value (#batch, time1, d_model)
+ weighted by the attention score (#batch, time1, time2).
+
+ """
+ n_batch = value.size(0)
+ if mask is not None:
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
+ min_value = torch.finfo(scores.dtype).min
+ scores = scores.masked_fill(mask, min_value)
+ self.attn = torch.softmax(scores, dim=-1).masked_fill(
+ mask, 0.0
+ ) # (batch, head, time1, time2)
+ else:
+ self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
+
+ p_attn = self.dropout(self.attn)
+ x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
+ x = (
+ x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
+ ) # (batch, time1, d_model)
+
+ return self.linear_out(x), self.attn # (batch, time1, d_model)
+
+ def forward(self, query, key, value, mask):
+ """Compute scaled dot product attention.
+
+ Args:
+ query (torch.Tensor): Query tensor (#batch, time1, size).
+ key (torch.Tensor): Key tensor (#batch, time2, size).
+ value (torch.Tensor): Value tensor (#batch, time2, size).
+ mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
+ (#batch, time1, time2).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time1, d_model).
+
+ """
+ q, k, v = self.forward_qkv(query, key, value)
+ scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
+ return self.forward_attention(v, scores, mask)
+
+
diff --git a/funasr/models/lcbnet/encoder.py b/funasr/models/lcbnet/encoder.py
new file mode 100644
index 0000000..c65823c
--- /dev/null
+++ b/funasr/models/lcbnet/encoder.py
@@ -0,0 +1,392 @@
+# Copyright 2019 Shigeki Karita
+# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
+
+"""Transformer encoder definition."""
+
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+import torch
+from torch import nn
+import logging
+
+from funasr.models.transformer.attention import MultiHeadedAttention
+from funasr.models.lcbnet.attention import MultiHeadedAttentionReturnWeight
+from funasr.models.transformer.embedding import PositionalEncoding
+from funasr.models.transformer.layer_norm import LayerNorm
+
+from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
+from funasr.models.transformer.utils.repeat import repeat
+from funasr.register import tables
+
+class EncoderLayer(nn.Module):
+ """Encoder layer module.
+
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
+ can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+ can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool): Whether to use layer_norm before the first block.
+ concat_after (bool): Whether to concat attention layer's input and output.
+ if True, additional linear will be applied.
+ i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
+ stochastic_depth_rate (float): Proability to skip this layer.
+ During training, the layer may skip residual computation and return input
+ as-is with given probability.
+ """
+
+ def __init__(
+ self,
+ size,
+ self_attn,
+ feed_forward,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ stochastic_depth_rate=0.0,
+ ):
+ """Construct an EncoderLayer object."""
+ super(EncoderLayer, self).__init__()
+ self.self_attn = self_attn
+ self.feed_forward = feed_forward
+ self.norm1 = LayerNorm(size)
+ self.norm2 = LayerNorm(size)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.size = size
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ if self.concat_after:
+ self.concat_linear = nn.Linear(size + size, size)
+ self.stochastic_depth_rate = stochastic_depth_rate
+
+ def forward(self, x, mask, cache=None):
+ """Compute encoded features.
+
+ Args:
+ x_input (torch.Tensor): Input tensor (#batch, time, size).
+ mask (torch.Tensor): Mask tensor for the input (#batch, time).
+ cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor (#batch, time, size).
+ torch.Tensor: Mask tensor (#batch, time).
+
+ """
+ skip_layer = False
+ # with stochastic depth, residual connection `x + f(x)` becomes
+ # `x <- x + 1 / (1 - p) * f(x)` at training time.
+ stoch_layer_coeff = 1.0
+ if self.training and self.stochastic_depth_rate > 0:
+ skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
+ stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
+
+ if skip_layer:
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+ return x, mask
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm1(x)
+
+ if cache is None:
+ x_q = x
+ else:
+ assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
+ x_q = x[:, -1:, :]
+ residual = residual[:, -1:, :]
+ mask = None if mask is None else mask[:, -1:, :]
+
+ if self.concat_after:
+ x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
+ x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
+ else:
+ x = residual + stoch_layer_coeff * self.dropout(
+ self.self_attn(x_q, x, x, mask)
+ )
+ if not self.normalize_before:
+ x = self.norm1(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm2(x)
+ x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm2(x)
+
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+
+ return x, mask
+
+@tables.register("encoder_classes", "TransformerTextEncoder")
+class TransformerTextEncoder(nn.Module):
+ """Transformer text encoder module.
+
+ Args:
+ input_size: input dim
+ output_size: dimension of attention
+ attention_heads: the number of heads of multi head attention
+ linear_units: the number of units of position-wise feed forward
+ num_blocks: the number of decoder blocks
+ dropout_rate: dropout rate
+ attention_dropout_rate: dropout rate in attention
+ positional_dropout_rate: dropout rate after adding positional encoding
+ input_layer: input layer type
+ pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
+ normalize_before: whether to use layer_norm before the first block
+ concat_after: whether to concat attention layer's input and output
+ if True, additional linear will be applied.
+ i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied.
+ i.e. x -> x + att(x)
+ positionwise_layer_type: linear of conv1d
+ positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
+ padding_idx: padding_idx for input_layer=embed
+ """
+
+ def __init__(
+ self,
+ input_size: int,
+ output_size: int = 256,
+ attention_heads: int = 4,
+ linear_units: int = 2048,
+ num_blocks: int = 6,
+ dropout_rate: float = 0.1,
+ positional_dropout_rate: float = 0.1,
+ attention_dropout_rate: float = 0.0,
+ pos_enc_class=PositionalEncoding,
+ normalize_before: bool = True,
+ concat_after: bool = False,
+ ):
+ super().__init__()
+ self._output_size = output_size
+
+ self.embed = torch.nn.Sequential(
+ torch.nn.Embedding(input_size, output_size),
+ pos_enc_class(output_size, positional_dropout_rate),
+ )
+
+ self.normalize_before = normalize_before
+
+ positionwise_layer = PositionwiseFeedForward
+ positionwise_layer_args = (
+ output_size,
+ linear_units,
+ dropout_rate,
+ )
+ self.encoders = repeat(
+ num_blocks,
+ lambda lnum: EncoderLayer(
+ output_size,
+ MultiHeadedAttention(
+ attention_heads, output_size, attention_dropout_rate
+ ),
+ positionwise_layer(*positionwise_layer_args),
+ dropout_rate,
+ normalize_before,
+ concat_after,
+ ),
+ )
+ if self.normalize_before:
+ self.after_norm = LayerNorm(output_size)
+
+ def output_size(self) -> int:
+ return self._output_size
+
+ def forward(
+ self,
+ xs_pad: torch.Tensor,
+ ilens: torch.Tensor,
+ ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
+ """Embed positions in tensor.
+
+ Args:
+ xs_pad: input tensor (B, L, D)
+ ilens: input length (B)
+ Returns:
+ position embedded tensor and mask
+ """
+ masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
+ xs_pad = self.embed(xs_pad)
+
+ xs_pad, masks = self.encoders(xs_pad, masks)
+
+ if self.normalize_before:
+ xs_pad = self.after_norm(xs_pad)
+
+ olens = masks.squeeze(1).sum(1)
+ return xs_pad, olens, None
+
+
+
+
+@tables.register("encoder_classes", "FusionSANEncoder")
+class SelfSrcAttention(nn.Module):
+ """Single decoder layer module.
+
+ Args:
+ size (int): Input dimension.
+ self_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` instance can be used as the argument.
+ src_attn (torch.nn.Module): Self-attention module instance.
+ `MultiHeadedAttention` instance can be used as the argument.
+ feed_forward (torch.nn.Module): Feed-forward module instance.
+ `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
+ can be used as the argument.
+ dropout_rate (float): Dropout rate.
+ normalize_before (bool): Whether to use layer_norm before the first block.
+ concat_after (bool): Whether to concat attention layer's input and output.
+ if True, additional linear will be applied.
+ i.e. x -> x + linear(concat(x, att(x)))
+ if False, no additional linear will be applied. i.e. x -> x + att(x)
+
+
+ """
+ def __init__(
+ self,
+ size,
+ attention_heads,
+ attention_dim,
+ linear_units,
+ self_attention_dropout_rate,
+ src_attention_dropout_rate,
+ positional_dropout_rate,
+ dropout_rate,
+ normalize_before=True,
+ concat_after=False,
+ ):
+ """Construct an SelfSrcAttention object."""
+ super(SelfSrcAttention, self).__init__()
+ self.size = size
+ self.self_attn = MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate)
+ self.src_attn = MultiHeadedAttentionReturnWeight(attention_heads, attention_dim, src_attention_dropout_rate)
+ self.feed_forward = PositionwiseFeedForward(attention_dim, linear_units, positional_dropout_rate)
+ self.norm1 = LayerNorm(size)
+ self.norm2 = LayerNorm(size)
+ self.norm3 = LayerNorm(size)
+ self.dropout = nn.Dropout(dropout_rate)
+ self.normalize_before = normalize_before
+ self.concat_after = concat_after
+ if self.concat_after:
+ self.concat_linear1 = nn.Linear(size + size, size)
+ self.concat_linear2 = nn.Linear(size + size, size)
+
+ def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
+ """Compute decoded features.
+
+ Args:
+ tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
+ tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
+ memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
+ memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
+ cache (List[torch.Tensor]): List of cached tensors.
+ Each tensor shape should be (#batch, maxlen_out - 1, size).
+
+ Returns:
+ torch.Tensor: Output tensor(#batch, maxlen_out, size).
+ torch.Tensor: Mask for output tensor (#batch, maxlen_out).
+ torch.Tensor: Encoded memory (#batch, maxlen_in, size).
+ torch.Tensor: Encoded memory mask (#batch, maxlen_in).
+
+ """
+ residual = tgt
+ if self.normalize_before:
+ tgt = self.norm1(tgt)
+
+ if cache is None:
+ tgt_q = tgt
+ tgt_q_mask = tgt_mask
+ else:
+ # compute only the last frame query keeping dim: max_time_out -> 1
+ assert cache.shape == (
+ tgt.shape[0],
+ tgt.shape[1] - 1,
+ self.size,
+ ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
+ tgt_q = tgt[:, -1:, :]
+ residual = residual[:, -1:, :]
+ tgt_q_mask = None
+ if tgt_mask is not None:
+ tgt_q_mask = tgt_mask[:, -1:, :]
+
+ if self.concat_after:
+ tgt_concat = torch.cat(
+ (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
+ )
+ x = residual + self.concat_linear1(tgt_concat)
+ else:
+ x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
+ if not self.normalize_before:
+ x = self.norm1(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm2(x)
+ if self.concat_after:
+ x_concat = torch.cat(
+ (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
+ )
+ x = residual + self.concat_linear2(x_concat)
+ else:
+ x, score = self.src_attn(x, memory, memory, memory_mask)
+ x = residual + self.dropout(x)
+ if not self.normalize_before:
+ x = self.norm2(x)
+
+ residual = x
+ if self.normalize_before:
+ x = self.norm3(x)
+ x = residual + self.dropout(self.feed_forward(x))
+ if not self.normalize_before:
+ x = self.norm3(x)
+
+ if cache is not None:
+ x = torch.cat([cache, x], dim=1)
+
+ return x, tgt_mask, memory, memory_mask
+
+
+@tables.register("encoder_classes", "ConvBiasPredictor")
+class ConvPredictor(nn.Module):
+ def __init__(self, size=256, l_order=3, r_order=3, attention_heads=4, attention_dropout_rate=0.1, linear_units=2048):
+ super().__init__()
+ self.atten = MultiHeadedAttention(attention_heads, size, attention_dropout_rate)
+ self.norm1 = LayerNorm(size)
+ self.feed_forward = PositionwiseFeedForward(size, linear_units, attention_dropout_rate)
+ self.norm2 = LayerNorm(size)
+ self.pad = nn.ConstantPad1d((l_order, r_order), 0)
+ self.conv1d = nn.Conv1d(size, size, l_order + r_order + 1, groups=size)
+ self.output_linear = nn.Linear(size, 1)
+
+
+ def forward(self, text_enc, asr_enc):
+ # stage1 cross-attention
+ residual = text_enc
+ text_enc = residual + self.atten(text_enc, asr_enc, asr_enc, None)
+
+ # stage2 FFN
+ residual = text_enc
+ text_enc = self.norm1(text_enc)
+ text_enc = residual + self.feed_forward(text_enc)
+
+ # stage Conv predictor
+ text_enc = self.norm2(text_enc)
+ context = text_enc.transpose(1, 2)
+ queries = self.pad(context)
+ memory = self.conv1d(queries)
+ output = memory + context
+ output = output.transpose(1, 2)
+ output = torch.relu(output)
+ output = self.output_linear(output)
+ if output.dim()==3:
+ output = output.squeeze(2)
+ return output
diff --git a/funasr/models/lcbnet/model.py b/funasr/models/lcbnet/model.py
new file mode 100644
index 0000000..3ac319c
--- /dev/null
+++ b/funasr/models/lcbnet/model.py
@@ -0,0 +1,495 @@
+#!/usr/bin/env python3
+# -*- encoding: utf-8 -*-
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+# MIT License (https://opensource.org/licenses/MIT)
+
+import logging
+from typing import Union, Dict, List, Tuple, Optional
+
+import time
+import torch
+import torch.nn as nn
+from torch.cuda.amp import autocast
+
+from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
+from funasr.models.ctc.ctc import CTC
+from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
+from funasr.metrics.compute_acc import th_accuracy
+# from funasr.models.e2e_asr_common import ErrorCalculator
+from funasr.train_utils.device_funcs import force_gatherable
+from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
+from funasr.utils import postprocess_utils
+from funasr.utils.datadir_writer import DatadirWriter
+from funasr.register import tables
+
+import pdb
+@tables.register("model_classes", "LCBNet")
+class LCBNet(nn.Module):
+ """
+ Author: Speech Lab of DAMO Academy, Alibaba Group
+ LCB-NET: LONG-CONTEXT BIASING FOR AUDIO-VISUAL SPEECH RECOGNITION
+ https://arxiv.org/abs/2401.06390
+ """
+
+ def __init__(
+ self,
+ specaug: str = None,
+ specaug_conf: dict = None,
+ normalize: str = None,
+ normalize_conf: dict = None,
+ encoder: str = None,
+ encoder_conf: dict = None,
+ decoder: str = None,
+ decoder_conf: dict = None,
+ text_encoder: str = None,
+ text_encoder_conf: dict = None,
+ bias_predictor: str = None,
+ bias_predictor_conf: dict = None,
+ fusion_encoder: str = None,
+ fusion_encoder_conf: dict = None,
+ ctc: str = None,
+ ctc_conf: dict = None,
+ ctc_weight: float = 0.5,
+ interctc_weight: float = 0.0,
+ select_num: int = 2,
+ select_length: int = 3,
+ insert_blank: bool = True,
+ input_size: int = 80,
+ vocab_size: int = -1,
+ ignore_id: int = -1,
+ blank_id: int = 0,
+ sos: int = 1,
+ eos: int = 2,
+ lsm_weight: float = 0.0,
+ length_normalized_loss: bool = False,
+ report_cer: bool = True,
+ report_wer: bool = True,
+ sym_space: str = "<space>",
+ sym_blank: str = "<blank>",
+ # extract_feats_in_collect_stats: bool = True,
+ share_embedding: bool = False,
+ # preencoder: Optional[AbsPreEncoder] = None,
+ # postencoder: Optional[AbsPostEncoder] = None,
+ **kwargs,
+ ):
+
+ super().__init__()
+
+ if specaug is not None:
+ specaug_class = tables.specaug_classes.get(specaug)
+ specaug = specaug_class(**specaug_conf)
+ if normalize is not None:
+ normalize_class = tables.normalize_classes.get(normalize)
+ normalize = normalize_class(**normalize_conf)
+ encoder_class = tables.encoder_classes.get(encoder)
+ encoder = encoder_class(input_size=input_size, **encoder_conf)
+ encoder_output_size = encoder.output_size()
+
+ # lcbnet modules: text encoder, fusion encoder and bias predictor
+ text_encoder_class = tables.encoder_classes.get(text_encoder)
+ text_encoder = text_encoder_class(input_size=vocab_size, **text_encoder_conf)
+ fusion_encoder_class = tables.encoder_classes.get(fusion_encoder)
+ fusion_encoder = fusion_encoder_class(**fusion_encoder_conf)
+ bias_predictor_class = tables.encoder_classes.get(bias_predictor)
+ bias_predictor = bias_predictor_class(**bias_predictor_conf)
+
+
+ if decoder is not None:
+ decoder_class = tables.decoder_classes.get(decoder)
+ decoder = decoder_class(
+ vocab_size=vocab_size,
+ encoder_output_size=encoder_output_size,
+ **decoder_conf,
+ )
+ if ctc_weight > 0.0:
+
+ if ctc_conf is None:
+ ctc_conf = {}
+
+ ctc = CTC(
+ odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
+ )
+
+ self.blank_id = blank_id
+ self.sos = vocab_size - 1
+ self.eos = vocab_size - 1
+ self.vocab_size = vocab_size
+ self.ignore_id = ignore_id
+ self.ctc_weight = ctc_weight
+ self.specaug = specaug
+ self.normalize = normalize
+ self.encoder = encoder
+ # lcbnet
+ self.text_encoder = text_encoder
+ self.fusion_encoder = fusion_encoder
+ self.bias_predictor = bias_predictor
+ self.select_num = select_num
+ self.select_length = select_length
+ self.insert_blank = insert_blank
+
+ if not hasattr(self.encoder, "interctc_use_conditioning"):
+ self.encoder.interctc_use_conditioning = False
+ if self.encoder.interctc_use_conditioning:
+ self.encoder.conditioning_layer = torch.nn.Linear(
+ vocab_size, self.encoder.output_size()
+ )
+ self.interctc_weight = interctc_weight
+
+ # self.error_calculator = None
+ if ctc_weight == 1.0:
+ self.decoder = None
+ else:
+ self.decoder = decoder
+
+ self.criterion_att = LabelSmoothingLoss(
+ size=vocab_size,
+ padding_idx=ignore_id,
+ smoothing=lsm_weight,
+ normalize_length=length_normalized_loss,
+ )
+ #
+ # if report_cer or report_wer:
+ # self.error_calculator = ErrorCalculator(
+ # token_list, sym_space, sym_blank, report_cer, report_wer
+ # )
+ #
+ self.error_calculator = None
+ if ctc_weight == 0.0:
+ self.ctc = None
+ else:
+ self.ctc = ctc
+
+ self.share_embedding = share_embedding
+ if self.share_embedding:
+ self.decoder.embed = None
+
+ self.length_normalized_loss = length_normalized_loss
+ self.beam_search = None
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ **kwargs,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Encoder + Decoder + Calc loss
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+
+ if len(text_lengths.size()) > 1:
+ text_lengths = text_lengths[:, 0]
+ if len(speech_lengths.size()) > 1:
+ speech_lengths = speech_lengths[:, 0]
+
+ batch_size = speech.shape[0]
+
+ # 1. Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ intermediate_outs = None
+ if isinstance(encoder_out, tuple):
+ intermediate_outs = encoder_out[1]
+ encoder_out = encoder_out[0]
+
+ loss_att, acc_att, cer_att, wer_att = None, None, None, None
+ loss_ctc, cer_ctc = None, None
+ stats = dict()
+
+ # decoder: CTC branch
+ if self.ctc_weight != 0.0:
+ loss_ctc, cer_ctc = self._calc_ctc_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # Collect CTC branch stats
+ stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
+ stats["cer_ctc"] = cer_ctc
+
+ # Intermediate CTC (optional)
+ loss_interctc = 0.0
+ if self.interctc_weight != 0.0 and intermediate_outs is not None:
+ for layer_idx, intermediate_out in intermediate_outs:
+ # we assume intermediate_out has the same length & padding
+ # as those of encoder_out
+ loss_ic, cer_ic = self._calc_ctc_loss(
+ intermediate_out, encoder_out_lens, text, text_lengths
+ )
+ loss_interctc = loss_interctc + loss_ic
+
+ # Collect Intermedaite CTC stats
+ stats["loss_interctc_layer{}".format(layer_idx)] = (
+ loss_ic.detach() if loss_ic is not None else None
+ )
+ stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
+
+ loss_interctc = loss_interctc / len(intermediate_outs)
+
+ # calculate whole encoder loss
+ loss_ctc = (
+ 1 - self.interctc_weight
+ ) * loss_ctc + self.interctc_weight * loss_interctc
+
+ # decoder: Attention decoder branch
+ loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
+ encoder_out, encoder_out_lens, text, text_lengths
+ )
+
+ # 3. CTC-Att loss definition
+ if self.ctc_weight == 0.0:
+ loss = loss_att
+ elif self.ctc_weight == 1.0:
+ loss = loss_ctc
+ else:
+ loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
+
+ # Collect Attn branch stats
+ stats["loss_att"] = loss_att.detach() if loss_att is not None else None
+ stats["acc"] = acc_att
+ stats["cer"] = cer_att
+ stats["wer"] = wer_att
+
+ # Collect total loss stats
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ if self.length_normalized_loss:
+ batch_size = int((text_lengths + 1).sum())
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+
+ def encode(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ ind: int
+ """
+ with autocast(False):
+ # Data augmentation
+ if self.specaug is not None and self.training:
+ speech, speech_lengths = self.specaug(speech, speech_lengths)
+ # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ speech, speech_lengths = self.normalize(speech, speech_lengths)
+ # Forward encoder
+ # feats: (Batch, Length, Dim)
+ # -> encoder_out: (Batch, Length2, Dim2)
+ if self.encoder.interctc_use_conditioning:
+ encoder_out, encoder_out_lens, _ = self.encoder(
+ speech, speech_lengths, ctc=self.ctc
+ )
+ else:
+ encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
+ intermediate_outs = None
+ if isinstance(encoder_out, tuple):
+ intermediate_outs = encoder_out[1]
+ encoder_out = encoder_out[0]
+
+ if intermediate_outs is not None:
+ return (encoder_out, intermediate_outs), encoder_out_lens
+ return encoder_out, encoder_out_lens
+
+ def _calc_att_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
+ ys_in_lens = ys_pad_lens + 1
+
+ # 1. Forward decoder
+ decoder_out, _ = self.decoder(
+ encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
+ )
+
+ # 2. Compute attention loss
+ loss_att = self.criterion_att(decoder_out, ys_out_pad)
+ acc_att = th_accuracy(
+ decoder_out.view(-1, self.vocab_size),
+ ys_out_pad,
+ ignore_label=self.ignore_id,
+ )
+
+ # Compute cer/wer using attention-decoder
+ if self.training or self.error_calculator is None:
+ cer_att, wer_att = None, None
+ else:
+ ys_hat = decoder_out.argmax(dim=-1)
+ cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
+
+ return loss_att, acc_att, cer_att, wer_att
+
+ def _calc_ctc_loss(
+ self,
+ encoder_out: torch.Tensor,
+ encoder_out_lens: torch.Tensor,
+ ys_pad: torch.Tensor,
+ ys_pad_lens: torch.Tensor,
+ ):
+ # Calc CTC loss
+ loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
+
+ # Calc CER using CTC
+ cer_ctc = None
+ if not self.training and self.error_calculator is not None:
+ ys_hat = self.ctc.argmax(encoder_out).data
+ cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
+ return loss_ctc, cer_ctc
+
+ def init_beam_search(self,
+ **kwargs,
+ ):
+ from funasr.models.transformer.search import BeamSearch
+ from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
+ from funasr.models.transformer.scorers.length_bonus import LengthBonus
+
+ # 1. Build ASR model
+ scorers = {}
+
+ if self.ctc != None:
+ ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
+ scorers.update(
+ ctc=ctc
+ )
+ token_list = kwargs.get("token_list")
+ scorers.update(
+ decoder=self.decoder,
+ length_bonus=LengthBonus(len(token_list)),
+ )
+
+
+ # 3. Build ngram model
+ # ngram is not supported now
+ ngram = None
+ scorers["ngram"] = ngram
+
+ weights = dict(
+ decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.3),
+ ctc=kwargs.get("decoding_ctc_weight", 0.3),
+ lm=kwargs.get("lm_weight", 0.0),
+ ngram=kwargs.get("ngram_weight", 0.0),
+ length_bonus=kwargs.get("penalty", 0.0),
+ )
+ beam_search = BeamSearch(
+ beam_size=kwargs.get("beam_size", 20),
+ weights=weights,
+ scorers=scorers,
+ sos=self.sos,
+ eos=self.eos,
+ vocab_size=len(token_list),
+ token_list=token_list,
+ pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
+ )
+
+ self.beam_search = beam_search
+
+ def inference(self,
+ data_in,
+ data_lengths=None,
+ key: list=None,
+ tokenizer=None,
+ frontend=None,
+ **kwargs,
+ ):
+
+ if kwargs.get("batch_size", 1) > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+
+ # init beamsearch
+ if self.beam_search is None:
+ logging.info("enable beam_search")
+ self.init_beam_search(**kwargs)
+ self.nbest = kwargs.get("nbest", 1)
+
+ meta_data = {}
+ if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
+ speech, speech_lengths = data_in, data_lengths
+ if len(speech.shape) < 3:
+ speech = speech[None, :, :]
+ if speech_lengths is None:
+ speech_lengths = speech.shape[1]
+ else:
+ # extract fbank feats
+ time1 = time.perf_counter()
+ sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
+ data_type=kwargs.get("data_type", "sound"),
+ tokenizer=tokenizer)
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ audio_sample_list = sample_list[0]
+ if len(sample_list) >1:
+ ocr_sample_list = sample_list[1]
+ else:
+ ocr_sample_list = [[294, 0]]
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+ frontend=frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ frame_shift = 10
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift / 1000
+
+ speech = speech.to(device=kwargs["device"])
+ speech_lengths = speech_lengths.to(device=kwargs["device"])
+ # Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+ if isinstance(encoder_out, tuple):
+ encoder_out = encoder_out[0]
+
+ ocr_list_new = [[x + 1 if x != 0 else x for x in sublist] for sublist in ocr_sample_list]
+ ocr = torch.tensor(ocr_list_new).to(device=kwargs["device"])
+ ocr_lengths = ocr.new_full([1], dtype=torch.long, fill_value=ocr.size(1)).to(device=kwargs["device"])
+ ocr, ocr_lens, _ = self.text_encoder(ocr, ocr_lengths)
+ fusion_out, _, _, _ = self.fusion_encoder(encoder_out,None, ocr, None)
+ encoder_out = encoder_out + fusion_out
+ # c. Passed the encoder result and the beam search
+ nbest_hyps = self.beam_search(
+ x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
+ )
+
+ nbest_hyps = nbest_hyps[: self.nbest]
+
+ results = []
+ b, n, d = encoder_out.size()
+ for i in range(b):
+
+ for nbest_idx, hyp in enumerate(nbest_hyps):
+ ibest_writer = None
+ if kwargs.get("output_dir") is not None:
+ if not hasattr(self, "writer"):
+ self.writer = DatadirWriter(kwargs.get("output_dir"))
+ ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
+
+ # remove sos/eos and get results
+ last_pos = -1
+ if isinstance(hyp.yseq, list):
+ token_int = hyp.yseq[1:last_pos]
+ else:
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
+
+ # Change integer-ids to tokens
+ token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.tokens2text(token)
+
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ result_i = {"key": key[i], "token": token, "text": text_postprocessed}
+ results.append(result_i)
+
+ if ibest_writer is not None:
+ ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["text"][key[i]] = text_postprocessed
+
+ return results, meta_data
+
diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py
index 20b0cc8..a8b1f1f 100644
--- a/funasr/models/seaco_paraformer/model.py
+++ b/funasr/models/seaco_paraformer/model.py
@@ -30,7 +30,7 @@
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-
+import pdb
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
@@ -128,7 +128,7 @@
hotword_pad = kwargs.get("hotword_pad")
hotword_lengths = kwargs.get("hotword_lengths")
dha_pad = kwargs.get("dha_pad")
-
+
batch_size = speech.shape[0]
# for data-parallel
text = text[:, : text_lengths.max()]
@@ -209,17 +209,20 @@
nfilter=50,
seaco_weight=1.0):
# decoder forward
+
decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True)
+
decoder_pred = torch.log_softmax(decoder_out, dim=-1)
if hw_list is not None:
hw_lengths = [len(i) for i in hw_list]
hw_list_ = [torch.Tensor(i).long() for i in hw_list]
hw_list_pad = pad_list(hw_list_, 0).to(encoder_out.device)
selected = self._hotword_representation(hw_list_pad, torch.Tensor(hw_lengths).int().to(encoder_out.device))
+
contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device)
num_hot_word = contextual_info.shape[1]
_contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device)
-
+
# ASF Core
if nfilter > 0 and nfilter < num_hot_word:
hotword_scores = self.seaco_decoder.forward_asf6(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens)
@@ -239,7 +242,7 @@
cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens)
dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens)
merged = self._merge(cif_attended, dec_attended)
-
+
dha_output = self.hotword_output_layer(merged) # remove the last token in loss calculation
dha_pred = torch.log_softmax(dha_output, dim=-1)
def _merge_res(dec_output, dha_output):
@@ -253,8 +256,8 @@
# logits = dec_output * dha_mask + dha_output[:,:,:-1] * (1-dha_mask)
logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask)
return logits
+
merged_pred = _merge_res(decoder_pred, dha_pred)
- # import pdb; pdb.set_trace()
return merged_pred
else:
return decoder_pred
@@ -304,7 +307,6 @@
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
-
meta_data = {}
# extract fbank feats
@@ -330,6 +332,7 @@
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
+
# predictor
predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \
@@ -338,15 +341,14 @@
if torch.max(pre_token_length) < 1:
return []
-
decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens,
pre_acoustic_embeds,
pre_token_length,
hw_list=self.hotword_list)
+
# decoder_out, _ = decoder_outs[0], decoder_outs[1]
_, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
pre_token_length)
-
results = []
b, n, d = decoder_out.size()
for i in range(b):
diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py
index ea23725..0c46449 100644
--- a/funasr/train_utils/load_pretrained_model.py
+++ b/funasr/train_utils/load_pretrained_model.py
@@ -7,7 +7,7 @@
import torch
import torch.nn
import torch.optim
-
+import pdb
def filter_state_dict(
dst_state: Dict[str, Union[float, torch.Tensor]],
@@ -63,6 +63,7 @@
dst_state = obj.state_dict()
print(f"ckpt: {path}")
+
if oss_bucket is None:
src_state = torch.load(path, map_location=map_location)
else:
diff --git a/funasr/utils/load_utils.py b/funasr/utils/load_utils.py
index 7748172..84c38f9 100644
--- a/funasr/utils/load_utils.py
+++ b/funasr/utils/load_utils.py
@@ -13,29 +13,25 @@
from funasr.download.file import download_from_url
except:
print("urllib is not installed, if you infer from url, please install it first.")
-
+import pdb
def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None, **kwargs):
if isinstance(data_or_path_or_list, (list, tuple)):
if data_type is not None and isinstance(data_type, (list, tuple)):
-
data_types = [data_type] * len(data_or_path_or_list)
data_or_path_or_list_ret = [[] for d in data_type]
for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)):
-
for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)):
-
data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs)
data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
return data_or_path_or_list_ret
else:
return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list]
-
if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file
data_or_path_or_list = download_from_url(data_or_path_or_list)
-
+
if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
if data_type is None or data_type == "sound":
data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
@@ -56,10 +52,22 @@
data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point
data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze() # [n_samples,]
+ elif isinstance(data_or_path_or_list, str) and data_type == "kaldi_ark":
+ data_mat = kaldiio.load_mat(data_or_path_or_list)
+ if isinstance(data_mat, tuple):
+ audio_fs, mat = data_mat
+ else:
+ mat = data_mat
+ if mat.dtype == 'int16' or mat.dtype == 'int32':
+ mat = mat.astype(np.float64)
+ mat = mat / 32768
+ if mat.ndim ==2:
+ mat = mat[:,0]
+ data_or_path_or_list = mat
else:
pass
# print(f"unsupport data type: {data_or_path_or_list}, return raw data")
-
+
if audio_fs != fs and data_type != "text":
resampler = torchaudio.transforms.Resample(audio_fs, fs)
data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
@@ -81,8 +89,6 @@
return array
def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None, **kwargs):
- # import pdb;
- # pdb.set_trace()
if isinstance(data, np.ndarray):
data = torch.from_numpy(data)
if len(data.shape) < 2:
@@ -100,9 +106,7 @@
data_list.append(data_i)
data_len.append(data_i.shape[0])
data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
- # import pdb;
- # pdb.set_trace()
- # if data_type == "sound":
+
data, data_len = frontend(data, data_len, **kwargs)
if isinstance(data_len, (list, tuple)):
--
Gitblit v1.9.1