From d2c1204d91d7c98be7998e3966bd82e22750293b Mon Sep 17 00:00:00 2001
From: zhifu gao <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 04 三月 2024 17:50:29 +0800
Subject: [PATCH] Revert "Dev yf" (#1418)
---
/dev/null | 495 ---------------------------------------------
funasr/models/seaco_paraformer/model.py | 18
funasr/auto/auto_model.py | 12
funasr/models/contextual_paraformer/model.py | 39 +-
funasr/models/conformer/encoder.py | 2
examples/industrial_data_pretraining/contextual_paraformer/demo.py | 0
funasr/utils/load_utils.py | 28 +-
examples/industrial_data_pretraining/seaco_paraformer/demo.py | 10
examples/industrial_data_pretraining/contextual_paraformer/demo.sh | 2
funasr/frontends/default.py | 20 -
funasr/train_utils/load_pretrained_model.py | 3
11 files changed, 61 insertions(+), 568 deletions(-)
diff --git a/examples/industrial_data_pretraining/contextual_paraformer/demo.py b/examples/industrial_data_pretraining/contextual_paraformer/demo.py
old mode 100755
new mode 100644
diff --git a/examples/industrial_data_pretraining/contextual_paraformer/demo.sh b/examples/industrial_data_pretraining/contextual_paraformer/demo.sh
old mode 100755
new mode 100644
index 1bd4f7f..8fc66f3
--- a/examples/industrial_data_pretraining/contextual_paraformer/demo.sh
+++ b/examples/industrial_data_pretraining/contextual_paraformer/demo.sh
@@ -2,7 +2,7 @@
model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
model_revision="v2.0.4"
-python ../../../funasr/bin/inference.py \
+python funasr/bin/inference.py \
+model=${model} \
+model_revision=${model_revision} \
+input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \
diff --git a/examples/industrial_data_pretraining/contextual_paraformer/demo2.sh b/examples/industrial_data_pretraining/contextual_paraformer/demo2.sh
deleted file mode 100755
index 282f4f1..0000000
--- a/examples/industrial_data_pretraining/contextual_paraformer/demo2.sh
+++ /dev/null
@@ -1,9 +0,0 @@
-python -m funasr.bin.inference \
---config-path="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \
---config-name="config.yaml" \
-++init_param="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/model.pb" \
-++tokenizer_conf.token_list="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/tokens.txt" \
-++frontend_conf.cmvn_file="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/am.mvn" \
-++input="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/asr_example_zh.wav" \
-++output_dir="./outputs/debug2" \
-++device="" \
diff --git a/examples/industrial_data_pretraining/contextual_paraformer/path.sh b/examples/industrial_data_pretraining/contextual_paraformer/path.sh
deleted file mode 100755
index 1a6d67e..0000000
--- a/examples/industrial_data_pretraining/contextual_paraformer/path.sh
+++ /dev/null
@@ -1,6 +0,0 @@
-export FUNASR_DIR=$PWD/../../../
-
-# NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C
-export PYTHONIOENCODING=UTF-8
-export PATH=$FUNASR_DIR/funasr/bin:$PATH
-export PYTHONPATH=$FUNASR_DIR/funasr/bin:$FUNASR_DIR/funasr:$FUNASR_DIR:$PYTHONPATH
diff --git a/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py b/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py
deleted file mode 100755
index e72d871..0000000
--- a/examples/industrial_data_pretraining/lcbnet/compute_wer_details.py
+++ /dev/null
@@ -1,702 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-
-from enum import Enum
-import re, sys, unicodedata
-import codecs
-import argparse
-from tqdm import tqdm
-import os
-import pdb
-remove_tag = False
-spacelist = [" ", "\t", "\r", "\n"]
-puncts = [
- "!",
- ",",
- "?",
- "銆�",
- "銆�",
- "锛�",
- "锛�",
- "锛�",
- "锛�",
- "锛�",
- "銆�",
- "銆�",
- "锔�",
- "銆�",
- "銆�",
- "銆�",
- "銆�",
-]
-
-
-class Code(Enum):
- match = 1
- substitution = 2
- insertion = 3
- deletion = 4
-
-
-class WordError(object):
- def __init__(self):
- self.errors = {
- Code.substitution: 0,
- Code.insertion: 0,
- Code.deletion: 0,
- }
- self.ref_words = 0
-
- def get_wer(self):
- assert self.ref_words != 0
- errors = (
- self.errors[Code.substitution]
- + self.errors[Code.insertion]
- + self.errors[Code.deletion]
- )
- return 100.0 * errors / self.ref_words
-
- def get_result_string(self):
- return (
- f"error_rate={self.get_wer():.4f}, "
- f"ref_words={self.ref_words}, "
- f"subs={self.errors[Code.substitution]}, "
- f"ins={self.errors[Code.insertion]}, "
- f"dels={self.errors[Code.deletion]}"
- )
-
-
-def characterize(string):
- res = []
- i = 0
- while i < len(string):
- char = string[i]
- if char in puncts:
- i += 1
- continue
- cat1 = unicodedata.category(char)
- # https://unicodebook.readthedocs.io/unicode.html#unicode-categories
- if cat1 == "Zs" or cat1 == "Cn" or char in spacelist: # space or not assigned
- i += 1
- continue
- if cat1 == "Lo": # letter-other
- res.append(char)
- i += 1
- else:
- # some input looks like: <unk><noise>, we want to separate it to two words.
- sep = " "
- if char == "<":
- sep = ">"
- j = i + 1
- while j < len(string):
- c = string[j]
- if ord(c) >= 128 or (c in spacelist) or (c == sep):
- break
- j += 1
- if j < len(string) and string[j] == ">":
- j += 1
- res.append(string[i:j])
- i = j
- return res
-
-
-def stripoff_tags(x):
- if not x:
- return ""
- chars = []
- i = 0
- T = len(x)
- while i < T:
- if x[i] == "<":
- while i < T and x[i] != ">":
- i += 1
- i += 1
- else:
- chars.append(x[i])
- i += 1
- return "".join(chars)
-
-
-def normalize(sentence, ignore_words, cs, split=None):
- """sentence, ignore_words are both in unicode"""
- new_sentence = []
- for token in sentence:
- x = token
- if not cs:
- x = x.upper()
- if x in ignore_words:
- continue
- if remove_tag:
- x = stripoff_tags(x)
- if not x:
- continue
- if split and x in split:
- new_sentence += split[x]
- else:
- new_sentence.append(x)
- return new_sentence
-
-
-class Calculator:
- def __init__(self):
- self.data = {}
- self.space = []
- self.cost = {}
- self.cost["cor"] = 0
- self.cost["sub"] = 1
- self.cost["del"] = 1
- self.cost["ins"] = 1
-
- def calculate(self, lab, rec):
- # Initialization
- lab.insert(0, "")
- rec.insert(0, "")
- while len(self.space) < len(lab):
- self.space.append([])
- for row in self.space:
- for element in row:
- element["dist"] = 0
- element["error"] = "non"
- while len(row) < len(rec):
- row.append({"dist": 0, "error": "non"})
- for i in range(len(lab)):
- self.space[i][0]["dist"] = i
- self.space[i][0]["error"] = "del"
- for j in range(len(rec)):
- self.space[0][j]["dist"] = j
- self.space[0][j]["error"] = "ins"
- self.space[0][0]["error"] = "non"
- for token in lab:
- if token not in self.data and len(token) > 0:
- self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
- for token in rec:
- if token not in self.data and len(token) > 0:
- self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
- # Computing edit distance
- for i, lab_token in enumerate(lab):
- for j, rec_token in enumerate(rec):
- if i == 0 or j == 0:
- continue
- min_dist = sys.maxsize
- min_error = "none"
- dist = self.space[i - 1][j]["dist"] + self.cost["del"]
- error = "del"
- if dist < min_dist:
- min_dist = dist
- min_error = error
- dist = self.space[i][j - 1]["dist"] + self.cost["ins"]
- error = "ins"
- if dist < min_dist:
- min_dist = dist
- min_error = error
- if lab_token == rec_token.replace("<BIAS>", ""):
- dist = self.space[i - 1][j - 1]["dist"] + self.cost["cor"]
- error = "cor"
- else:
- dist = self.space[i - 1][j - 1]["dist"] + self.cost["sub"]
- error = "sub"
- if dist < min_dist:
- min_dist = dist
- min_error = error
- self.space[i][j]["dist"] = min_dist
- self.space[i][j]["error"] = min_error
- # Tracing back
- result = {
- "lab": [],
- "rec": [],
- "code": [],
- "all": 0,
- "cor": 0,
- "sub": 0,
- "ins": 0,
- "del": 0,
- }
- i = len(lab) - 1
- j = len(rec) - 1
- while True:
- if self.space[i][j]["error"] == "cor": # correct
- if len(lab[i]) > 0:
- self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
- self.data[lab[i]]["cor"] = self.data[lab[i]]["cor"] + 1
- result["all"] = result["all"] + 1
- result["cor"] = result["cor"] + 1
- result["lab"].insert(0, lab[i])
- result["rec"].insert(0, rec[j])
- result["code"].insert(0, Code.match)
- i = i - 1
- j = j - 1
- elif self.space[i][j]["error"] == "sub": # substitution
- if len(lab[i]) > 0:
- self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
- self.data[lab[i]]["sub"] = self.data[lab[i]]["sub"] + 1
- result["all"] = result["all"] + 1
- result["sub"] = result["sub"] + 1
- result["lab"].insert(0, lab[i])
- result["rec"].insert(0, rec[j])
- result["code"].insert(0, Code.substitution)
- i = i - 1
- j = j - 1
- elif self.space[i][j]["error"] == "del": # deletion
- if len(lab[i]) > 0:
- self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1
- self.data[lab[i]]["del"] = self.data[lab[i]]["del"] + 1
- result["all"] = result["all"] + 1
- result["del"] = result["del"] + 1
- result["lab"].insert(0, lab[i])
- result["rec"].insert(0, "")
- result["code"].insert(0, Code.deletion)
- i = i - 1
- elif self.space[i][j]["error"] == "ins": # insertion
- if len(rec[j]) > 0:
- self.data[rec[j]]["ins"] = self.data[rec[j]]["ins"] + 1
- result["ins"] = result["ins"] + 1
- result["lab"].insert(0, "")
- result["rec"].insert(0, rec[j])
- result["code"].insert(0, Code.insertion)
- j = j - 1
- elif self.space[i][j]["error"] == "non": # starting point
- break
- else: # shouldn't reach here
- print(
- "this should not happen , i = {i} , j = {j} , error = {error}".format(
- i=i, j=j, error=self.space[i][j]["error"]
- )
- )
- return result
-
- def overall(self):
- result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
- for token in self.data:
- result["all"] = result["all"] + self.data[token]["all"]
- result["cor"] = result["cor"] + self.data[token]["cor"]
- result["sub"] = result["sub"] + self.data[token]["sub"]
- result["ins"] = result["ins"] + self.data[token]["ins"]
- result["del"] = result["del"] + self.data[token]["del"]
- return result
-
- def cluster(self, data):
- result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0}
- for token in data:
- if token in self.data:
- result["all"] = result["all"] + self.data[token]["all"]
- result["cor"] = result["cor"] + self.data[token]["cor"]
- result["sub"] = result["sub"] + self.data[token]["sub"]
- result["ins"] = result["ins"] + self.data[token]["ins"]
- result["del"] = result["del"] + self.data[token]["del"]
- return result
-
- def keys(self):
- return list(self.data.keys())
-
-
-def width(string):
- return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string)
-
-
-def default_cluster(word):
- unicode_names = [unicodedata.name(char) for char in word]
- for i in reversed(range(len(unicode_names))):
- if unicode_names[i].startswith("DIGIT"): # 1
- unicode_names[i] = "Number" # 'DIGIT'
- elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[
- i
- ].startswith("CJK COMPATIBILITY IDEOGRAPH"):
- # 鏄� / 铯�
- unicode_names[i] = "Mandarin" # 'CJK IDEOGRAPH'
- elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[
- i
- ].startswith("LATIN SMALL LETTER"):
- # A / a
- unicode_names[i] = "English" # 'LATIN LETTER'
- elif unicode_names[i].startswith("HIRAGANA LETTER"): # 銇� 銇� 銈�
- unicode_names[i] = "Japanese" # 'GANA LETTER'
- elif (
- unicode_names[i].startswith("AMPERSAND")
- or unicode_names[i].startswith("APOSTROPHE")
- or unicode_names[i].startswith("COMMERCIAL AT")
- or unicode_names[i].startswith("DEGREE CELSIUS")
- or unicode_names[i].startswith("EQUALS SIGN")
- or unicode_names[i].startswith("FULL STOP")
- or unicode_names[i].startswith("HYPHEN-MINUS")
- or unicode_names[i].startswith("LOW LINE")
- or unicode_names[i].startswith("NUMBER SIGN")
- or unicode_names[i].startswith("PLUS SIGN")
- or unicode_names[i].startswith("SEMICOLON")
- ):
- # & / ' / @ / 鈩� / = / . / - / _ / # / + / ;
- del unicode_names[i]
- else:
- return "Other"
- if len(unicode_names) == 0:
- return "Other"
- if len(unicode_names) == 1:
- return unicode_names[0]
- for i in range(len(unicode_names) - 1):
- if unicode_names[i] != unicode_names[i + 1]:
- return "Other"
- return unicode_names[0]
-
-
-def get_args():
- parser = argparse.ArgumentParser(description="wer cal")
- parser.add_argument("--ref", type=str, help="Text input path")
- parser.add_argument("--ref_ocr", type=str, help="Text input path")
- parser.add_argument("--rec_name", type=str, action="append", default=[])
- parser.add_argument("--rec_file", type=str, action="append", default=[])
- parser.add_argument("--verbose", type=int, default=1, help="show")
- parser.add_argument("--char", type=bool, default=True, help="show")
- args = parser.parse_args()
- return args
-
-
-def main(args):
- cluster_file = ""
- ignore_words = set()
- tochar = args.char
- verbose = args.verbose
- padding_symbol = " "
- case_sensitive = False
- max_words_per_line = sys.maxsize
- split = None
-
- if not case_sensitive:
- ig = set([w.upper() for w in ignore_words])
- ignore_words = ig
-
- default_clusters = {}
- default_words = {}
- ref_file = args.ref
- ref_ocr = args.ref_ocr
- rec_files = args.rec_file
- rec_names = args.rec_name
- assert len(rec_files) == len(rec_names)
-
- # load ocr
- ref_ocr_dict = {}
- with codecs.open(ref_ocr, "r", "utf-8") as fh:
- for line in fh:
- if "$" in line:
- line = line.replace("$", " ")
- if tochar:
- array = characterize(line)
- else:
- array = line.strip().split()
- if len(array) == 0:
- continue
- fid = array[0]
- ref_ocr_dict[fid] = normalize(array[1:], ignore_words, case_sensitive, split)
-
- if split and not case_sensitive:
- newsplit = dict()
- for w in split:
- words = split[w]
- for i in range(len(words)):
- words[i] = words[i].upper()
- newsplit[w.upper()] = words
- split = newsplit
-
- rec_sets = {}
- calculators_dict = dict()
- ub_wer_dict = dict()
- hotwords_related_dict = dict() # 璁板綍recall鐩稿叧鐨勫唴瀹�
- for i, hyp_file in enumerate(rec_files):
- rec_sets[rec_names[i]] = dict()
- with codecs.open(hyp_file, "r", "utf-8") as fh:
- for line in fh:
- if tochar:
- array = characterize(line)
- else:
- array = line.strip().split()
- if len(array) == 0:
- continue
- fid = array[0]
- rec_sets[rec_names[i]][fid] = normalize(array[1:], ignore_words, case_sensitive, split)
-
- calculators_dict[rec_names[i]] = Calculator()
- ub_wer_dict[rec_names[i]] = {"u_wer": WordError(), "b_wer": WordError(), "wer": WordError()}
- hotwords_related_dict[rec_names[i]] = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0}
- # tp: 鐑瘝鍦╨abel閲岋紝鍚屾椂鍦╮ec閲�
- # tn: 鐑瘝涓嶅湪label閲岋紝鍚屾椂涓嶅湪rec閲�
- # fp: 鐑瘝涓嶅湪label閲岋紝浣嗘槸鍦╮ec閲�
- # fn: 鐑瘝鍦╨abel閲岋紝浣嗘槸涓嶅湪rec閲�
-
- # record wrong label but in ocr
- wrong_rec_but_in_ocr_dict = {}
- for rec_name in rec_names:
- wrong_rec_but_in_ocr_dict[rec_name] = 0
-
- _file_total_len = 0
- with os.popen("cat {} | wc -l".format(ref_file)) as pipe:
- _file_total_len = int(pipe.read().strip())
-
- # compute error rate on the interaction of reference file and hyp file
- for line in tqdm(open(ref_file, 'r', encoding='utf-8'), total=_file_total_len):
- if tochar:
- array = characterize(line)
- else:
- array = line.rstrip('\n').split()
- if len(array) == 0: continue
- fid = array[0]
- lab = normalize(array[1:], ignore_words, case_sensitive, split)
-
- if verbose:
- print('\nutt: %s' % fid)
-
- ocr_text = ref_ocr_dict[fid]
- ocr_set = set(ocr_text)
- print('ocr: {}'.format(" ".join(ocr_text)))
- list_match = [] # 鎸噇abel閲岄潰鍦╫cr閲岄潰鐨勫唴瀹�
- list_not_mathch = []
- tmp_error = 0
- tmp_match = 0
- for index in range(len(lab)):
- # text_list.append(uttlist[index+1])
- if lab[index] not in ocr_set:
- tmp_error += 1
- list_not_mathch.append(lab[index])
- else:
- tmp_match += 1
- list_match.append(lab[index])
- print('label in ocr: {}'.format(" ".join(list_match)))
-
- # for each reco file
- base_wrong_ocr_wer = None
- ocr_wrong_ocr_wer = None
-
- for rec_name in rec_names:
- rec_set = rec_sets[rec_name]
- if fid not in rec_set:
- continue
- rec = rec_set[fid]
-
- # print(rec)
- for word in rec + lab:
- if word not in default_words:
- default_cluster_name = default_cluster(word)
- if default_cluster_name not in default_clusters:
- default_clusters[default_cluster_name] = {}
- if word not in default_clusters[default_cluster_name]:
- default_clusters[default_cluster_name][word] = 1
- default_words[word] = default_cluster_name
-
- result = calculators_dict[rec_name].calculate(lab.copy(), rec.copy())
- if verbose:
- if result['all'] != 0:
- wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
- else:
- wer = 0.0
- print('WER(%s): %4.2f %%' % (rec_name, wer), end=' ')
- print('N=%d C=%d S=%d D=%d I=%d' %
- (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
-
-
- # print(result['rec'])
- wrong_rec_but_in_ocr = []
- for idx in range(len(result['lab'])):
- if result['lab'][idx] != "":
- if result['lab'][idx] != result['rec'][idx].replace("<BIAS>", ""):
- if result['lab'][idx] in list_match:
- wrong_rec_but_in_ocr.append(result['lab'][idx])
- wrong_rec_but_in_ocr_dict[rec_name] += 1
- print('wrong_rec_but_in_ocr: {}'.format(" ".join(wrong_rec_but_in_ocr)))
-
- if rec_name == "base":
- base_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
- if "ocr" in rec_name or "hot" in rec_name:
- ocr_wrong_ocr_wer = len(wrong_rec_but_in_ocr)
- if ocr_wrong_ocr_wer < base_wrong_ocr_wer:
- print("{} {} helps, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
- elif ocr_wrong_ocr_wer > base_wrong_ocr_wer:
- print("{} {} hurts, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer))
-
- # recall = 0
- # false_alarm = 0
- # for idx in range(len(result['lab'])):
- # if "<BIAS>" in result['rec'][idx]:
- # if result['rec'][idx].replace("<BIAS>", "") in list_match:
- # recall += 1
- # else:
- # false_alarm += 1
- # print("bias hotwords recall: {}, fa: {}, list_match {}, recall: {:.2f}, fa: {:.2f}".format(
- # recall, false_alarm, len(list_match), recall / len(list_match) if len(list_match) != 0 else 0, false_alarm / len(list_match) if len(list_match) != 0 else 0
- # ))
- # tp: 鐑瘝鍦╨abel閲岋紝鍚屾椂鍦╮ec閲�
- # tn: 鐑瘝涓嶅湪label閲岋紝鍚屾椂涓嶅湪rec閲�
- # fp: 鐑瘝涓嶅湪label閲岋紝浣嗘槸鍦╮ec閲�
- # fn: 鐑瘝鍦╨abel閲岋紝浣嗘槸涓嶅湪rec閲�
- _rec_list = [word.replace("<BIAS>", "") for word in rec]
- _label_list = [word for word in lab]
- _tp = _tn = _fp = _fn = 0
- hot_true_list = [hotword for hotword in ocr_text if hotword in _label_list]
- hot_bad_list = [hotword for hotword in ocr_text if hotword not in _label_list]
- for badhotword in hot_bad_list:
- count = len([word for word in _rec_list if word == badhotword])
- # print(f"bad {badhotword} count: {count}")
- # for word in _rec_list:
- # if badhotword == word:
- # count += 1
- if count == 0:
- hotwords_related_dict[rec_name]['tn'] += 1
- _tn += 1
- # fp: 0
- else:
- hotwords_related_dict[rec_name]['fp'] += count
- _fp += count
- # tn: 0
- # if badhotword in _rec_list:
- # hotwords_related_dict[rec_name]['fp'] += 1
- # else:
- # hotwords_related_dict[rec_name]['tn'] += 1
- for hotword in hot_true_list:
- true_count = len([word for word in _label_list if hotword == word])
- rec_count = len([word for word in _rec_list if hotword == word])
- # print(f"good {hotword} true_count: {true_count}, rec_count: {rec_count}")
- if rec_count == true_count:
- hotwords_related_dict[rec_name]['tp'] += true_count
- _tp += true_count
- elif rec_count > true_count:
- hotwords_related_dict[rec_name]['tp'] += true_count
- # fp: 涓嶅湪label閲岋紝浣嗘槸鍦╮ec閲�
- hotwords_related_dict[rec_name]['fp'] += rec_count - true_count
- _tp += true_count
- _fp += rec_count - true_count
- else:
- hotwords_related_dict[rec_name]['tp'] += rec_count
- # fn: 鐑瘝鍦╨abel閲岋紝浣嗘槸涓嶅湪rec閲�
- hotwords_related_dict[rec_name]['fn'] += true_count - rec_count
- _tp += rec_count
- _fn += true_count - rec_count
- print("hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format(
- _tp, _tn, _fp, _fn, sum([_tp, _tn, _fp, _fn]), _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0
- ))
-
- # if hotword in _rec_list:
- # hotwords_related_dict[rec_name]['tp'] += 1
- # else:
- # hotwords_related_dict[rec_name]['fn'] += 1
- # 璁$畻uwer, bwer, wer
- for code, rec_word, lab_word in zip(result["code"], result["rec"], result["lab"]):
- if code == Code.match:
- ub_wer_dict[rec_name]["wer"].ref_words += 1
- if lab_word in hot_true_list:
- # tmp_ref.append(ref_tokens[ref_idx])
- ub_wer_dict[rec_name]["b_wer"].ref_words += 1
- else:
- ub_wer_dict[rec_name]["u_wer"].ref_words += 1
- elif code == Code.substitution:
- ub_wer_dict[rec_name]["wer"].ref_words += 1
- ub_wer_dict[rec_name]["wer"].errors[Code.substitution] += 1
- if lab_word in hot_true_list:
- # tmp_ref.append(ref_tokens[ref_idx])
- ub_wer_dict[rec_name]["b_wer"].ref_words += 1
- ub_wer_dict[rec_name]["b_wer"].errors[Code.substitution] += 1
- else:
- ub_wer_dict[rec_name]["u_wer"].ref_words += 1
- ub_wer_dict[rec_name]["u_wer"].errors[Code.substitution] += 1
- elif code == Code.deletion:
- ub_wer_dict[rec_name]["wer"].ref_words += 1
- ub_wer_dict[rec_name]["wer"].errors[Code.deletion] += 1
- if lab_word in hot_true_list:
- # tmp_ref.append(ref_tokens[ref_idx])
- ub_wer_dict[rec_name]["b_wer"].ref_words += 1
- ub_wer_dict[rec_name]["b_wer"].errors[Code.deletion] += 1
- else:
- ub_wer_dict[rec_name]["u_wer"].ref_words += 1
- ub_wer_dict[rec_name]["u_wer"].errors[Code.deletion] += 1
- elif code == Code.insertion:
- ub_wer_dict[rec_name]["wer"].errors[Code.insertion] += 1
- if rec_word in hot_true_list:
- ub_wer_dict[rec_name]["b_wer"].errors[Code.insertion] += 1
- else:
- ub_wer_dict[rec_name]["u_wer"].errors[Code.insertion] += 1
-
- space = {}
- space['lab'] = []
- space['rec'] = []
- for idx in range(len(result['lab'])):
- len_lab = width(result['lab'][idx])
- len_rec = width(result['rec'][idx])
- length = max(len_lab, len_rec)
- space['lab'].append(length - len_lab)
- space['rec'].append(length - len_rec)
- upper_lab = len(result['lab'])
- upper_rec = len(result['rec'])
- lab1, rec1 = 0, 0
- while lab1 < upper_lab or rec1 < upper_rec:
- if verbose > 1:
- print('lab(%s):' % fid.encode('utf-8'), end=' ')
- else:
- print('lab:', end=' ')
- lab2 = min(upper_lab, lab1 + max_words_per_line)
- for idx in range(lab1, lab2):
- token = result['lab'][idx]
- print('{token}'.format(token=token), end='')
- for n in range(space['lab'][idx]):
- print(padding_symbol, end='')
- print(' ', end='')
- print()
- if verbose > 1:
- print('rec(%s):' % fid.encode('utf-8'), end=' ')
- else:
- print('rec:', end=' ')
-
- rec2 = min(upper_rec, rec1 + max_words_per_line)
- for idx in range(rec1, rec2):
- token = result['rec'][idx]
- print('{token}'.format(token=token), end='')
- for n in range(space['rec'][idx]):
- print(padding_symbol, end='')
- print(' ', end='')
- print()
- # print('\n', end='\n')
- lab1 = lab2
- rec1 = rec2
- print('\n', end='\n')
- # break
- if verbose:
- print('===========================================================================')
- print()
-
- print(wrong_rec_but_in_ocr_dict)
- for rec_name in rec_names:
- result = calculators_dict[rec_name].overall()
-
- if result['all'] != 0:
- wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all']
- else:
- wer = 0.0
- print('{} Overall -> {:4.2f} %'.format(rec_name, wer), end=' ')
- print('N=%d C=%d S=%d D=%d I=%d' %
- (result['all'], result['cor'], result['sub'], result['del'], result['ins']))
- print(f"WER: {ub_wer_dict[rec_name]['wer'].get_result_string()}")
- print(f"U-WER: {ub_wer_dict[rec_name]['u_wer'].get_result_string()}")
- print(f"B-WER: {ub_wer_dict[rec_name]['b_wer'].get_result_string()}")
-
- print('hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%'.format(
- hotwords_related_dict[rec_name]['tp'],
- hotwords_related_dict[rec_name]['tn'],
- hotwords_related_dict[rec_name]['fp'],
- hotwords_related_dict[rec_name]['fn'],
- sum([v for k, v in hotwords_related_dict[rec_name].items()]),
- hotwords_related_dict[rec_name]['tp'] / (
- hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn']
- ) * 100 if hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] != 0 else 0
- ))
-
- # tp: 鐑瘝鍦╨abel閲岋紝鍚屾椂鍦╮ec閲�
- # tn: 鐑瘝涓嶅湪label閲岋紝鍚屾椂涓嶅湪rec閲�
- # fp: 鐑瘝涓嶅湪label閲岋紝浣嗘槸鍦╮ec閲�
- # fn: 鐑瘝鍦╨abel閲岋紝浣嗘槸涓嶅湪rec閲�
- if not verbose:
- print()
- print()
-
-
-if __name__ == "__main__":
- args = get_args()
-
- # print("")
- print(args)
- main(args)
-
diff --git a/examples/industrial_data_pretraining/lcbnet/demo.py b/examples/industrial_data_pretraining/lcbnet/demo.py
deleted file mode 100755
index 4ca5255..0000000
--- a/examples/industrial_data_pretraining/lcbnet/demo.py
+++ /dev/null
@@ -1,13 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-from funasr import AutoModel
-
-model = AutoModel(model="iic/LCB-NET",
- model_revision="v1.0.0")
-
-res = model.generate(input=("https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/asr_example.wav","https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/ocr.txt"),data_type=("sound", "text"))
-
-print(res)
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/lcbnet/demo.sh b/examples/industrial_data_pretraining/lcbnet/demo.sh
deleted file mode 100755
index 2f226bc..0000000
--- a/examples/industrial_data_pretraining/lcbnet/demo.sh
+++ /dev/null
@@ -1,72 +0,0 @@
-file_dir="/home/yf352572/.cache/modelscope/hub/iic/LCB-NET/"
-CUDA_VISIBLE_DEVICES="0,1"
-inference_device="cuda"
-
-if [ ${inference_device} == "cuda" ]; then
- nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}')
-else
- inference_batch_size=1
- CUDA_VISIBLE_DEVICES=""
- for JOB in $(seq ${nj}); do
- CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1,"
- done
-fi
-
-inference_dir="outputs/slidespeech_dev"
-_logdir="${inference_dir}/logdir"
-echo "inference_dir: ${inference_dir}"
-
-mkdir -p "${_logdir}"
-key_file1=${file_dir}/dev/wav.scp
-key_file2=${file_dir}/dev/ocr.txt
-split_scps1=
-split_scps2=
-for JOB in $(seq "${nj}"); do
- split_scps1+=" ${_logdir}/wav.${JOB}.scp"
- split_scps2+=" ${_logdir}/ocr.${JOB}.txt"
-done
-utils/split_scp.pl "${key_file1}" ${split_scps1}
-utils/split_scp.pl "${key_file2}" ${split_scps2}
-
-gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ })
-for JOB in $(seq ${nj}); do
- {
- id=$((JOB-1))
- gpuid=${gpuid_list_array[$id]}
-
- export CUDA_VISIBLE_DEVICES=${gpuid}
-
- python -m funasr.bin.inference \
- --config-path=${file_dir} \
- --config-name="config.yaml" \
- ++init_param=${file_dir}/model.pt \
- ++tokenizer_conf.token_list=${file_dir}/tokens.txt \
- ++input=[${_logdir}/wav.${JOB}.scp,${_logdir}/ocr.${JOB}.txt] \
- +data_type='["kaldi_ark", "text"]' \
- ++tokenizer_conf.bpemodel=${file_dir}/bpe.pt \
- ++normalize_conf.stats_file=${file_dir}/am.mvn \
- ++output_dir="${inference_dir}/${JOB}" \
- ++device="${inference_device}" \
- ++ncpu=1 \
- ++disable_log=true &> ${_logdir}/log.${JOB}.txt
-
- }&
-done
-wait
-
-
-mkdir -p ${inference_dir}/1best_recog
-
-for JOB in $(seq "${nj}"); do
- cat "${inference_dir}/${JOB}/1best_recog/token" >> "${inference_dir}/1best_recog/token"
-done
-
-echo "Computing WER ..."
-sed -e 's/ /\t/' -e 's/ //g' -e 's/鈻�/ /g' -e 's/\t /\t/' ${inference_dir}/1best_recog/token > ${inference_dir}/1best_recog/token.proc
-cp ${file_dir}/dev/text ${inference_dir}/1best_recog/token.ref
-cp ${file_dir}/dev/ocr.list ${inference_dir}/1best_recog/ocr.list
-python utils/compute_wer.py ${inference_dir}/1best_recog/token.ref ${inference_dir}/1best_recog/token.proc ${inference_dir}/1best_recog/token.cer
-tail -n 3 ${inference_dir}/1best_recog/token.cer
-
-./run_bwer_recall.sh ${inference_dir}/1best_recog/
-tail -n 6 ${inference_dir}/1best_recog/BWER-UWER.results |head -n 5
diff --git a/examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh b/examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh
deleted file mode 100755
index 7d6b6ff..0000000
--- a/examples/industrial_data_pretraining/lcbnet/run_bwer_recall.sh
+++ /dev/null
@@ -1,11 +0,0 @@
-#now_result_name=asr_conformer_acc1_lr002_warm20000/decode_asr_asr_model_valid.acc.ave
-#hotword_type=ocr_1ngram_top10_hotwords_list
-hot_exp_suf=$1
-
-
-python compute_wer_details.py --v 1 \
- --ref ${hot_exp_suf}/token.ref \
- --ref_ocr ${hot_exp_suf}/ocr.list \
- --rec_name base \
- --rec_file ${hot_exp_suf}/token.proc \
- > ${hot_exp_suf}/BWER-UWER.results
diff --git a/examples/industrial_data_pretraining/lcbnet/utils b/examples/industrial_data_pretraining/lcbnet/utils
deleted file mode 120000
index be5e5a3..0000000
--- a/examples/industrial_data_pretraining/lcbnet/utils
+++ /dev/null
@@ -1 +0,0 @@
-../../aishell/paraformer/utils
\ No newline at end of file
diff --git a/examples/industrial_data_pretraining/seaco_paraformer/demo.py b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
index 551dd8b..a44c649 100644
--- a/examples/industrial_data_pretraining/seaco_paraformer/demo.py
+++ b/examples/industrial_data_pretraining/seaco_paraformer/demo.py
@@ -7,10 +7,10 @@
model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
model_revision="v2.0.4",
- # vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
- # vad_model_revision="v2.0.4",
- # punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
- # punc_model_revision="v2.0.4",
+ vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+ vad_model_revision="v2.0.4",
+ punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch",
+ punc_model_revision="v2.0.4",
# spk_model="damo/speech_campplus_sv_zh-cn_16k-common",
# spk_model_revision="v2.0.2",
)
@@ -43,4 +43,4 @@
wav_file = os.path.join(model.model_path, "example/asr_example.wav")
speech, sample_rate = soundfile.read(wav_file)
res = model.generate(input=[speech], batch_size_s=300, is_final=True)
-'''
+'''
\ No newline at end of file
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index ec3c3f3..921ede8 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -28,7 +28,7 @@
from funasr.models.campplus.cluster_backend import ClusterBackend
except:
print("If you want to use the speaker diarization, please `pip install hdbscan`")
-import pdb
+
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
"""
@@ -46,7 +46,6 @@
chars = string.ascii_letters + string.digits
if isinstance(data_in, str) and data_in.startswith('http'): # url
data_in = download_from_url(data_in)
-
if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt;
_, file_extension = os.path.splitext(data_in)
file_extension = file_extension.lower()
@@ -147,7 +146,7 @@
kwargs = download_model(**kwargs)
set_all_random_seed(kwargs.get("seed", 0))
-
+
device = kwargs.get("device", "cuda")
if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
device = "cpu"
@@ -169,6 +168,7 @@
vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1
else:
vocab_size = -1
+
# build frontend
frontend = kwargs.get("frontend", None)
kwargs["input_size"] = None
@@ -181,6 +181,7 @@
# build model
model_class = tables.model_classes.get(kwargs["model"])
model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size)
+
model.to(device)
# init_param
@@ -223,9 +224,9 @@
batch_size = kwargs.get("batch_size", 1)
# if kwargs.get("device", "cpu") == "cpu":
# batch_size = 1
-
+
key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key)
-
+
speed_stats = {}
asr_result_list = []
num_samples = len(data_list)
@@ -238,7 +239,6 @@
data_batch = data_list[beg_idx:end_idx]
key_batch = key_list[beg_idx:end_idx]
batch = {"data_in": data_batch, "key": key_batch}
-
if (end_idx - beg_idx) == 1 and kwargs.get("data_type", None) == "fbank": # fbank
batch["data_in"] = data_batch[0]
batch["data_lengths"] = input_len
diff --git a/funasr/frontends/default.py b/funasr/frontends/default.py
index c4bdbd7..8ac1ca8 100644
--- a/funasr/frontends/default.py
+++ b/funasr/frontends/default.py
@@ -3,6 +3,7 @@
from typing import Tuple
from typing import Union
import logging
+import humanfriendly
import numpy as np
import torch
import torch.nn as nn
@@ -15,10 +16,8 @@
from funasr.frontends.utils.stft import Stft
from funasr.frontends.utils.frontend import Frontend
from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.register import tables
-@tables.register("frontend_classes", "DefaultFrontend")
class DefaultFrontend(nn.Module):
"""Conventional frontend structure for ASR.
Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN
@@ -26,7 +25,7 @@
def __init__(
self,
- fs: int = 16000,
+ fs: Union[int, str] = 16000,
n_fft: int = 512,
win_length: int = None,
hop_length: int = 128,
@@ -41,14 +40,14 @@
frontend_conf: Optional[dict] = None,
apply_stft: bool = True,
use_channel: int = None,
- **kwargs,
):
super().__init__()
+ if isinstance(fs, str):
+ fs = humanfriendly.parse_size(fs)
# Deepcopy (In general, dict shouldn't be used as default arg)
frontend_conf = copy.deepcopy(frontend_conf)
self.hop_length = hop_length
- self.fs = fs
if apply_stft:
self.stft = Stft(
@@ -85,12 +84,8 @@
return self.n_mels
def forward(
- self, input: torch.Tensor, input_lengths: Union[torch.Tensor, list]
+ self, input: torch.Tensor, input_lengths: torch.Tensor
) -> Tuple[torch.Tensor, torch.Tensor]:
- if isinstance(input_lengths, list):
- input_lengths = torch.tensor(input_lengths)
- if input.dtype == torch.float64:
- input = input.float()
# 1. Domain-conversion: e.g. Stft: time -> time-freq
if self.stft is not None:
input_stft, feats_lens = self._compute_stft(input, input_lengths)
@@ -150,7 +145,7 @@
def __init__(
self,
- fs: int = 16000,
+ fs: Union[int, str] = 16000,
n_fft: int = 512,
win_length: int = None,
hop_length: int = None,
@@ -173,6 +168,9 @@
mc: bool = True
):
super().__init__()
+ if isinstance(fs, str):
+ fs = humanfriendly.parse_size(fs)
+
# Deepcopy (In general, dict shouldn't be used as default arg)
frontend_conf = copy.deepcopy(frontend_conf)
if win_length is None and hop_length is None:
diff --git a/funasr/models/conformer/encoder.py b/funasr/models/conformer/encoder.py
index be973c6..1d252c2 100644
--- a/funasr/models/conformer/encoder.py
+++ b/funasr/models/conformer/encoder.py
@@ -47,7 +47,7 @@
from funasr.models.transformer.utils.subsampling import Conv2dSubsamplingPad
from funasr.models.transformer.utils.subsampling import StreamingConvInput
from funasr.register import tables
-import pdb
+
class ConvolutionModule(nn.Module):
"""ConvolutionModule in Conformer model.
diff --git a/funasr/models/contextual_paraformer/model.py b/funasr/models/contextual_paraformer/model.py
index 7d6f729..49868a8 100644
--- a/funasr/models/contextual_paraformer/model.py
+++ b/funasr/models/contextual_paraformer/model.py
@@ -29,7 +29,7 @@
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-import pdb
+
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
@@ -62,6 +62,7 @@
crit_attn_weight = kwargs.get("crit_attn_weight", 0.0)
crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0)
bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0)
+
if bias_encoder_type == 'lstm':
self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate)
@@ -102,16 +103,17 @@
text_lengths = text_lengths[:, 0]
if len(speech_lengths.size()) > 1:
speech_lengths = speech_lengths[:, 0]
-
+
batch_size = speech.shape[0]
hotword_pad = kwargs.get("hotword_pad")
hotword_lengths = kwargs.get("hotword_lengths")
dha_pad = kwargs.get("dha_pad")
-
+
# 1. Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
loss_ctc, cer_ctc = None, None
stats = dict()
@@ -126,11 +128,12 @@
stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
stats["cer_ctc"] = cer_ctc
+
# 2b. Attention decoder branch
loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss(
encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths
)
-
+
# 3. CTC-Att loss definition
if self.ctc_weight == 0.0:
loss = loss_att + loss_pre * self.predictor_weight
@@ -168,24 +171,22 @@
):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
-
if self.predictor_bias == 1:
_, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
ys_pad_lens = ys_pad_lens + self.predictor_bias
-
pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask,
ignore_id=self.ignore_id)
+
# -1. bias encoder
if self.use_decoder_embedding:
hw_embed = self.decoder.embed(hotword_pad)
else:
hw_embed = self.bias_embed(hotword_pad)
-
hw_embed, (_, _) = self.bias_encoder(hw_embed)
_ind = np.arange(0, hotword_pad.shape[0]).tolist()
selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]]
contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device)
-
+
# 0. sampler
decoder_out_1st = None
if self.sampling_ratio > 0.0:
@@ -194,7 +195,7 @@
pre_acoustic_embeds, contextual_info)
else:
sematic_embeds = pre_acoustic_embeds
-
+
# 1. Forward decoder
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info
@@ -210,7 +211,7 @@
loss_ideal = None
'''
loss_ideal = None
-
+
if decoder_out_1st is None:
decoder_out_1st = decoder_out
# 2. Compute attention loss
@@ -287,11 +288,10 @@
enforce_sorted=False)
_, (h_n, _) = self.bias_encoder(hw_embed)
hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1)
-
+
decoder_outs = self.decoder(
encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale
)
-
decoder_out = decoder_outs[0]
decoder_out = torch.log_softmax(decoder_out, dim=-1)
return decoder_out, ys_pad_lens
@@ -305,42 +305,38 @@
**kwargs,
):
# init beamsearch
-
is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None
is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
-
+
meta_data = {}
# extract fbank feats
time1 = time.perf_counter()
-
audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
-
time2 = time.perf_counter()
meta_data["load_data"] = f"{time2 - time1:0.3f}"
-
speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
frontend=frontend)
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data[
"batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
-
+
speech = speech.to(device=kwargs["device"])
speech_lengths = speech_lengths.to(device=kwargs["device"])
# hotword
self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend)
-
+
# Encoder
encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
-
+
# predictor
predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \
@@ -348,7 +344,8 @@
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
-
+
+
decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens,
pre_acoustic_embeds,
pre_token_length,
diff --git a/funasr/models/lcbnet/__init__.py b/funasr/models/lcbnet/__init__.py
deleted file mode 100644
index e69de29..0000000
--- a/funasr/models/lcbnet/__init__.py
+++ /dev/null
diff --git a/funasr/models/lcbnet/attention.py b/funasr/models/lcbnet/attention.py
deleted file mode 100644
index 8e8c594..0000000
--- a/funasr/models/lcbnet/attention.py
+++ /dev/null
@@ -1,112 +0,0 @@
-#!/usr/bin/env python3
-# -*- coding: utf-8 -*-
-
-# Copyright 2024 yufan
-# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
-
-"""Multi-Head Attention Return Weight layer definition."""
-
-import math
-
-import torch
-from torch import nn
-
-class MultiHeadedAttentionReturnWeight(nn.Module):
- """Multi-Head Attention layer.
-
- Args:
- n_head (int): The number of heads.
- n_feat (int): The number of features.
- dropout_rate (float): Dropout rate.
-
- """
-
- def __init__(self, n_head, n_feat, dropout_rate):
- """Construct an MultiHeadedAttentionReturnWeight object."""
- super(MultiHeadedAttentionReturnWeight, self).__init__()
- assert n_feat % n_head == 0
- # We assume d_v always equals d_k
- self.d_k = n_feat // n_head
- self.h = n_head
- self.linear_q = nn.Linear(n_feat, n_feat)
- self.linear_k = nn.Linear(n_feat, n_feat)
- self.linear_v = nn.Linear(n_feat, n_feat)
- self.linear_out = nn.Linear(n_feat, n_feat)
- self.attn = None
- self.dropout = nn.Dropout(p=dropout_rate)
-
- def forward_qkv(self, query, key, value):
- """Transform query, key and value.
-
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
-
- Returns:
- torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k).
- torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k).
- torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k).
-
- """
- n_batch = query.size(0)
- q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k)
- k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k)
- v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k)
- q = q.transpose(1, 2) # (batch, head, time1, d_k)
- k = k.transpose(1, 2) # (batch, head, time2, d_k)
- v = v.transpose(1, 2) # (batch, head, time2, d_k)
-
- return q, k, v
-
- def forward_attention(self, value, scores, mask):
- """Compute attention context vector.
-
- Args:
- value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k).
- scores (torch.Tensor): Attention score (#batch, n_head, time1, time2).
- mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2).
-
- Returns:
- torch.Tensor: Transformed value (#batch, time1, d_model)
- weighted by the attention score (#batch, time1, time2).
-
- """
- n_batch = value.size(0)
- if mask is not None:
- mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2)
- min_value = torch.finfo(scores.dtype).min
- scores = scores.masked_fill(mask, min_value)
- self.attn = torch.softmax(scores, dim=-1).masked_fill(
- mask, 0.0
- ) # (batch, head, time1, time2)
- else:
- self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
-
- p_attn = self.dropout(self.attn)
- x = torch.matmul(p_attn, value) # (batch, head, time1, d_k)
- x = (
- x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k)
- ) # (batch, time1, d_model)
-
- return self.linear_out(x), self.attn # (batch, time1, d_model)
-
- def forward(self, query, key, value, mask):
- """Compute scaled dot product attention.
-
- Args:
- query (torch.Tensor): Query tensor (#batch, time1, size).
- key (torch.Tensor): Key tensor (#batch, time2, size).
- value (torch.Tensor): Value tensor (#batch, time2, size).
- mask (torch.Tensor): Mask tensor (#batch, 1, time2) or
- (#batch, time1, time2).
-
- Returns:
- torch.Tensor: Output tensor (#batch, time1, d_model).
-
- """
- q, k, v = self.forward_qkv(query, key, value)
- scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k)
- return self.forward_attention(v, scores, mask)
-
-
diff --git a/funasr/models/lcbnet/encoder.py b/funasr/models/lcbnet/encoder.py
deleted file mode 100644
index c65823c..0000000
--- a/funasr/models/lcbnet/encoder.py
+++ /dev/null
@@ -1,392 +0,0 @@
-# Copyright 2019 Shigeki Karita
-# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
-
-"""Transformer encoder definition."""
-
-from typing import List
-from typing import Optional
-from typing import Tuple
-
-import torch
-from torch import nn
-import logging
-
-from funasr.models.transformer.attention import MultiHeadedAttention
-from funasr.models.lcbnet.attention import MultiHeadedAttentionReturnWeight
-from funasr.models.transformer.embedding import PositionalEncoding
-from funasr.models.transformer.layer_norm import LayerNorm
-
-from funasr.models.transformer.utils.nets_utils import make_pad_mask
-from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward
-from funasr.models.transformer.utils.repeat import repeat
-from funasr.register import tables
-
-class EncoderLayer(nn.Module):
- """Encoder layer module.
-
- Args:
- size (int): Input dimension.
- self_attn (torch.nn.Module): Self-attention module instance.
- `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance
- can be used as the argument.
- feed_forward (torch.nn.Module): Feed-forward module instance.
- `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
- can be used as the argument.
- dropout_rate (float): Dropout rate.
- normalize_before (bool): Whether to use layer_norm before the first block.
- concat_after (bool): Whether to concat attention layer's input and output.
- if True, additional linear will be applied.
- i.e. x -> x + linear(concat(x, att(x)))
- if False, no additional linear will be applied. i.e. x -> x + att(x)
- stochastic_depth_rate (float): Proability to skip this layer.
- During training, the layer may skip residual computation and return input
- as-is with given probability.
- """
-
- def __init__(
- self,
- size,
- self_attn,
- feed_forward,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
- stochastic_depth_rate=0.0,
- ):
- """Construct an EncoderLayer object."""
- super(EncoderLayer, self).__init__()
- self.self_attn = self_attn
- self.feed_forward = feed_forward
- self.norm1 = LayerNorm(size)
- self.norm2 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
- self.size = size
- self.normalize_before = normalize_before
- self.concat_after = concat_after
- if self.concat_after:
- self.concat_linear = nn.Linear(size + size, size)
- self.stochastic_depth_rate = stochastic_depth_rate
-
- def forward(self, x, mask, cache=None):
- """Compute encoded features.
-
- Args:
- x_input (torch.Tensor): Input tensor (#batch, time, size).
- mask (torch.Tensor): Mask tensor for the input (#batch, time).
- cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size).
-
- Returns:
- torch.Tensor: Output tensor (#batch, time, size).
- torch.Tensor: Mask tensor (#batch, time).
-
- """
- skip_layer = False
- # with stochastic depth, residual connection `x + f(x)` becomes
- # `x <- x + 1 / (1 - p) * f(x)` at training time.
- stoch_layer_coeff = 1.0
- if self.training and self.stochastic_depth_rate > 0:
- skip_layer = torch.rand(1).item() < self.stochastic_depth_rate
- stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate)
-
- if skip_layer:
- if cache is not None:
- x = torch.cat([cache, x], dim=1)
- return x, mask
-
- residual = x
- if self.normalize_before:
- x = self.norm1(x)
-
- if cache is None:
- x_q = x
- else:
- assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size)
- x_q = x[:, -1:, :]
- residual = residual[:, -1:, :]
- mask = None if mask is None else mask[:, -1:, :]
-
- if self.concat_after:
- x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1)
- x = residual + stoch_layer_coeff * self.concat_linear(x_concat)
- else:
- x = residual + stoch_layer_coeff * self.dropout(
- self.self_attn(x_q, x, x, mask)
- )
- if not self.normalize_before:
- x = self.norm1(x)
-
- residual = x
- if self.normalize_before:
- x = self.norm2(x)
- x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x))
- if not self.normalize_before:
- x = self.norm2(x)
-
- if cache is not None:
- x = torch.cat([cache, x], dim=1)
-
- return x, mask
-
-@tables.register("encoder_classes", "TransformerTextEncoder")
-class TransformerTextEncoder(nn.Module):
- """Transformer text encoder module.
-
- Args:
- input_size: input dim
- output_size: dimension of attention
- attention_heads: the number of heads of multi head attention
- linear_units: the number of units of position-wise feed forward
- num_blocks: the number of decoder blocks
- dropout_rate: dropout rate
- attention_dropout_rate: dropout rate in attention
- positional_dropout_rate: dropout rate after adding positional encoding
- input_layer: input layer type
- pos_enc_class: PositionalEncoding or ScaledPositionalEncoding
- normalize_before: whether to use layer_norm before the first block
- concat_after: whether to concat attention layer's input and output
- if True, additional linear will be applied.
- i.e. x -> x + linear(concat(x, att(x)))
- if False, no additional linear will be applied.
- i.e. x -> x + att(x)
- positionwise_layer_type: linear of conv1d
- positionwise_conv_kernel_size: kernel size of positionwise conv1d layer
- padding_idx: padding_idx for input_layer=embed
- """
-
- def __init__(
- self,
- input_size: int,
- output_size: int = 256,
- attention_heads: int = 4,
- linear_units: int = 2048,
- num_blocks: int = 6,
- dropout_rate: float = 0.1,
- positional_dropout_rate: float = 0.1,
- attention_dropout_rate: float = 0.0,
- pos_enc_class=PositionalEncoding,
- normalize_before: bool = True,
- concat_after: bool = False,
- ):
- super().__init__()
- self._output_size = output_size
-
- self.embed = torch.nn.Sequential(
- torch.nn.Embedding(input_size, output_size),
- pos_enc_class(output_size, positional_dropout_rate),
- )
-
- self.normalize_before = normalize_before
-
- positionwise_layer = PositionwiseFeedForward
- positionwise_layer_args = (
- output_size,
- linear_units,
- dropout_rate,
- )
- self.encoders = repeat(
- num_blocks,
- lambda lnum: EncoderLayer(
- output_size,
- MultiHeadedAttention(
- attention_heads, output_size, attention_dropout_rate
- ),
- positionwise_layer(*positionwise_layer_args),
- dropout_rate,
- normalize_before,
- concat_after,
- ),
- )
- if self.normalize_before:
- self.after_norm = LayerNorm(output_size)
-
- def output_size(self) -> int:
- return self._output_size
-
- def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
- """Embed positions in tensor.
-
- Args:
- xs_pad: input tensor (B, L, D)
- ilens: input length (B)
- Returns:
- position embedded tensor and mask
- """
- masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device)
- xs_pad = self.embed(xs_pad)
-
- xs_pad, masks = self.encoders(xs_pad, masks)
-
- if self.normalize_before:
- xs_pad = self.after_norm(xs_pad)
-
- olens = masks.squeeze(1).sum(1)
- return xs_pad, olens, None
-
-
-
-
-@tables.register("encoder_classes", "FusionSANEncoder")
-class SelfSrcAttention(nn.Module):
- """Single decoder layer module.
-
- Args:
- size (int): Input dimension.
- self_attn (torch.nn.Module): Self-attention module instance.
- `MultiHeadedAttention` instance can be used as the argument.
- src_attn (torch.nn.Module): Self-attention module instance.
- `MultiHeadedAttention` instance can be used as the argument.
- feed_forward (torch.nn.Module): Feed-forward module instance.
- `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance
- can be used as the argument.
- dropout_rate (float): Dropout rate.
- normalize_before (bool): Whether to use layer_norm before the first block.
- concat_after (bool): Whether to concat attention layer's input and output.
- if True, additional linear will be applied.
- i.e. x -> x + linear(concat(x, att(x)))
- if False, no additional linear will be applied. i.e. x -> x + att(x)
-
-
- """
- def __init__(
- self,
- size,
- attention_heads,
- attention_dim,
- linear_units,
- self_attention_dropout_rate,
- src_attention_dropout_rate,
- positional_dropout_rate,
- dropout_rate,
- normalize_before=True,
- concat_after=False,
- ):
- """Construct an SelfSrcAttention object."""
- super(SelfSrcAttention, self).__init__()
- self.size = size
- self.self_attn = MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate)
- self.src_attn = MultiHeadedAttentionReturnWeight(attention_heads, attention_dim, src_attention_dropout_rate)
- self.feed_forward = PositionwiseFeedForward(attention_dim, linear_units, positional_dropout_rate)
- self.norm1 = LayerNorm(size)
- self.norm2 = LayerNorm(size)
- self.norm3 = LayerNorm(size)
- self.dropout = nn.Dropout(dropout_rate)
- self.normalize_before = normalize_before
- self.concat_after = concat_after
- if self.concat_after:
- self.concat_linear1 = nn.Linear(size + size, size)
- self.concat_linear2 = nn.Linear(size + size, size)
-
- def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None):
- """Compute decoded features.
-
- Args:
- tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size).
- tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out).
- memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size).
- memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in).
- cache (List[torch.Tensor]): List of cached tensors.
- Each tensor shape should be (#batch, maxlen_out - 1, size).
-
- Returns:
- torch.Tensor: Output tensor(#batch, maxlen_out, size).
- torch.Tensor: Mask for output tensor (#batch, maxlen_out).
- torch.Tensor: Encoded memory (#batch, maxlen_in, size).
- torch.Tensor: Encoded memory mask (#batch, maxlen_in).
-
- """
- residual = tgt
- if self.normalize_before:
- tgt = self.norm1(tgt)
-
- if cache is None:
- tgt_q = tgt
- tgt_q_mask = tgt_mask
- else:
- # compute only the last frame query keeping dim: max_time_out -> 1
- assert cache.shape == (
- tgt.shape[0],
- tgt.shape[1] - 1,
- self.size,
- ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}"
- tgt_q = tgt[:, -1:, :]
- residual = residual[:, -1:, :]
- tgt_q_mask = None
- if tgt_mask is not None:
- tgt_q_mask = tgt_mask[:, -1:, :]
-
- if self.concat_after:
- tgt_concat = torch.cat(
- (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1
- )
- x = residual + self.concat_linear1(tgt_concat)
- else:
- x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask))
- if not self.normalize_before:
- x = self.norm1(x)
-
- residual = x
- if self.normalize_before:
- x = self.norm2(x)
- if self.concat_after:
- x_concat = torch.cat(
- (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1
- )
- x = residual + self.concat_linear2(x_concat)
- else:
- x, score = self.src_attn(x, memory, memory, memory_mask)
- x = residual + self.dropout(x)
- if not self.normalize_before:
- x = self.norm2(x)
-
- residual = x
- if self.normalize_before:
- x = self.norm3(x)
- x = residual + self.dropout(self.feed_forward(x))
- if not self.normalize_before:
- x = self.norm3(x)
-
- if cache is not None:
- x = torch.cat([cache, x], dim=1)
-
- return x, tgt_mask, memory, memory_mask
-
-
-@tables.register("encoder_classes", "ConvBiasPredictor")
-class ConvPredictor(nn.Module):
- def __init__(self, size=256, l_order=3, r_order=3, attention_heads=4, attention_dropout_rate=0.1, linear_units=2048):
- super().__init__()
- self.atten = MultiHeadedAttention(attention_heads, size, attention_dropout_rate)
- self.norm1 = LayerNorm(size)
- self.feed_forward = PositionwiseFeedForward(size, linear_units, attention_dropout_rate)
- self.norm2 = LayerNorm(size)
- self.pad = nn.ConstantPad1d((l_order, r_order), 0)
- self.conv1d = nn.Conv1d(size, size, l_order + r_order + 1, groups=size)
- self.output_linear = nn.Linear(size, 1)
-
-
- def forward(self, text_enc, asr_enc):
- # stage1 cross-attention
- residual = text_enc
- text_enc = residual + self.atten(text_enc, asr_enc, asr_enc, None)
-
- # stage2 FFN
- residual = text_enc
- text_enc = self.norm1(text_enc)
- text_enc = residual + self.feed_forward(text_enc)
-
- # stage Conv predictor
- text_enc = self.norm2(text_enc)
- context = text_enc.transpose(1, 2)
- queries = self.pad(context)
- memory = self.conv1d(queries)
- output = memory + context
- output = output.transpose(1, 2)
- output = torch.relu(output)
- output = self.output_linear(output)
- if output.dim()==3:
- output = output.squeeze(2)
- return output
diff --git a/funasr/models/lcbnet/model.py b/funasr/models/lcbnet/model.py
deleted file mode 100644
index 3ac319c..0000000
--- a/funasr/models/lcbnet/model.py
+++ /dev/null
@@ -1,495 +0,0 @@
-#!/usr/bin/env python3
-# -*- encoding: utf-8 -*-
-# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
-# MIT License (https://opensource.org/licenses/MIT)
-
-import logging
-from typing import Union, Dict, List, Tuple, Optional
-
-import time
-import torch
-import torch.nn as nn
-from torch.cuda.amp import autocast
-
-from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
-from funasr.models.ctc.ctc import CTC
-from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
-from funasr.metrics.compute_acc import th_accuracy
-# from funasr.models.e2e_asr_common import ErrorCalculator
-from funasr.train_utils.device_funcs import force_gatherable
-from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-from funasr.utils import postprocess_utils
-from funasr.utils.datadir_writer import DatadirWriter
-from funasr.register import tables
-
-import pdb
-@tables.register("model_classes", "LCBNet")
-class LCBNet(nn.Module):
- """
- Author: Speech Lab of DAMO Academy, Alibaba Group
- LCB-NET: LONG-CONTEXT BIASING FOR AUDIO-VISUAL SPEECH RECOGNITION
- https://arxiv.org/abs/2401.06390
- """
-
- def __init__(
- self,
- specaug: str = None,
- specaug_conf: dict = None,
- normalize: str = None,
- normalize_conf: dict = None,
- encoder: str = None,
- encoder_conf: dict = None,
- decoder: str = None,
- decoder_conf: dict = None,
- text_encoder: str = None,
- text_encoder_conf: dict = None,
- bias_predictor: str = None,
- bias_predictor_conf: dict = None,
- fusion_encoder: str = None,
- fusion_encoder_conf: dict = None,
- ctc: str = None,
- ctc_conf: dict = None,
- ctc_weight: float = 0.5,
- interctc_weight: float = 0.0,
- select_num: int = 2,
- select_length: int = 3,
- insert_blank: bool = True,
- input_size: int = 80,
- vocab_size: int = -1,
- ignore_id: int = -1,
- blank_id: int = 0,
- sos: int = 1,
- eos: int = 2,
- lsm_weight: float = 0.0,
- length_normalized_loss: bool = False,
- report_cer: bool = True,
- report_wer: bool = True,
- sym_space: str = "<space>",
- sym_blank: str = "<blank>",
- # extract_feats_in_collect_stats: bool = True,
- share_embedding: bool = False,
- # preencoder: Optional[AbsPreEncoder] = None,
- # postencoder: Optional[AbsPostEncoder] = None,
- **kwargs,
- ):
-
- super().__init__()
-
- if specaug is not None:
- specaug_class = tables.specaug_classes.get(specaug)
- specaug = specaug_class(**specaug_conf)
- if normalize is not None:
- normalize_class = tables.normalize_classes.get(normalize)
- normalize = normalize_class(**normalize_conf)
- encoder_class = tables.encoder_classes.get(encoder)
- encoder = encoder_class(input_size=input_size, **encoder_conf)
- encoder_output_size = encoder.output_size()
-
- # lcbnet modules: text encoder, fusion encoder and bias predictor
- text_encoder_class = tables.encoder_classes.get(text_encoder)
- text_encoder = text_encoder_class(input_size=vocab_size, **text_encoder_conf)
- fusion_encoder_class = tables.encoder_classes.get(fusion_encoder)
- fusion_encoder = fusion_encoder_class(**fusion_encoder_conf)
- bias_predictor_class = tables.encoder_classes.get(bias_predictor)
- bias_predictor = bias_predictor_class(**bias_predictor_conf)
-
-
- if decoder is not None:
- decoder_class = tables.decoder_classes.get(decoder)
- decoder = decoder_class(
- vocab_size=vocab_size,
- encoder_output_size=encoder_output_size,
- **decoder_conf,
- )
- if ctc_weight > 0.0:
-
- if ctc_conf is None:
- ctc_conf = {}
-
- ctc = CTC(
- odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf
- )
-
- self.blank_id = blank_id
- self.sos = vocab_size - 1
- self.eos = vocab_size - 1
- self.vocab_size = vocab_size
- self.ignore_id = ignore_id
- self.ctc_weight = ctc_weight
- self.specaug = specaug
- self.normalize = normalize
- self.encoder = encoder
- # lcbnet
- self.text_encoder = text_encoder
- self.fusion_encoder = fusion_encoder
- self.bias_predictor = bias_predictor
- self.select_num = select_num
- self.select_length = select_length
- self.insert_blank = insert_blank
-
- if not hasattr(self.encoder, "interctc_use_conditioning"):
- self.encoder.interctc_use_conditioning = False
- if self.encoder.interctc_use_conditioning:
- self.encoder.conditioning_layer = torch.nn.Linear(
- vocab_size, self.encoder.output_size()
- )
- self.interctc_weight = interctc_weight
-
- # self.error_calculator = None
- if ctc_weight == 1.0:
- self.decoder = None
- else:
- self.decoder = decoder
-
- self.criterion_att = LabelSmoothingLoss(
- size=vocab_size,
- padding_idx=ignore_id,
- smoothing=lsm_weight,
- normalize_length=length_normalized_loss,
- )
- #
- # if report_cer or report_wer:
- # self.error_calculator = ErrorCalculator(
- # token_list, sym_space, sym_blank, report_cer, report_wer
- # )
- #
- self.error_calculator = None
- if ctc_weight == 0.0:
- self.ctc = None
- else:
- self.ctc = ctc
-
- self.share_embedding = share_embedding
- if self.share_embedding:
- self.decoder.embed = None
-
- self.length_normalized_loss = length_normalized_loss
- self.beam_search = None
-
- def forward(
- self,
- speech: torch.Tensor,
- speech_lengths: torch.Tensor,
- text: torch.Tensor,
- text_lengths: torch.Tensor,
- **kwargs,
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- """Encoder + Decoder + Calc loss
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- text: (Batch, Length)
- text_lengths: (Batch,)
- """
-
- if len(text_lengths.size()) > 1:
- text_lengths = text_lengths[:, 0]
- if len(speech_lengths.size()) > 1:
- speech_lengths = speech_lengths[:, 0]
-
- batch_size = speech.shape[0]
-
- # 1. Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
-
- loss_att, acc_att, cer_att, wer_att = None, None, None, None
- loss_ctc, cer_ctc = None, None
- stats = dict()
-
- # decoder: CTC branch
- if self.ctc_weight != 0.0:
- loss_ctc, cer_ctc = self._calc_ctc_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- # Collect CTC branch stats
- stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None
- stats["cer_ctc"] = cer_ctc
-
- # Intermediate CTC (optional)
- loss_interctc = 0.0
- if self.interctc_weight != 0.0 and intermediate_outs is not None:
- for layer_idx, intermediate_out in intermediate_outs:
- # we assume intermediate_out has the same length & padding
- # as those of encoder_out
- loss_ic, cer_ic = self._calc_ctc_loss(
- intermediate_out, encoder_out_lens, text, text_lengths
- )
- loss_interctc = loss_interctc + loss_ic
-
- # Collect Intermedaite CTC stats
- stats["loss_interctc_layer{}".format(layer_idx)] = (
- loss_ic.detach() if loss_ic is not None else None
- )
- stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic
-
- loss_interctc = loss_interctc / len(intermediate_outs)
-
- # calculate whole encoder loss
- loss_ctc = (
- 1 - self.interctc_weight
- ) * loss_ctc + self.interctc_weight * loss_interctc
-
- # decoder: Attention decoder branch
- loss_att, acc_att, cer_att, wer_att = self._calc_att_loss(
- encoder_out, encoder_out_lens, text, text_lengths
- )
-
- # 3. CTC-Att loss definition
- if self.ctc_weight == 0.0:
- loss = loss_att
- elif self.ctc_weight == 1.0:
- loss = loss_ctc
- else:
- loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att
-
- # Collect Attn branch stats
- stats["loss_att"] = loss_att.detach() if loss_att is not None else None
- stats["acc"] = acc_att
- stats["cer"] = cer_att
- stats["wer"] = wer_att
-
- # Collect total loss stats
- stats["loss"] = torch.clone(loss.detach())
-
- # force_gatherable: to-device and to-tensor if scalar for DataParallel
- if self.length_normalized_loss:
- batch_size = int((text_lengths + 1).sum())
- loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
- return loss, stats, weight
-
-
- def encode(
- self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs,
- ) -> Tuple[torch.Tensor, torch.Tensor]:
- """Frontend + Encoder. Note that this method is used by asr_inference.py
- Args:
- speech: (Batch, Length, ...)
- speech_lengths: (Batch, )
- ind: int
- """
- with autocast(False):
- # Data augmentation
- if self.specaug is not None and self.training:
- speech, speech_lengths = self.specaug(speech, speech_lengths)
- # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
- if self.normalize is not None:
- speech, speech_lengths = self.normalize(speech, speech_lengths)
- # Forward encoder
- # feats: (Batch, Length, Dim)
- # -> encoder_out: (Batch, Length2, Dim2)
- if self.encoder.interctc_use_conditioning:
- encoder_out, encoder_out_lens, _ = self.encoder(
- speech, speech_lengths, ctc=self.ctc
- )
- else:
- encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths)
- intermediate_outs = None
- if isinstance(encoder_out, tuple):
- intermediate_outs = encoder_out[1]
- encoder_out = encoder_out[0]
-
- if intermediate_outs is not None:
- return (encoder_out, intermediate_outs), encoder_out_lens
- return encoder_out, encoder_out_lens
-
- def _calc_att_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id)
- ys_in_lens = ys_pad_lens + 1
-
- # 1. Forward decoder
- decoder_out, _ = self.decoder(
- encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens
- )
-
- # 2. Compute attention loss
- loss_att = self.criterion_att(decoder_out, ys_out_pad)
- acc_att = th_accuracy(
- decoder_out.view(-1, self.vocab_size),
- ys_out_pad,
- ignore_label=self.ignore_id,
- )
-
- # Compute cer/wer using attention-decoder
- if self.training or self.error_calculator is None:
- cer_att, wer_att = None, None
- else:
- ys_hat = decoder_out.argmax(dim=-1)
- cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu())
-
- return loss_att, acc_att, cer_att, wer_att
-
- def _calc_ctc_loss(
- self,
- encoder_out: torch.Tensor,
- encoder_out_lens: torch.Tensor,
- ys_pad: torch.Tensor,
- ys_pad_lens: torch.Tensor,
- ):
- # Calc CTC loss
- loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens)
-
- # Calc CER using CTC
- cer_ctc = None
- if not self.training and self.error_calculator is not None:
- ys_hat = self.ctc.argmax(encoder_out).data
- cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True)
- return loss_ctc, cer_ctc
-
- def init_beam_search(self,
- **kwargs,
- ):
- from funasr.models.transformer.search import BeamSearch
- from funasr.models.transformer.scorers.ctc import CTCPrefixScorer
- from funasr.models.transformer.scorers.length_bonus import LengthBonus
-
- # 1. Build ASR model
- scorers = {}
-
- if self.ctc != None:
- ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos)
- scorers.update(
- ctc=ctc
- )
- token_list = kwargs.get("token_list")
- scorers.update(
- decoder=self.decoder,
- length_bonus=LengthBonus(len(token_list)),
- )
-
-
- # 3. Build ngram model
- # ngram is not supported now
- ngram = None
- scorers["ngram"] = ngram
-
- weights = dict(
- decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.3),
- ctc=kwargs.get("decoding_ctc_weight", 0.3),
- lm=kwargs.get("lm_weight", 0.0),
- ngram=kwargs.get("ngram_weight", 0.0),
- length_bonus=kwargs.get("penalty", 0.0),
- )
- beam_search = BeamSearch(
- beam_size=kwargs.get("beam_size", 20),
- weights=weights,
- scorers=scorers,
- sos=self.sos,
- eos=self.eos,
- vocab_size=len(token_list),
- token_list=token_list,
- pre_beam_score_key=None if self.ctc_weight == 1.0 else "full",
- )
-
- self.beam_search = beam_search
-
- def inference(self,
- data_in,
- data_lengths=None,
- key: list=None,
- tokenizer=None,
- frontend=None,
- **kwargs,
- ):
-
- if kwargs.get("batch_size", 1) > 1:
- raise NotImplementedError("batch decoding is not implemented")
-
- # init beamsearch
- if self.beam_search is None:
- logging.info("enable beam_search")
- self.init_beam_search(**kwargs)
- self.nbest = kwargs.get("nbest", 1)
-
- meta_data = {}
- if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank
- speech, speech_lengths = data_in, data_lengths
- if len(speech.shape) < 3:
- speech = speech[None, :, :]
- if speech_lengths is None:
- speech_lengths = speech.shape[1]
- else:
- # extract fbank feats
- time1 = time.perf_counter()
- sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000),
- data_type=kwargs.get("data_type", "sound"),
- tokenizer=tokenizer)
- time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
- audio_sample_list = sample_list[0]
- if len(sample_list) >1:
- ocr_sample_list = sample_list[1]
- else:
- ocr_sample_list = [[294, 0]]
- speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
- frontend=frontend)
- time3 = time.perf_counter()
- meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- frame_shift = 10
- meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift / 1000
-
- speech = speech.to(device=kwargs["device"])
- speech_lengths = speech_lengths.to(device=kwargs["device"])
- # Encoder
- encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
- if isinstance(encoder_out, tuple):
- encoder_out = encoder_out[0]
-
- ocr_list_new = [[x + 1 if x != 0 else x for x in sublist] for sublist in ocr_sample_list]
- ocr = torch.tensor(ocr_list_new).to(device=kwargs["device"])
- ocr_lengths = ocr.new_full([1], dtype=torch.long, fill_value=ocr.size(1)).to(device=kwargs["device"])
- ocr, ocr_lens, _ = self.text_encoder(ocr, ocr_lengths)
- fusion_out, _, _, _ = self.fusion_encoder(encoder_out,None, ocr, None)
- encoder_out = encoder_out + fusion_out
- # c. Passed the encoder result and the beam search
- nbest_hyps = self.beam_search(
- x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0)
- )
-
- nbest_hyps = nbest_hyps[: self.nbest]
-
- results = []
- b, n, d = encoder_out.size()
- for i in range(b):
-
- for nbest_idx, hyp in enumerate(nbest_hyps):
- ibest_writer = None
- if kwargs.get("output_dir") is not None:
- if not hasattr(self, "writer"):
- self.writer = DatadirWriter(kwargs.get("output_dir"))
- ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"]
-
- # remove sos/eos and get results
- last_pos = -1
- if isinstance(hyp.yseq, list):
- token_int = hyp.yseq[1:last_pos]
- else:
- token_int = hyp.yseq[1:last_pos].tolist()
-
- # remove blank symbol id, which is assumed to be 0
- token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
-
- # Change integer-ids to tokens
- token = tokenizer.ids2tokens(token_int)
- text = tokenizer.tokens2text(token)
-
- text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
- result_i = {"key": key[i], "token": token, "text": text_postprocessed}
- results.append(result_i)
-
- if ibest_writer is not None:
- ibest_writer["token"][key[i]] = " ".join(token)
- ibest_writer["text"][key[i]] = text_postprocessed
-
- return results, meta_data
-
diff --git a/funasr/models/seaco_paraformer/model.py b/funasr/models/seaco_paraformer/model.py
index a8b1f1f..20b0cc8 100644
--- a/funasr/models/seaco_paraformer/model.py
+++ b/funasr/models/seaco_paraformer/model.py
@@ -30,7 +30,7 @@
from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
-import pdb
+
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
from torch.cuda.amp import autocast
else:
@@ -128,7 +128,7 @@
hotword_pad = kwargs.get("hotword_pad")
hotword_lengths = kwargs.get("hotword_lengths")
dha_pad = kwargs.get("dha_pad")
-
+
batch_size = speech.shape[0]
# for data-parallel
text = text[:, : text_lengths.max()]
@@ -209,20 +209,17 @@
nfilter=50,
seaco_weight=1.0):
# decoder forward
-
decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True)
-
decoder_pred = torch.log_softmax(decoder_out, dim=-1)
if hw_list is not None:
hw_lengths = [len(i) for i in hw_list]
hw_list_ = [torch.Tensor(i).long() for i in hw_list]
hw_list_pad = pad_list(hw_list_, 0).to(encoder_out.device)
selected = self._hotword_representation(hw_list_pad, torch.Tensor(hw_lengths).int().to(encoder_out.device))
-
contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device)
num_hot_word = contextual_info.shape[1]
_contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device)
-
+
# ASF Core
if nfilter > 0 and nfilter < num_hot_word:
hotword_scores = self.seaco_decoder.forward_asf6(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens)
@@ -242,7 +239,7 @@
cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens)
dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens)
merged = self._merge(cif_attended, dec_attended)
-
+
dha_output = self.hotword_output_layer(merged) # remove the last token in loss calculation
dha_pred = torch.log_softmax(dha_output, dim=-1)
def _merge_res(dec_output, dha_output):
@@ -256,8 +253,8 @@
# logits = dec_output * dha_mask + dha_output[:,:,:-1] * (1-dha_mask)
logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask)
return logits
-
merged_pred = _merge_res(decoder_pred, dha_pred)
+ # import pdb; pdb.set_trace()
return merged_pred
else:
return decoder_pred
@@ -307,6 +304,7 @@
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
self.nbest = kwargs.get("nbest", 1)
+
meta_data = {}
# extract fbank feats
@@ -332,7 +330,6 @@
if isinstance(encoder_out, tuple):
encoder_out = encoder_out[0]
-
# predictor
predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens)
pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \
@@ -341,14 +338,15 @@
if torch.max(pre_token_length) < 1:
return []
+
decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens,
pre_acoustic_embeds,
pre_token_length,
hw_list=self.hotword_list)
-
# decoder_out, _ = decoder_outs[0], decoder_outs[1]
_, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
pre_token_length)
+
results = []
b, n, d = decoder_out.size()
for i in range(b):
diff --git a/funasr/train_utils/load_pretrained_model.py b/funasr/train_utils/load_pretrained_model.py
index 0c46449..ea23725 100644
--- a/funasr/train_utils/load_pretrained_model.py
+++ b/funasr/train_utils/load_pretrained_model.py
@@ -7,7 +7,7 @@
import torch
import torch.nn
import torch.optim
-import pdb
+
def filter_state_dict(
dst_state: Dict[str, Union[float, torch.Tensor]],
@@ -63,7 +63,6 @@
dst_state = obj.state_dict()
print(f"ckpt: {path}")
-
if oss_bucket is None:
src_state = torch.load(path, map_location=map_location)
else:
diff --git a/funasr/utils/load_utils.py b/funasr/utils/load_utils.py
index 84c38f9..7748172 100644
--- a/funasr/utils/load_utils.py
+++ b/funasr/utils/load_utils.py
@@ -13,25 +13,29 @@
from funasr.download.file import download_from_url
except:
print("urllib is not installed, if you infer from url, please install it first.")
-import pdb
+
def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None, **kwargs):
if isinstance(data_or_path_or_list, (list, tuple)):
if data_type is not None and isinstance(data_type, (list, tuple)):
+
data_types = [data_type] * len(data_or_path_or_list)
data_or_path_or_list_ret = [[] for d in data_type]
for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)):
+
for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)):
+
data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs)
data_or_path_or_list_ret[j].append(data_or_path_or_list_j)
return data_or_path_or_list_ret
else:
return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list]
+
if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file
data_or_path_or_list = download_from_url(data_or_path_or_list)
-
+
if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file
if data_type is None or data_type == "sound":
data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list)
@@ -52,22 +56,10 @@
data_or_path_or_list = tokenizer.encode(data_or_path_or_list)
elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point
data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze() # [n_samples,]
- elif isinstance(data_or_path_or_list, str) and data_type == "kaldi_ark":
- data_mat = kaldiio.load_mat(data_or_path_or_list)
- if isinstance(data_mat, tuple):
- audio_fs, mat = data_mat
- else:
- mat = data_mat
- if mat.dtype == 'int16' or mat.dtype == 'int32':
- mat = mat.astype(np.float64)
- mat = mat / 32768
- if mat.ndim ==2:
- mat = mat[:,0]
- data_or_path_or_list = mat
else:
pass
# print(f"unsupport data type: {data_or_path_or_list}, return raw data")
-
+
if audio_fs != fs and data_type != "text":
resampler = torchaudio.transforms.Resample(audio_fs, fs)
data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :]
@@ -89,6 +81,8 @@
return array
def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None, **kwargs):
+ # import pdb;
+ # pdb.set_trace()
if isinstance(data, np.ndarray):
data = torch.from_numpy(data)
if len(data.shape) < 2:
@@ -106,7 +100,9 @@
data_list.append(data_i)
data_len.append(data_i.shape[0])
data = pad_sequence(data_list, batch_first=True) # data: [batch, N]
-
+ # import pdb;
+ # pdb.set_trace()
+ # if data_type == "sound":
data, data_len = frontend(data, data_len, **kwargs)
if isinstance(data_len, (list, tuple)):
--
Gitblit v1.9.1