old mode 100644
new mode 100755
old mode 100644
new mode 100755
| | |
| | | model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" |
| | | model_revision="v2.0.4" |
| | | |
| | | python funasr/bin/inference.py \ |
| | | python ../../../funasr/bin/inference.py \ |
| | | +model=${model} \ |
| | | +model_revision=${model_revision} \ |
| | | +input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav" \ |
| New file |
| | |
| | | python -m funasr.bin.inference \ |
| | | --config-path="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \ |
| | | --config-name="config.yaml" \ |
| | | ++init_param="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/model.pb" \ |
| | | ++tokenizer_conf.token_list="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/tokens.txt" \ |
| | | ++frontend_conf.cmvn_file="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/am.mvn" \ |
| | | ++input="/nfs/yufan.yf/workspace/model_download/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/asr_example_zh.wav" \ |
| | | ++output_dir="./outputs/debug2" \ |
| | | ++device="" \ |
| New file |
| | |
| | | export FUNASR_DIR=$PWD/../../../ |
| | | |
| | | # NOTE(kan-bayashi): Use UTF-8 in Python to avoid UnicodeDecodeError when LC_ALL=C |
| | | export PYTHONIOENCODING=UTF-8 |
| | | export PATH=$FUNASR_DIR/funasr/bin:$PATH |
| | | export PYTHONPATH=$FUNASR_DIR/funasr/bin:$FUNASR_DIR/funasr:$FUNASR_DIR:$PYTHONPATH |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # -*- coding: utf-8 -*- |
| | | |
| | | |
| | | from enum import Enum |
| | | import re, sys, unicodedata |
| | | import codecs |
| | | import argparse |
| | | from tqdm import tqdm |
| | | import os |
| | | import pdb |
| | | remove_tag = False |
| | | spacelist = [" ", "\t", "\r", "\n"] |
| | | puncts = [ |
| | | "!", |
| | | ",", |
| | | "?", |
| | | "、", |
| | | "。", |
| | | "!", |
| | | ",", |
| | | ";", |
| | | "?", |
| | | ":", |
| | | "「", |
| | | "」", |
| | | "︰", |
| | | "『", |
| | | "』", |
| | | "《", |
| | | "》", |
| | | ] |
| | | |
| | | |
| | | class Code(Enum): |
| | | match = 1 |
| | | substitution = 2 |
| | | insertion = 3 |
| | | deletion = 4 |
| | | |
| | | |
| | | class WordError(object): |
| | | def __init__(self): |
| | | self.errors = { |
| | | Code.substitution: 0, |
| | | Code.insertion: 0, |
| | | Code.deletion: 0, |
| | | } |
| | | self.ref_words = 0 |
| | | |
| | | def get_wer(self): |
| | | assert self.ref_words != 0 |
| | | errors = ( |
| | | self.errors[Code.substitution] |
| | | + self.errors[Code.insertion] |
| | | + self.errors[Code.deletion] |
| | | ) |
| | | return 100.0 * errors / self.ref_words |
| | | |
| | | def get_result_string(self): |
| | | return ( |
| | | f"error_rate={self.get_wer():.4f}, " |
| | | f"ref_words={self.ref_words}, " |
| | | f"subs={self.errors[Code.substitution]}, " |
| | | f"ins={self.errors[Code.insertion]}, " |
| | | f"dels={self.errors[Code.deletion]}" |
| | | ) |
| | | |
| | | |
| | | def characterize(string): |
| | | res = [] |
| | | i = 0 |
| | | while i < len(string): |
| | | char = string[i] |
| | | if char in puncts: |
| | | i += 1 |
| | | continue |
| | | cat1 = unicodedata.category(char) |
| | | # https://unicodebook.readthedocs.io/unicode.html#unicode-categories |
| | | if cat1 == "Zs" or cat1 == "Cn" or char in spacelist: # space or not assigned |
| | | i += 1 |
| | | continue |
| | | if cat1 == "Lo": # letter-other |
| | | res.append(char) |
| | | i += 1 |
| | | else: |
| | | # some input looks like: <unk><noise>, we want to separate it to two words. |
| | | sep = " " |
| | | if char == "<": |
| | | sep = ">" |
| | | j = i + 1 |
| | | while j < len(string): |
| | | c = string[j] |
| | | if ord(c) >= 128 or (c in spacelist) or (c == sep): |
| | | break |
| | | j += 1 |
| | | if j < len(string) and string[j] == ">": |
| | | j += 1 |
| | | res.append(string[i:j]) |
| | | i = j |
| | | return res |
| | | |
| | | |
| | | def stripoff_tags(x): |
| | | if not x: |
| | | return "" |
| | | chars = [] |
| | | i = 0 |
| | | T = len(x) |
| | | while i < T: |
| | | if x[i] == "<": |
| | | while i < T and x[i] != ">": |
| | | i += 1 |
| | | i += 1 |
| | | else: |
| | | chars.append(x[i]) |
| | | i += 1 |
| | | return "".join(chars) |
| | | |
| | | |
| | | def normalize(sentence, ignore_words, cs, split=None): |
| | | """sentence, ignore_words are both in unicode""" |
| | | new_sentence = [] |
| | | for token in sentence: |
| | | x = token |
| | | if not cs: |
| | | x = x.upper() |
| | | if x in ignore_words: |
| | | continue |
| | | if remove_tag: |
| | | x = stripoff_tags(x) |
| | | if not x: |
| | | continue |
| | | if split and x in split: |
| | | new_sentence += split[x] |
| | | else: |
| | | new_sentence.append(x) |
| | | return new_sentence |
| | | |
| | | |
| | | class Calculator: |
| | | def __init__(self): |
| | | self.data = {} |
| | | self.space = [] |
| | | self.cost = {} |
| | | self.cost["cor"] = 0 |
| | | self.cost["sub"] = 1 |
| | | self.cost["del"] = 1 |
| | | self.cost["ins"] = 1 |
| | | |
| | | def calculate(self, lab, rec): |
| | | # Initialization |
| | | lab.insert(0, "") |
| | | rec.insert(0, "") |
| | | while len(self.space) < len(lab): |
| | | self.space.append([]) |
| | | for row in self.space: |
| | | for element in row: |
| | | element["dist"] = 0 |
| | | element["error"] = "non" |
| | | while len(row) < len(rec): |
| | | row.append({"dist": 0, "error": "non"}) |
| | | for i in range(len(lab)): |
| | | self.space[i][0]["dist"] = i |
| | | self.space[i][0]["error"] = "del" |
| | | for j in range(len(rec)): |
| | | self.space[0][j]["dist"] = j |
| | | self.space[0][j]["error"] = "ins" |
| | | self.space[0][0]["error"] = "non" |
| | | for token in lab: |
| | | if token not in self.data and len(token) > 0: |
| | | self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0} |
| | | for token in rec: |
| | | if token not in self.data and len(token) > 0: |
| | | self.data[token] = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0} |
| | | # Computing edit distance |
| | | for i, lab_token in enumerate(lab): |
| | | for j, rec_token in enumerate(rec): |
| | | if i == 0 or j == 0: |
| | | continue |
| | | min_dist = sys.maxsize |
| | | min_error = "none" |
| | | dist = self.space[i - 1][j]["dist"] + self.cost["del"] |
| | | error = "del" |
| | | if dist < min_dist: |
| | | min_dist = dist |
| | | min_error = error |
| | | dist = self.space[i][j - 1]["dist"] + self.cost["ins"] |
| | | error = "ins" |
| | | if dist < min_dist: |
| | | min_dist = dist |
| | | min_error = error |
| | | if lab_token == rec_token.replace("<BIAS>", ""): |
| | | dist = self.space[i - 1][j - 1]["dist"] + self.cost["cor"] |
| | | error = "cor" |
| | | else: |
| | | dist = self.space[i - 1][j - 1]["dist"] + self.cost["sub"] |
| | | error = "sub" |
| | | if dist < min_dist: |
| | | min_dist = dist |
| | | min_error = error |
| | | self.space[i][j]["dist"] = min_dist |
| | | self.space[i][j]["error"] = min_error |
| | | # Tracing back |
| | | result = { |
| | | "lab": [], |
| | | "rec": [], |
| | | "code": [], |
| | | "all": 0, |
| | | "cor": 0, |
| | | "sub": 0, |
| | | "ins": 0, |
| | | "del": 0, |
| | | } |
| | | i = len(lab) - 1 |
| | | j = len(rec) - 1 |
| | | while True: |
| | | if self.space[i][j]["error"] == "cor": # correct |
| | | if len(lab[i]) > 0: |
| | | self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1 |
| | | self.data[lab[i]]["cor"] = self.data[lab[i]]["cor"] + 1 |
| | | result["all"] = result["all"] + 1 |
| | | result["cor"] = result["cor"] + 1 |
| | | result["lab"].insert(0, lab[i]) |
| | | result["rec"].insert(0, rec[j]) |
| | | result["code"].insert(0, Code.match) |
| | | i = i - 1 |
| | | j = j - 1 |
| | | elif self.space[i][j]["error"] == "sub": # substitution |
| | | if len(lab[i]) > 0: |
| | | self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1 |
| | | self.data[lab[i]]["sub"] = self.data[lab[i]]["sub"] + 1 |
| | | result["all"] = result["all"] + 1 |
| | | result["sub"] = result["sub"] + 1 |
| | | result["lab"].insert(0, lab[i]) |
| | | result["rec"].insert(0, rec[j]) |
| | | result["code"].insert(0, Code.substitution) |
| | | i = i - 1 |
| | | j = j - 1 |
| | | elif self.space[i][j]["error"] == "del": # deletion |
| | | if len(lab[i]) > 0: |
| | | self.data[lab[i]]["all"] = self.data[lab[i]]["all"] + 1 |
| | | self.data[lab[i]]["del"] = self.data[lab[i]]["del"] + 1 |
| | | result["all"] = result["all"] + 1 |
| | | result["del"] = result["del"] + 1 |
| | | result["lab"].insert(0, lab[i]) |
| | | result["rec"].insert(0, "") |
| | | result["code"].insert(0, Code.deletion) |
| | | i = i - 1 |
| | | elif self.space[i][j]["error"] == "ins": # insertion |
| | | if len(rec[j]) > 0: |
| | | self.data[rec[j]]["ins"] = self.data[rec[j]]["ins"] + 1 |
| | | result["ins"] = result["ins"] + 1 |
| | | result["lab"].insert(0, "") |
| | | result["rec"].insert(0, rec[j]) |
| | | result["code"].insert(0, Code.insertion) |
| | | j = j - 1 |
| | | elif self.space[i][j]["error"] == "non": # starting point |
| | | break |
| | | else: # shouldn't reach here |
| | | print( |
| | | "this should not happen , i = {i} , j = {j} , error = {error}".format( |
| | | i=i, j=j, error=self.space[i][j]["error"] |
| | | ) |
| | | ) |
| | | return result |
| | | |
| | | def overall(self): |
| | | result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0} |
| | | for token in self.data: |
| | | result["all"] = result["all"] + self.data[token]["all"] |
| | | result["cor"] = result["cor"] + self.data[token]["cor"] |
| | | result["sub"] = result["sub"] + self.data[token]["sub"] |
| | | result["ins"] = result["ins"] + self.data[token]["ins"] |
| | | result["del"] = result["del"] + self.data[token]["del"] |
| | | return result |
| | | |
| | | def cluster(self, data): |
| | | result = {"all": 0, "cor": 0, "sub": 0, "ins": 0, "del": 0} |
| | | for token in data: |
| | | if token in self.data: |
| | | result["all"] = result["all"] + self.data[token]["all"] |
| | | result["cor"] = result["cor"] + self.data[token]["cor"] |
| | | result["sub"] = result["sub"] + self.data[token]["sub"] |
| | | result["ins"] = result["ins"] + self.data[token]["ins"] |
| | | result["del"] = result["del"] + self.data[token]["del"] |
| | | return result |
| | | |
| | | def keys(self): |
| | | return list(self.data.keys()) |
| | | |
| | | |
| | | def width(string): |
| | | return sum(1 + (unicodedata.east_asian_width(c) in "AFW") for c in string) |
| | | |
| | | |
| | | def default_cluster(word): |
| | | unicode_names = [unicodedata.name(char) for char in word] |
| | | for i in reversed(range(len(unicode_names))): |
| | | if unicode_names[i].startswith("DIGIT"): # 1 |
| | | unicode_names[i] = "Number" # 'DIGIT' |
| | | elif unicode_names[i].startswith("CJK UNIFIED IDEOGRAPH") or unicode_names[ |
| | | i |
| | | ].startswith("CJK COMPATIBILITY IDEOGRAPH"): |
| | | # 明 / 郎 |
| | | unicode_names[i] = "Mandarin" # 'CJK IDEOGRAPH' |
| | | elif unicode_names[i].startswith("LATIN CAPITAL LETTER") or unicode_names[ |
| | | i |
| | | ].startswith("LATIN SMALL LETTER"): |
| | | # A / a |
| | | unicode_names[i] = "English" # 'LATIN LETTER' |
| | | elif unicode_names[i].startswith("HIRAGANA LETTER"): # は こ め |
| | | unicode_names[i] = "Japanese" # 'GANA LETTER' |
| | | elif ( |
| | | unicode_names[i].startswith("AMPERSAND") |
| | | or unicode_names[i].startswith("APOSTROPHE") |
| | | or unicode_names[i].startswith("COMMERCIAL AT") |
| | | or unicode_names[i].startswith("DEGREE CELSIUS") |
| | | or unicode_names[i].startswith("EQUALS SIGN") |
| | | or unicode_names[i].startswith("FULL STOP") |
| | | or unicode_names[i].startswith("HYPHEN-MINUS") |
| | | or unicode_names[i].startswith("LOW LINE") |
| | | or unicode_names[i].startswith("NUMBER SIGN") |
| | | or unicode_names[i].startswith("PLUS SIGN") |
| | | or unicode_names[i].startswith("SEMICOLON") |
| | | ): |
| | | # & / ' / @ / ℃ / = / . / - / _ / # / + / ; |
| | | del unicode_names[i] |
| | | else: |
| | | return "Other" |
| | | if len(unicode_names) == 0: |
| | | return "Other" |
| | | if len(unicode_names) == 1: |
| | | return unicode_names[0] |
| | | for i in range(len(unicode_names) - 1): |
| | | if unicode_names[i] != unicode_names[i + 1]: |
| | | return "Other" |
| | | return unicode_names[0] |
| | | |
| | | |
| | | def get_args(): |
| | | parser = argparse.ArgumentParser(description="wer cal") |
| | | parser.add_argument("--ref", type=str, help="Text input path") |
| | | parser.add_argument("--ref_ocr", type=str, help="Text input path") |
| | | parser.add_argument("--rec_name", type=str, action="append", default=[]) |
| | | parser.add_argument("--rec_file", type=str, action="append", default=[]) |
| | | parser.add_argument("--verbose", type=int, default=1, help="show") |
| | | parser.add_argument("--char", type=bool, default=True, help="show") |
| | | args = parser.parse_args() |
| | | return args |
| | | |
| | | |
| | | def main(args): |
| | | cluster_file = "" |
| | | ignore_words = set() |
| | | tochar = args.char |
| | | verbose = args.verbose |
| | | padding_symbol = " " |
| | | case_sensitive = False |
| | | max_words_per_line = sys.maxsize |
| | | split = None |
| | | |
| | | if not case_sensitive: |
| | | ig = set([w.upper() for w in ignore_words]) |
| | | ignore_words = ig |
| | | |
| | | default_clusters = {} |
| | | default_words = {} |
| | | ref_file = args.ref |
| | | ref_ocr = args.ref_ocr |
| | | rec_files = args.rec_file |
| | | rec_names = args.rec_name |
| | | assert len(rec_files) == len(rec_names) |
| | | |
| | | # load ocr |
| | | ref_ocr_dict = {} |
| | | with codecs.open(ref_ocr, "r", "utf-8") as fh: |
| | | for line in fh: |
| | | if "$" in line: |
| | | line = line.replace("$", " ") |
| | | if tochar: |
| | | array = characterize(line) |
| | | else: |
| | | array = line.strip().split() |
| | | if len(array) == 0: |
| | | continue |
| | | fid = array[0] |
| | | ref_ocr_dict[fid] = normalize(array[1:], ignore_words, case_sensitive, split) |
| | | |
| | | if split and not case_sensitive: |
| | | newsplit = dict() |
| | | for w in split: |
| | | words = split[w] |
| | | for i in range(len(words)): |
| | | words[i] = words[i].upper() |
| | | newsplit[w.upper()] = words |
| | | split = newsplit |
| | | |
| | | rec_sets = {} |
| | | calculators_dict = dict() |
| | | ub_wer_dict = dict() |
| | | hotwords_related_dict = dict() # 记录recall相关的内容 |
| | | for i, hyp_file in enumerate(rec_files): |
| | | rec_sets[rec_names[i]] = dict() |
| | | with codecs.open(hyp_file, "r", "utf-8") as fh: |
| | | for line in fh: |
| | | if tochar: |
| | | array = characterize(line) |
| | | else: |
| | | array = line.strip().split() |
| | | if len(array) == 0: |
| | | continue |
| | | fid = array[0] |
| | | rec_sets[rec_names[i]][fid] = normalize(array[1:], ignore_words, case_sensitive, split) |
| | | |
| | | calculators_dict[rec_names[i]] = Calculator() |
| | | ub_wer_dict[rec_names[i]] = {"u_wer": WordError(), "b_wer": WordError(), "wer": WordError()} |
| | | hotwords_related_dict[rec_names[i]] = {'tp': 0, 'tn': 0, 'fp': 0, 'fn': 0} |
| | | # tp: 热词在label里,同时在rec里 |
| | | # tn: 热词不在label里,同时不在rec里 |
| | | # fp: 热词不在label里,但是在rec里 |
| | | # fn: 热词在label里,但是不在rec里 |
| | | |
| | | # record wrong label but in ocr |
| | | wrong_rec_but_in_ocr_dict = {} |
| | | for rec_name in rec_names: |
| | | wrong_rec_but_in_ocr_dict[rec_name] = 0 |
| | | |
| | | _file_total_len = 0 |
| | | with os.popen("cat {} | wc -l".format(ref_file)) as pipe: |
| | | _file_total_len = int(pipe.read().strip()) |
| | | |
| | | # compute error rate on the interaction of reference file and hyp file |
| | | for line in tqdm(open(ref_file, 'r', encoding='utf-8'), total=_file_total_len): |
| | | if tochar: |
| | | array = characterize(line) |
| | | else: |
| | | array = line.rstrip('\n').split() |
| | | if len(array) == 0: continue |
| | | fid = array[0] |
| | | lab = normalize(array[1:], ignore_words, case_sensitive, split) |
| | | |
| | | if verbose: |
| | | print('\nutt: %s' % fid) |
| | | |
| | | ocr_text = ref_ocr_dict[fid] |
| | | ocr_set = set(ocr_text) |
| | | print('ocr: {}'.format(" ".join(ocr_text))) |
| | | list_match = [] # 指label里面在ocr里面的内容 |
| | | list_not_mathch = [] |
| | | tmp_error = 0 |
| | | tmp_match = 0 |
| | | for index in range(len(lab)): |
| | | # text_list.append(uttlist[index+1]) |
| | | if lab[index] not in ocr_set: |
| | | tmp_error += 1 |
| | | list_not_mathch.append(lab[index]) |
| | | else: |
| | | tmp_match += 1 |
| | | list_match.append(lab[index]) |
| | | print('label in ocr: {}'.format(" ".join(list_match))) |
| | | |
| | | # for each reco file |
| | | base_wrong_ocr_wer = None |
| | | ocr_wrong_ocr_wer = None |
| | | |
| | | for rec_name in rec_names: |
| | | rec_set = rec_sets[rec_name] |
| | | if fid not in rec_set: |
| | | continue |
| | | rec = rec_set[fid] |
| | | |
| | | # print(rec) |
| | | for word in rec + lab: |
| | | if word not in default_words: |
| | | default_cluster_name = default_cluster(word) |
| | | if default_cluster_name not in default_clusters: |
| | | default_clusters[default_cluster_name] = {} |
| | | if word not in default_clusters[default_cluster_name]: |
| | | default_clusters[default_cluster_name][word] = 1 |
| | | default_words[word] = default_cluster_name |
| | | |
| | | result = calculators_dict[rec_name].calculate(lab.copy(), rec.copy()) |
| | | if verbose: |
| | | if result['all'] != 0: |
| | | wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] |
| | | else: |
| | | wer = 0.0 |
| | | print('WER(%s): %4.2f %%' % (rec_name, wer), end=' ') |
| | | print('N=%d C=%d S=%d D=%d I=%d' % |
| | | (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) |
| | | |
| | | |
| | | # print(result['rec']) |
| | | wrong_rec_but_in_ocr = [] |
| | | for idx in range(len(result['lab'])): |
| | | if result['lab'][idx] != "": |
| | | if result['lab'][idx] != result['rec'][idx].replace("<BIAS>", ""): |
| | | if result['lab'][idx] in list_match: |
| | | wrong_rec_but_in_ocr.append(result['lab'][idx]) |
| | | wrong_rec_but_in_ocr_dict[rec_name] += 1 |
| | | print('wrong_rec_but_in_ocr: {}'.format(" ".join(wrong_rec_but_in_ocr))) |
| | | |
| | | if rec_name == "base": |
| | | base_wrong_ocr_wer = len(wrong_rec_but_in_ocr) |
| | | if "ocr" in rec_name or "hot" in rec_name: |
| | | ocr_wrong_ocr_wer = len(wrong_rec_but_in_ocr) |
| | | if ocr_wrong_ocr_wer < base_wrong_ocr_wer: |
| | | print("{} {} helps, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer)) |
| | | elif ocr_wrong_ocr_wer > base_wrong_ocr_wer: |
| | | print("{} {} hurts, {} -> {}".format(fid, rec_name, base_wrong_ocr_wer, ocr_wrong_ocr_wer)) |
| | | |
| | | # recall = 0 |
| | | # false_alarm = 0 |
| | | # for idx in range(len(result['lab'])): |
| | | # if "<BIAS>" in result['rec'][idx]: |
| | | # if result['rec'][idx].replace("<BIAS>", "") in list_match: |
| | | # recall += 1 |
| | | # else: |
| | | # false_alarm += 1 |
| | | # print("bias hotwords recall: {}, fa: {}, list_match {}, recall: {:.2f}, fa: {:.2f}".format( |
| | | # recall, false_alarm, len(list_match), recall / len(list_match) if len(list_match) != 0 else 0, false_alarm / len(list_match) if len(list_match) != 0 else 0 |
| | | # )) |
| | | # tp: 热词在label里,同时在rec里 |
| | | # tn: 热词不在label里,同时不在rec里 |
| | | # fp: 热词不在label里,但是在rec里 |
| | | # fn: 热词在label里,但是不在rec里 |
| | | _rec_list = [word.replace("<BIAS>", "") for word in rec] |
| | | _label_list = [word for word in lab] |
| | | _tp = _tn = _fp = _fn = 0 |
| | | hot_true_list = [hotword for hotword in ocr_text if hotword in _label_list] |
| | | hot_bad_list = [hotword for hotword in ocr_text if hotword not in _label_list] |
| | | for badhotword in hot_bad_list: |
| | | count = len([word for word in _rec_list if word == badhotword]) |
| | | # print(f"bad {badhotword} count: {count}") |
| | | # for word in _rec_list: |
| | | # if badhotword == word: |
| | | # count += 1 |
| | | if count == 0: |
| | | hotwords_related_dict[rec_name]['tn'] += 1 |
| | | _tn += 1 |
| | | # fp: 0 |
| | | else: |
| | | hotwords_related_dict[rec_name]['fp'] += count |
| | | _fp += count |
| | | # tn: 0 |
| | | # if badhotword in _rec_list: |
| | | # hotwords_related_dict[rec_name]['fp'] += 1 |
| | | # else: |
| | | # hotwords_related_dict[rec_name]['tn'] += 1 |
| | | for hotword in hot_true_list: |
| | | true_count = len([word for word in _label_list if hotword == word]) |
| | | rec_count = len([word for word in _rec_list if hotword == word]) |
| | | # print(f"good {hotword} true_count: {true_count}, rec_count: {rec_count}") |
| | | if rec_count == true_count: |
| | | hotwords_related_dict[rec_name]['tp'] += true_count |
| | | _tp += true_count |
| | | elif rec_count > true_count: |
| | | hotwords_related_dict[rec_name]['tp'] += true_count |
| | | # fp: 不在label里,但是在rec里 |
| | | hotwords_related_dict[rec_name]['fp'] += rec_count - true_count |
| | | _tp += true_count |
| | | _fp += rec_count - true_count |
| | | else: |
| | | hotwords_related_dict[rec_name]['tp'] += rec_count |
| | | # fn: 热词在label里,但是不在rec里 |
| | | hotwords_related_dict[rec_name]['fn'] += true_count - rec_count |
| | | _tp += rec_count |
| | | _fn += true_count - rec_count |
| | | print("hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%".format( |
| | | _tp, _tn, _fp, _fn, sum([_tp, _tn, _fp, _fn]), _tp / (_tp + _fn) * 100 if (_tp + _fn) != 0 else 0 |
| | | )) |
| | | |
| | | # if hotword in _rec_list: |
| | | # hotwords_related_dict[rec_name]['tp'] += 1 |
| | | # else: |
| | | # hotwords_related_dict[rec_name]['fn'] += 1 |
| | | # 计算uwer, bwer, wer |
| | | for code, rec_word, lab_word in zip(result["code"], result["rec"], result["lab"]): |
| | | if code == Code.match: |
| | | ub_wer_dict[rec_name]["wer"].ref_words += 1 |
| | | if lab_word in hot_true_list: |
| | | # tmp_ref.append(ref_tokens[ref_idx]) |
| | | ub_wer_dict[rec_name]["b_wer"].ref_words += 1 |
| | | else: |
| | | ub_wer_dict[rec_name]["u_wer"].ref_words += 1 |
| | | elif code == Code.substitution: |
| | | ub_wer_dict[rec_name]["wer"].ref_words += 1 |
| | | ub_wer_dict[rec_name]["wer"].errors[Code.substitution] += 1 |
| | | if lab_word in hot_true_list: |
| | | # tmp_ref.append(ref_tokens[ref_idx]) |
| | | ub_wer_dict[rec_name]["b_wer"].ref_words += 1 |
| | | ub_wer_dict[rec_name]["b_wer"].errors[Code.substitution] += 1 |
| | | else: |
| | | ub_wer_dict[rec_name]["u_wer"].ref_words += 1 |
| | | ub_wer_dict[rec_name]["u_wer"].errors[Code.substitution] += 1 |
| | | elif code == Code.deletion: |
| | | ub_wer_dict[rec_name]["wer"].ref_words += 1 |
| | | ub_wer_dict[rec_name]["wer"].errors[Code.deletion] += 1 |
| | | if lab_word in hot_true_list: |
| | | # tmp_ref.append(ref_tokens[ref_idx]) |
| | | ub_wer_dict[rec_name]["b_wer"].ref_words += 1 |
| | | ub_wer_dict[rec_name]["b_wer"].errors[Code.deletion] += 1 |
| | | else: |
| | | ub_wer_dict[rec_name]["u_wer"].ref_words += 1 |
| | | ub_wer_dict[rec_name]["u_wer"].errors[Code.deletion] += 1 |
| | | elif code == Code.insertion: |
| | | ub_wer_dict[rec_name]["wer"].errors[Code.insertion] += 1 |
| | | if rec_word in hot_true_list: |
| | | ub_wer_dict[rec_name]["b_wer"].errors[Code.insertion] += 1 |
| | | else: |
| | | ub_wer_dict[rec_name]["u_wer"].errors[Code.insertion] += 1 |
| | | |
| | | space = {} |
| | | space['lab'] = [] |
| | | space['rec'] = [] |
| | | for idx in range(len(result['lab'])): |
| | | len_lab = width(result['lab'][idx]) |
| | | len_rec = width(result['rec'][idx]) |
| | | length = max(len_lab, len_rec) |
| | | space['lab'].append(length - len_lab) |
| | | space['rec'].append(length - len_rec) |
| | | upper_lab = len(result['lab']) |
| | | upper_rec = len(result['rec']) |
| | | lab1, rec1 = 0, 0 |
| | | while lab1 < upper_lab or rec1 < upper_rec: |
| | | if verbose > 1: |
| | | print('lab(%s):' % fid.encode('utf-8'), end=' ') |
| | | else: |
| | | print('lab:', end=' ') |
| | | lab2 = min(upper_lab, lab1 + max_words_per_line) |
| | | for idx in range(lab1, lab2): |
| | | token = result['lab'][idx] |
| | | print('{token}'.format(token=token), end='') |
| | | for n in range(space['lab'][idx]): |
| | | print(padding_symbol, end='') |
| | | print(' ', end='') |
| | | print() |
| | | if verbose > 1: |
| | | print('rec(%s):' % fid.encode('utf-8'), end=' ') |
| | | else: |
| | | print('rec:', end=' ') |
| | | |
| | | rec2 = min(upper_rec, rec1 + max_words_per_line) |
| | | for idx in range(rec1, rec2): |
| | | token = result['rec'][idx] |
| | | print('{token}'.format(token=token), end='') |
| | | for n in range(space['rec'][idx]): |
| | | print(padding_symbol, end='') |
| | | print(' ', end='') |
| | | print() |
| | | # print('\n', end='\n') |
| | | lab1 = lab2 |
| | | rec1 = rec2 |
| | | print('\n', end='\n') |
| | | # break |
| | | if verbose: |
| | | print('===========================================================================') |
| | | print() |
| | | |
| | | print(wrong_rec_but_in_ocr_dict) |
| | | for rec_name in rec_names: |
| | | result = calculators_dict[rec_name].overall() |
| | | |
| | | if result['all'] != 0: |
| | | wer = float(result['ins'] + result['sub'] + result['del']) * 100.0 / result['all'] |
| | | else: |
| | | wer = 0.0 |
| | | print('{} Overall -> {:4.2f} %'.format(rec_name, wer), end=' ') |
| | | print('N=%d C=%d S=%d D=%d I=%d' % |
| | | (result['all'], result['cor'], result['sub'], result['del'], result['ins'])) |
| | | print(f"WER: {ub_wer_dict[rec_name]['wer'].get_result_string()}") |
| | | print(f"U-WER: {ub_wer_dict[rec_name]['u_wer'].get_result_string()}") |
| | | print(f"B-WER: {ub_wer_dict[rec_name]['b_wer'].get_result_string()}") |
| | | |
| | | print('hotword: tp: {}, tn: {}, fp: {}, fn: {}, all: {}, recall: {:.2f}%'.format( |
| | | hotwords_related_dict[rec_name]['tp'], |
| | | hotwords_related_dict[rec_name]['tn'], |
| | | hotwords_related_dict[rec_name]['fp'], |
| | | hotwords_related_dict[rec_name]['fn'], |
| | | sum([v for k, v in hotwords_related_dict[rec_name].items()]), |
| | | hotwords_related_dict[rec_name]['tp'] / ( |
| | | hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] |
| | | ) * 100 if hotwords_related_dict[rec_name]['tp'] + hotwords_related_dict[rec_name]['fn'] != 0 else 0 |
| | | )) |
| | | |
| | | # tp: 热词在label里,同时在rec里 |
| | | # tn: 热词不在label里,同时不在rec里 |
| | | # fp: 热词不在label里,但是在rec里 |
| | | # fn: 热词在label里,但是不在rec里 |
| | | if not verbose: |
| | | print() |
| | | print() |
| | | |
| | | |
| | | if __name__ == "__main__": |
| | | args = get_args() |
| | | |
| | | # print("") |
| | | print(args) |
| | | main(args) |
| | | |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | from funasr import AutoModel |
| | | |
| | | model = AutoModel(model="iic/LCB-NET", |
| | | model_revision="v1.0.0") |
| | | |
| | | res = model.generate(input=("https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/asr_example.wav","https://www.modelscope.cn/api/v1/models/iic/LCB-NET/repo?Revision=master&FilePath=example/ocr.txt"),data_type=("sound", "text")) |
| | | |
| | | print(res) |
| New file |
| | |
| | | file_dir="/home/yf352572/.cache/modelscope/hub/iic/LCB-NET/" |
| | | CUDA_VISIBLE_DEVICES="0,1" |
| | | inference_device="cuda" |
| | | |
| | | if [ ${inference_device} == "cuda" ]; then |
| | | nj=$(echo $CUDA_VISIBLE_DEVICES | awk -F "," '{print NF}') |
| | | else |
| | | inference_batch_size=1 |
| | | CUDA_VISIBLE_DEVICES="" |
| | | for JOB in $(seq ${nj}); do |
| | | CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"-1," |
| | | done |
| | | fi |
| | | |
| | | inference_dir="outputs/slidespeech_dev" |
| | | _logdir="${inference_dir}/logdir" |
| | | echo "inference_dir: ${inference_dir}" |
| | | |
| | | mkdir -p "${_logdir}" |
| | | key_file1=${file_dir}/dev/wav.scp |
| | | key_file2=${file_dir}/dev/ocr.txt |
| | | split_scps1= |
| | | split_scps2= |
| | | for JOB in $(seq "${nj}"); do |
| | | split_scps1+=" ${_logdir}/wav.${JOB}.scp" |
| | | split_scps2+=" ${_logdir}/ocr.${JOB}.txt" |
| | | done |
| | | utils/split_scp.pl "${key_file1}" ${split_scps1} |
| | | utils/split_scp.pl "${key_file2}" ${split_scps2} |
| | | |
| | | gpuid_list_array=(${CUDA_VISIBLE_DEVICES//,/ }) |
| | | for JOB in $(seq ${nj}); do |
| | | { |
| | | id=$((JOB-1)) |
| | | gpuid=${gpuid_list_array[$id]} |
| | | |
| | | export CUDA_VISIBLE_DEVICES=${gpuid} |
| | | |
| | | python -m funasr.bin.inference \ |
| | | --config-path=${file_dir} \ |
| | | --config-name="config.yaml" \ |
| | | ++init_param=${file_dir}/model.pt \ |
| | | ++tokenizer_conf.token_list=${file_dir}/tokens.txt \ |
| | | ++input=[${_logdir}/wav.${JOB}.scp,${_logdir}/ocr.${JOB}.txt] \ |
| | | +data_type='["kaldi_ark", "text"]' \ |
| | | ++tokenizer_conf.bpemodel=${file_dir}/bpe.pt \ |
| | | ++normalize_conf.stats_file=${file_dir}/am.mvn \ |
| | | ++output_dir="${inference_dir}/${JOB}" \ |
| | | ++device="${inference_device}" \ |
| | | ++ncpu=1 \ |
| | | ++disable_log=true &> ${_logdir}/log.${JOB}.txt |
| | | |
| | | }& |
| | | done |
| | | wait |
| | | |
| | | |
| | | mkdir -p ${inference_dir}/1best_recog |
| | | |
| | | for JOB in $(seq "${nj}"); do |
| | | cat "${inference_dir}/${JOB}/1best_recog/token" >> "${inference_dir}/1best_recog/token" |
| | | done |
| | | |
| | | echo "Computing WER ..." |
| | | sed -e 's/ /\t/' -e 's/ //g' -e 's/▁/ /g' -e 's/\t /\t/' ${inference_dir}/1best_recog/token > ${inference_dir}/1best_recog/token.proc |
| | | cp ${file_dir}/dev/text ${inference_dir}/1best_recog/token.ref |
| | | cp ${file_dir}/dev/ocr.list ${inference_dir}/1best_recog/ocr.list |
| | | python utils/compute_wer.py ${inference_dir}/1best_recog/token.ref ${inference_dir}/1best_recog/token.proc ${inference_dir}/1best_recog/token.cer |
| | | tail -n 3 ${inference_dir}/1best_recog/token.cer |
| | | |
| | | ./run_bwer_recall.sh ${inference_dir}/1best_recog/ |
| | | tail -n 6 ${inference_dir}/1best_recog/BWER-UWER.results |head -n 5 |
| New file |
| | |
| | | #now_result_name=asr_conformer_acc1_lr002_warm20000/decode_asr_asr_model_valid.acc.ave |
| | | #hotword_type=ocr_1ngram_top10_hotwords_list |
| | | hot_exp_suf=$1 |
| | | |
| | | |
| | | python compute_wer_details.py --v 1 \ |
| | | --ref ${hot_exp_suf}/token.ref \ |
| | | --ref_ocr ${hot_exp_suf}/ocr.list \ |
| | | --rec_name base \ |
| | | --rec_file ${hot_exp_suf}/token.proc \ |
| | | > ${hot_exp_suf}/BWER-UWER.results |
| New file |
| | |
| | | ../../aishell/paraformer/utils |
| | |
| | | |
| | | model = AutoModel(model="iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch", |
| | | model_revision="v2.0.4", |
| | | vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", |
| | | vad_model_revision="v2.0.4", |
| | | punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", |
| | | punc_model_revision="v2.0.4", |
| | | # vad_model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch", |
| | | # vad_model_revision="v2.0.4", |
| | | # punc_model="damo/punc_ct-transformer_zh-cn-common-vocab272727-pytorch", |
| | | # punc_model_revision="v2.0.4", |
| | | # spk_model="damo/speech_campplus_sv_zh-cn_16k-common", |
| | | # spk_model_revision="v2.0.2", |
| | | ) |
| | |
| | | wav_file = os.path.join(model.model_path, "example/asr_example.wav") |
| | | speech, sample_rate = soundfile.read(wav_file) |
| | | res = model.generate(input=[speech], batch_size_s=300, is_final=True) |
| | | ''' |
| | | ''' |
| | |
| | | from funasr.models.campplus.cluster_backend import ClusterBackend |
| | | except: |
| | | print("If you want to use the speaker diarization, please `pip install hdbscan`") |
| | | |
| | | import pdb |
| | | |
| | | def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None): |
| | | """ |
| | |
| | | chars = string.ascii_letters + string.digits |
| | | if isinstance(data_in, str) and data_in.startswith('http'): # url |
| | | data_in = download_from_url(data_in) |
| | | |
| | | if isinstance(data_in, str) and os.path.exists(data_in): # wav_path; filelist: wav.scp, file.jsonl;text.txt; |
| | | _, file_extension = os.path.splitext(data_in) |
| | | file_extension = file_extension.lower() |
| | |
| | | kwargs = download_model(**kwargs) |
| | | |
| | | set_all_random_seed(kwargs.get("seed", 0)) |
| | | |
| | | |
| | | device = kwargs.get("device", "cuda") |
| | | if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0: |
| | | device = "cpu" |
| | |
| | | vocab_size = len(kwargs["token_list"]) if kwargs["token_list"] is not None else -1 |
| | | else: |
| | | vocab_size = -1 |
| | | |
| | | # build frontend |
| | | frontend = kwargs.get("frontend", None) |
| | | kwargs["input_size"] = None |
| | |
| | | # build model |
| | | model_class = tables.model_classes.get(kwargs["model"]) |
| | | model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=vocab_size) |
| | | |
| | | model.to(device) |
| | | |
| | | # init_param |
| | |
| | | batch_size = kwargs.get("batch_size", 1) |
| | | # if kwargs.get("device", "cpu") == "cpu": |
| | | # batch_size = 1 |
| | | |
| | | |
| | | key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None), key=key) |
| | | |
| | | |
| | | speed_stats = {} |
| | | asr_result_list = [] |
| | | num_samples = len(data_list) |
| | |
| | | data_batch = data_list[beg_idx:end_idx] |
| | | key_batch = key_list[beg_idx:end_idx] |
| | | batch = {"data_in": data_batch, "key": key_batch} |
| | | |
| | | if (end_idx - beg_idx) == 1 and kwargs.get("data_type", None) == "fbank": # fbank |
| | | batch["data_in"] = data_batch[0] |
| | | batch["data_lengths"] = input_len |
| | |
| | | from typing import Tuple |
| | | from typing import Union |
| | | import logging |
| | | import humanfriendly |
| | | import numpy as np |
| | | import torch |
| | | import torch.nn as nn |
| | |
| | | from funasr.frontends.utils.stft import Stft |
| | | from funasr.frontends.utils.frontend import Frontend |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | from funasr.register import tables |
| | | |
| | | |
| | | @tables.register("frontend_classes", "DefaultFrontend") |
| | | class DefaultFrontend(nn.Module): |
| | | """Conventional frontend structure for ASR. |
| | | Stft -> WPE -> MVDR-Beamformer -> Power-spec -> Mel-Fbank -> CMVN |
| | |
| | | |
| | | def __init__( |
| | | self, |
| | | fs: Union[int, str] = 16000, |
| | | fs: int = 16000, |
| | | n_fft: int = 512, |
| | | win_length: int = None, |
| | | hop_length: int = 128, |
| | |
| | | frontend_conf: Optional[dict] = None, |
| | | apply_stft: bool = True, |
| | | use_channel: int = None, |
| | | **kwargs, |
| | | ): |
| | | super().__init__() |
| | | if isinstance(fs, str): |
| | | fs = humanfriendly.parse_size(fs) |
| | | |
| | | # Deepcopy (In general, dict shouldn't be used as default arg) |
| | | frontend_conf = copy.deepcopy(frontend_conf) |
| | | self.hop_length = hop_length |
| | | self.fs = fs |
| | | |
| | | if apply_stft: |
| | | self.stft = Stft( |
| | |
| | | return self.n_mels |
| | | |
| | | def forward( |
| | | self, input: torch.Tensor, input_lengths: torch.Tensor |
| | | self, input: torch.Tensor, input_lengths: Union[torch.Tensor, list] |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | if isinstance(input_lengths, list): |
| | | input_lengths = torch.tensor(input_lengths) |
| | | if input.dtype == torch.float64: |
| | | input = input.float() |
| | | # 1. Domain-conversion: e.g. Stft: time -> time-freq |
| | | if self.stft is not None: |
| | | input_stft, feats_lens = self._compute_stft(input, input_lengths) |
| | |
| | | |
| | | def __init__( |
| | | self, |
| | | fs: Union[int, str] = 16000, |
| | | fs: int = 16000, |
| | | n_fft: int = 512, |
| | | win_length: int = None, |
| | | hop_length: int = None, |
| | |
| | | mc: bool = True |
| | | ): |
| | | super().__init__() |
| | | if isinstance(fs, str): |
| | | fs = humanfriendly.parse_size(fs) |
| | | |
| | | # Deepcopy (In general, dict shouldn't be used as default arg) |
| | | frontend_conf = copy.deepcopy(frontend_conf) |
| | | if win_length is None and hop_length is None: |
| | |
| | | from funasr.models.transformer.utils.subsampling import Conv2dSubsamplingPad |
| | | from funasr.models.transformer.utils.subsampling import StreamingConvInput |
| | | from funasr.register import tables |
| | | |
| | | import pdb |
| | | |
| | | class ConvolutionModule(nn.Module): |
| | | """ConvolutionModule in Conformer model. |
| | |
| | | from funasr.models.transformer.utils.add_sos_eos import add_sos_eos |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | |
| | | import pdb |
| | | |
| | | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | | from torch.cuda.amp import autocast |
| | |
| | | crit_attn_weight = kwargs.get("crit_attn_weight", 0.0) |
| | | crit_attn_smooth = kwargs.get("crit_attn_smooth", 0.0) |
| | | bias_encoder_dropout_rate = kwargs.get("bias_encoder_dropout_rate", 0.0) |
| | | |
| | | |
| | | if bias_encoder_type == 'lstm': |
| | | self.bias_encoder = torch.nn.LSTM(inner_dim, inner_dim, 1, batch_first=True, dropout=bias_encoder_dropout_rate) |
| | |
| | | text_lengths = text_lengths[:, 0] |
| | | if len(speech_lengths.size()) > 1: |
| | | speech_lengths = speech_lengths[:, 0] |
| | | |
| | | |
| | | batch_size = speech.shape[0] |
| | | |
| | | hotword_pad = kwargs.get("hotword_pad") |
| | | hotword_lengths = kwargs.get("hotword_lengths") |
| | | dha_pad = kwargs.get("dha_pad") |
| | | |
| | | |
| | | # 1. Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | |
| | | |
| | | loss_ctc, cer_ctc = None, None |
| | | |
| | | stats = dict() |
| | |
| | | stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None |
| | | stats["cer_ctc"] = cer_ctc |
| | | |
| | | |
| | | # 2b. Attention decoder branch |
| | | loss_att, acc_att, cer_att, wer_att, loss_pre, loss_ideal = self._calc_att_clas_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths, hotword_pad, hotword_lengths |
| | | ) |
| | | |
| | | |
| | | # 3. CTC-Att loss definition |
| | | if self.ctc_weight == 0.0: |
| | | loss = loss_att + loss_pre * self.predictor_weight |
| | |
| | | ): |
| | | encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( |
| | | encoder_out.device) |
| | | |
| | | if self.predictor_bias == 1: |
| | | _, ys_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) |
| | | ys_pad_lens = ys_pad_lens + self.predictor_bias |
| | | |
| | | pre_acoustic_embeds, pre_token_length, _, _ = self.predictor(encoder_out, ys_pad, encoder_out_mask, |
| | | ignore_id=self.ignore_id) |
| | | |
| | | # -1. bias encoder |
| | | if self.use_decoder_embedding: |
| | | hw_embed = self.decoder.embed(hotword_pad) |
| | | else: |
| | | hw_embed = self.bias_embed(hotword_pad) |
| | | |
| | | hw_embed, (_, _) = self.bias_encoder(hw_embed) |
| | | _ind = np.arange(0, hotword_pad.shape[0]).tolist() |
| | | selected = hw_embed[_ind, [i - 1 for i in hotword_lengths.detach().cpu().tolist()]] |
| | | contextual_info = selected.squeeze(0).repeat(ys_pad.shape[0], 1, 1).to(ys_pad.device) |
| | | |
| | | |
| | | # 0. sampler |
| | | decoder_out_1st = None |
| | | if self.sampling_ratio > 0.0: |
| | |
| | | pre_acoustic_embeds, contextual_info) |
| | | else: |
| | | sematic_embeds = pre_acoustic_embeds |
| | | |
| | | |
| | | # 1. Forward decoder |
| | | decoder_outs = self.decoder( |
| | | encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=contextual_info |
| | |
| | | loss_ideal = None |
| | | ''' |
| | | loss_ideal = None |
| | | |
| | | |
| | | if decoder_out_1st is None: |
| | | decoder_out_1st = decoder_out |
| | | # 2. Compute attention loss |
| | |
| | | enforce_sorted=False) |
| | | _, (h_n, _) = self.bias_encoder(hw_embed) |
| | | hw_embed = h_n.repeat(encoder_out.shape[0], 1, 1) |
| | | |
| | | |
| | | decoder_outs = self.decoder( |
| | | encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, contextual_info=hw_embed, clas_scale=clas_scale |
| | | ) |
| | | |
| | | decoder_out = decoder_outs[0] |
| | | decoder_out = torch.log_softmax(decoder_out, dim=-1) |
| | | return decoder_out, ys_pad_lens |
| | |
| | | **kwargs, |
| | | ): |
| | | # init beamsearch |
| | | |
| | | is_use_ctc = kwargs.get("decoding_ctc_weight", 0.0) > 0.00001 and self.ctc != None |
| | | is_use_lm = kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None |
| | | if self.beam_search is None and (is_use_lm or is_use_ctc): |
| | | logging.info("enable beam_search") |
| | | self.init_beam_search(**kwargs) |
| | | self.nbest = kwargs.get("nbest", 1) |
| | | |
| | | |
| | | meta_data = {} |
| | | |
| | | # extract fbank feats |
| | | time1 = time.perf_counter() |
| | | |
| | | audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) |
| | | |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | |
| | | speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), |
| | | frontend=frontend) |
| | | time3 = time.perf_counter() |
| | | meta_data["extract_feat"] = f"{time3 - time2:0.3f}" |
| | | meta_data[ |
| | | "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 |
| | | |
| | | |
| | | speech = speech.to(device=kwargs["device"]) |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | | |
| | | # hotword |
| | | self.hotword_list = self.generate_hotwords_list(kwargs.get("hotword", None), tokenizer=tokenizer, frontend=frontend) |
| | | |
| | | |
| | | # Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | if isinstance(encoder_out, tuple): |
| | | encoder_out = encoder_out[0] |
| | | |
| | | |
| | | # predictor |
| | | predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens) |
| | | pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = predictor_outs[0], predictor_outs[1], \ |
| | |
| | | pre_token_length = pre_token_length.round().long() |
| | | if torch.max(pre_token_length) < 1: |
| | | return [] |
| | | |
| | | |
| | | |
| | | decoder_outs = self.cal_decoder_with_predictor(encoder_out, encoder_out_lens, |
| | | pre_acoustic_embeds, |
| | | pre_token_length, |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # -*- coding: utf-8 -*- |
| | | |
| | | # Copyright 2024 yufan |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | """Multi-Head Attention Return Weight layer definition.""" |
| | | |
| | | import math |
| | | |
| | | import torch |
| | | from torch import nn |
| | | |
| | | class MultiHeadedAttentionReturnWeight(nn.Module): |
| | | """Multi-Head Attention layer. |
| | | |
| | | Args: |
| | | n_head (int): The number of heads. |
| | | n_feat (int): The number of features. |
| | | dropout_rate (float): Dropout rate. |
| | | |
| | | """ |
| | | |
| | | def __init__(self, n_head, n_feat, dropout_rate): |
| | | """Construct an MultiHeadedAttentionReturnWeight object.""" |
| | | super(MultiHeadedAttentionReturnWeight, self).__init__() |
| | | assert n_feat % n_head == 0 |
| | | # We assume d_v always equals d_k |
| | | self.d_k = n_feat // n_head |
| | | self.h = n_head |
| | | self.linear_q = nn.Linear(n_feat, n_feat) |
| | | self.linear_k = nn.Linear(n_feat, n_feat) |
| | | self.linear_v = nn.Linear(n_feat, n_feat) |
| | | self.linear_out = nn.Linear(n_feat, n_feat) |
| | | self.attn = None |
| | | self.dropout = nn.Dropout(p=dropout_rate) |
| | | |
| | | def forward_qkv(self, query, key, value): |
| | | """Transform query, key and value. |
| | | |
| | | Args: |
| | | query (torch.Tensor): Query tensor (#batch, time1, size). |
| | | key (torch.Tensor): Key tensor (#batch, time2, size). |
| | | value (torch.Tensor): Value tensor (#batch, time2, size). |
| | | |
| | | Returns: |
| | | torch.Tensor: Transformed query tensor (#batch, n_head, time1, d_k). |
| | | torch.Tensor: Transformed key tensor (#batch, n_head, time2, d_k). |
| | | torch.Tensor: Transformed value tensor (#batch, n_head, time2, d_k). |
| | | |
| | | """ |
| | | n_batch = query.size(0) |
| | | q = self.linear_q(query).view(n_batch, -1, self.h, self.d_k) |
| | | k = self.linear_k(key).view(n_batch, -1, self.h, self.d_k) |
| | | v = self.linear_v(value).view(n_batch, -1, self.h, self.d_k) |
| | | q = q.transpose(1, 2) # (batch, head, time1, d_k) |
| | | k = k.transpose(1, 2) # (batch, head, time2, d_k) |
| | | v = v.transpose(1, 2) # (batch, head, time2, d_k) |
| | | |
| | | return q, k, v |
| | | |
| | | def forward_attention(self, value, scores, mask): |
| | | """Compute attention context vector. |
| | | |
| | | Args: |
| | | value (torch.Tensor): Transformed value (#batch, n_head, time2, d_k). |
| | | scores (torch.Tensor): Attention score (#batch, n_head, time1, time2). |
| | | mask (torch.Tensor): Mask (#batch, 1, time2) or (#batch, time1, time2). |
| | | |
| | | Returns: |
| | | torch.Tensor: Transformed value (#batch, time1, d_model) |
| | | weighted by the attention score (#batch, time1, time2). |
| | | |
| | | """ |
| | | n_batch = value.size(0) |
| | | if mask is not None: |
| | | mask = mask.unsqueeze(1).eq(0) # (batch, 1, *, time2) |
| | | min_value = torch.finfo(scores.dtype).min |
| | | scores = scores.masked_fill(mask, min_value) |
| | | self.attn = torch.softmax(scores, dim=-1).masked_fill( |
| | | mask, 0.0 |
| | | ) # (batch, head, time1, time2) |
| | | else: |
| | | self.attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2) |
| | | |
| | | p_attn = self.dropout(self.attn) |
| | | x = torch.matmul(p_attn, value) # (batch, head, time1, d_k) |
| | | x = ( |
| | | x.transpose(1, 2).contiguous().view(n_batch, -1, self.h * self.d_k) |
| | | ) # (batch, time1, d_model) |
| | | |
| | | return self.linear_out(x), self.attn # (batch, time1, d_model) |
| | | |
| | | def forward(self, query, key, value, mask): |
| | | """Compute scaled dot product attention. |
| | | |
| | | Args: |
| | | query (torch.Tensor): Query tensor (#batch, time1, size). |
| | | key (torch.Tensor): Key tensor (#batch, time2, size). |
| | | value (torch.Tensor): Value tensor (#batch, time2, size). |
| | | mask (torch.Tensor): Mask tensor (#batch, 1, time2) or |
| | | (#batch, time1, time2). |
| | | |
| | | Returns: |
| | | torch.Tensor: Output tensor (#batch, time1, d_model). |
| | | |
| | | """ |
| | | q, k, v = self.forward_qkv(query, key, value) |
| | | scores = torch.matmul(q, k.transpose(-2, -1)) / math.sqrt(self.d_k) |
| | | return self.forward_attention(v, scores, mask) |
| | | |
| | | |
| New file |
| | |
| | | # Copyright 2019 Shigeki Karita |
| | | # Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0) |
| | | |
| | | """Transformer encoder definition.""" |
| | | |
| | | from typing import List |
| | | from typing import Optional |
| | | from typing import Tuple |
| | | |
| | | import torch |
| | | from torch import nn |
| | | import logging |
| | | |
| | | from funasr.models.transformer.attention import MultiHeadedAttention |
| | | from funasr.models.lcbnet.attention import MultiHeadedAttentionReturnWeight |
| | | from funasr.models.transformer.embedding import PositionalEncoding |
| | | from funasr.models.transformer.layer_norm import LayerNorm |
| | | |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask |
| | | from funasr.models.transformer.positionwise_feed_forward import PositionwiseFeedForward |
| | | from funasr.models.transformer.utils.repeat import repeat |
| | | from funasr.register import tables |
| | | |
| | | class EncoderLayer(nn.Module): |
| | | """Encoder layer module. |
| | | |
| | | Args: |
| | | size (int): Input dimension. |
| | | self_attn (torch.nn.Module): Self-attention module instance. |
| | | `MultiHeadedAttention` or `RelPositionMultiHeadedAttention` instance |
| | | can be used as the argument. |
| | | feed_forward (torch.nn.Module): Feed-forward module instance. |
| | | `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance |
| | | can be used as the argument. |
| | | dropout_rate (float): Dropout rate. |
| | | normalize_before (bool): Whether to use layer_norm before the first block. |
| | | concat_after (bool): Whether to concat attention layer's input and output. |
| | | if True, additional linear will be applied. |
| | | i.e. x -> x + linear(concat(x, att(x))) |
| | | if False, no additional linear will be applied. i.e. x -> x + att(x) |
| | | stochastic_depth_rate (float): Proability to skip this layer. |
| | | During training, the layer may skip residual computation and return input |
| | | as-is with given probability. |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | size, |
| | | self_attn, |
| | | feed_forward, |
| | | dropout_rate, |
| | | normalize_before=True, |
| | | concat_after=False, |
| | | stochastic_depth_rate=0.0, |
| | | ): |
| | | """Construct an EncoderLayer object.""" |
| | | super(EncoderLayer, self).__init__() |
| | | self.self_attn = self_attn |
| | | self.feed_forward = feed_forward |
| | | self.norm1 = LayerNorm(size) |
| | | self.norm2 = LayerNorm(size) |
| | | self.dropout = nn.Dropout(dropout_rate) |
| | | self.size = size |
| | | self.normalize_before = normalize_before |
| | | self.concat_after = concat_after |
| | | if self.concat_after: |
| | | self.concat_linear = nn.Linear(size + size, size) |
| | | self.stochastic_depth_rate = stochastic_depth_rate |
| | | |
| | | def forward(self, x, mask, cache=None): |
| | | """Compute encoded features. |
| | | |
| | | Args: |
| | | x_input (torch.Tensor): Input tensor (#batch, time, size). |
| | | mask (torch.Tensor): Mask tensor for the input (#batch, time). |
| | | cache (torch.Tensor): Cache tensor of the input (#batch, time - 1, size). |
| | | |
| | | Returns: |
| | | torch.Tensor: Output tensor (#batch, time, size). |
| | | torch.Tensor: Mask tensor (#batch, time). |
| | | |
| | | """ |
| | | skip_layer = False |
| | | # with stochastic depth, residual connection `x + f(x)` becomes |
| | | # `x <- x + 1 / (1 - p) * f(x)` at training time. |
| | | stoch_layer_coeff = 1.0 |
| | | if self.training and self.stochastic_depth_rate > 0: |
| | | skip_layer = torch.rand(1).item() < self.stochastic_depth_rate |
| | | stoch_layer_coeff = 1.0 / (1 - self.stochastic_depth_rate) |
| | | |
| | | if skip_layer: |
| | | if cache is not None: |
| | | x = torch.cat([cache, x], dim=1) |
| | | return x, mask |
| | | |
| | | residual = x |
| | | if self.normalize_before: |
| | | x = self.norm1(x) |
| | | |
| | | if cache is None: |
| | | x_q = x |
| | | else: |
| | | assert cache.shape == (x.shape[0], x.shape[1] - 1, self.size) |
| | | x_q = x[:, -1:, :] |
| | | residual = residual[:, -1:, :] |
| | | mask = None if mask is None else mask[:, -1:, :] |
| | | |
| | | if self.concat_after: |
| | | x_concat = torch.cat((x, self.self_attn(x_q, x, x, mask)), dim=-1) |
| | | x = residual + stoch_layer_coeff * self.concat_linear(x_concat) |
| | | else: |
| | | x = residual + stoch_layer_coeff * self.dropout( |
| | | self.self_attn(x_q, x, x, mask) |
| | | ) |
| | | if not self.normalize_before: |
| | | x = self.norm1(x) |
| | | |
| | | residual = x |
| | | if self.normalize_before: |
| | | x = self.norm2(x) |
| | | x = residual + stoch_layer_coeff * self.dropout(self.feed_forward(x)) |
| | | if not self.normalize_before: |
| | | x = self.norm2(x) |
| | | |
| | | if cache is not None: |
| | | x = torch.cat([cache, x], dim=1) |
| | | |
| | | return x, mask |
| | | |
| | | @tables.register("encoder_classes", "TransformerTextEncoder") |
| | | class TransformerTextEncoder(nn.Module): |
| | | """Transformer text encoder module. |
| | | |
| | | Args: |
| | | input_size: input dim |
| | | output_size: dimension of attention |
| | | attention_heads: the number of heads of multi head attention |
| | | linear_units: the number of units of position-wise feed forward |
| | | num_blocks: the number of decoder blocks |
| | | dropout_rate: dropout rate |
| | | attention_dropout_rate: dropout rate in attention |
| | | positional_dropout_rate: dropout rate after adding positional encoding |
| | | input_layer: input layer type |
| | | pos_enc_class: PositionalEncoding or ScaledPositionalEncoding |
| | | normalize_before: whether to use layer_norm before the first block |
| | | concat_after: whether to concat attention layer's input and output |
| | | if True, additional linear will be applied. |
| | | i.e. x -> x + linear(concat(x, att(x))) |
| | | if False, no additional linear will be applied. |
| | | i.e. x -> x + att(x) |
| | | positionwise_layer_type: linear of conv1d |
| | | positionwise_conv_kernel_size: kernel size of positionwise conv1d layer |
| | | padding_idx: padding_idx for input_layer=embed |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | input_size: int, |
| | | output_size: int = 256, |
| | | attention_heads: int = 4, |
| | | linear_units: int = 2048, |
| | | num_blocks: int = 6, |
| | | dropout_rate: float = 0.1, |
| | | positional_dropout_rate: float = 0.1, |
| | | attention_dropout_rate: float = 0.0, |
| | | pos_enc_class=PositionalEncoding, |
| | | normalize_before: bool = True, |
| | | concat_after: bool = False, |
| | | ): |
| | | super().__init__() |
| | | self._output_size = output_size |
| | | |
| | | self.embed = torch.nn.Sequential( |
| | | torch.nn.Embedding(input_size, output_size), |
| | | pos_enc_class(output_size, positional_dropout_rate), |
| | | ) |
| | | |
| | | self.normalize_before = normalize_before |
| | | |
| | | positionwise_layer = PositionwiseFeedForward |
| | | positionwise_layer_args = ( |
| | | output_size, |
| | | linear_units, |
| | | dropout_rate, |
| | | ) |
| | | self.encoders = repeat( |
| | | num_blocks, |
| | | lambda lnum: EncoderLayer( |
| | | output_size, |
| | | MultiHeadedAttention( |
| | | attention_heads, output_size, attention_dropout_rate |
| | | ), |
| | | positionwise_layer(*positionwise_layer_args), |
| | | dropout_rate, |
| | | normalize_before, |
| | | concat_after, |
| | | ), |
| | | ) |
| | | if self.normalize_before: |
| | | self.after_norm = LayerNorm(output_size) |
| | | |
| | | def output_size(self) -> int: |
| | | return self._output_size |
| | | |
| | | def forward( |
| | | self, |
| | | xs_pad: torch.Tensor, |
| | | ilens: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: |
| | | """Embed positions in tensor. |
| | | |
| | | Args: |
| | | xs_pad: input tensor (B, L, D) |
| | | ilens: input length (B) |
| | | Returns: |
| | | position embedded tensor and mask |
| | | """ |
| | | masks = (~make_pad_mask(ilens)[:, None, :]).to(xs_pad.device) |
| | | xs_pad = self.embed(xs_pad) |
| | | |
| | | xs_pad, masks = self.encoders(xs_pad, masks) |
| | | |
| | | if self.normalize_before: |
| | | xs_pad = self.after_norm(xs_pad) |
| | | |
| | | olens = masks.squeeze(1).sum(1) |
| | | return xs_pad, olens, None |
| | | |
| | | |
| | | |
| | | |
| | | @tables.register("encoder_classes", "FusionSANEncoder") |
| | | class SelfSrcAttention(nn.Module): |
| | | """Single decoder layer module. |
| | | |
| | | Args: |
| | | size (int): Input dimension. |
| | | self_attn (torch.nn.Module): Self-attention module instance. |
| | | `MultiHeadedAttention` instance can be used as the argument. |
| | | src_attn (torch.nn.Module): Self-attention module instance. |
| | | `MultiHeadedAttention` instance can be used as the argument. |
| | | feed_forward (torch.nn.Module): Feed-forward module instance. |
| | | `PositionwiseFeedForward`, `MultiLayeredConv1d`, or `Conv1dLinear` instance |
| | | can be used as the argument. |
| | | dropout_rate (float): Dropout rate. |
| | | normalize_before (bool): Whether to use layer_norm before the first block. |
| | | concat_after (bool): Whether to concat attention layer's input and output. |
| | | if True, additional linear will be applied. |
| | | i.e. x -> x + linear(concat(x, att(x))) |
| | | if False, no additional linear will be applied. i.e. x -> x + att(x) |
| | | |
| | | |
| | | """ |
| | | def __init__( |
| | | self, |
| | | size, |
| | | attention_heads, |
| | | attention_dim, |
| | | linear_units, |
| | | self_attention_dropout_rate, |
| | | src_attention_dropout_rate, |
| | | positional_dropout_rate, |
| | | dropout_rate, |
| | | normalize_before=True, |
| | | concat_after=False, |
| | | ): |
| | | """Construct an SelfSrcAttention object.""" |
| | | super(SelfSrcAttention, self).__init__() |
| | | self.size = size |
| | | self.self_attn = MultiHeadedAttention(attention_heads, attention_dim, self_attention_dropout_rate) |
| | | self.src_attn = MultiHeadedAttentionReturnWeight(attention_heads, attention_dim, src_attention_dropout_rate) |
| | | self.feed_forward = PositionwiseFeedForward(attention_dim, linear_units, positional_dropout_rate) |
| | | self.norm1 = LayerNorm(size) |
| | | self.norm2 = LayerNorm(size) |
| | | self.norm3 = LayerNorm(size) |
| | | self.dropout = nn.Dropout(dropout_rate) |
| | | self.normalize_before = normalize_before |
| | | self.concat_after = concat_after |
| | | if self.concat_after: |
| | | self.concat_linear1 = nn.Linear(size + size, size) |
| | | self.concat_linear2 = nn.Linear(size + size, size) |
| | | |
| | | def forward(self, tgt, tgt_mask, memory, memory_mask, cache=None): |
| | | """Compute decoded features. |
| | | |
| | | Args: |
| | | tgt (torch.Tensor): Input tensor (#batch, maxlen_out, size). |
| | | tgt_mask (torch.Tensor): Mask for input tensor (#batch, maxlen_out). |
| | | memory (torch.Tensor): Encoded memory, float32 (#batch, maxlen_in, size). |
| | | memory_mask (torch.Tensor): Encoded memory mask (#batch, maxlen_in). |
| | | cache (List[torch.Tensor]): List of cached tensors. |
| | | Each tensor shape should be (#batch, maxlen_out - 1, size). |
| | | |
| | | Returns: |
| | | torch.Tensor: Output tensor(#batch, maxlen_out, size). |
| | | torch.Tensor: Mask for output tensor (#batch, maxlen_out). |
| | | torch.Tensor: Encoded memory (#batch, maxlen_in, size). |
| | | torch.Tensor: Encoded memory mask (#batch, maxlen_in). |
| | | |
| | | """ |
| | | residual = tgt |
| | | if self.normalize_before: |
| | | tgt = self.norm1(tgt) |
| | | |
| | | if cache is None: |
| | | tgt_q = tgt |
| | | tgt_q_mask = tgt_mask |
| | | else: |
| | | # compute only the last frame query keeping dim: max_time_out -> 1 |
| | | assert cache.shape == ( |
| | | tgt.shape[0], |
| | | tgt.shape[1] - 1, |
| | | self.size, |
| | | ), f"{cache.shape} == {(tgt.shape[0], tgt.shape[1] - 1, self.size)}" |
| | | tgt_q = tgt[:, -1:, :] |
| | | residual = residual[:, -1:, :] |
| | | tgt_q_mask = None |
| | | if tgt_mask is not None: |
| | | tgt_q_mask = tgt_mask[:, -1:, :] |
| | | |
| | | if self.concat_after: |
| | | tgt_concat = torch.cat( |
| | | (tgt_q, self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)), dim=-1 |
| | | ) |
| | | x = residual + self.concat_linear1(tgt_concat) |
| | | else: |
| | | x = residual + self.dropout(self.self_attn(tgt_q, tgt, tgt, tgt_q_mask)) |
| | | if not self.normalize_before: |
| | | x = self.norm1(x) |
| | | |
| | | residual = x |
| | | if self.normalize_before: |
| | | x = self.norm2(x) |
| | | if self.concat_after: |
| | | x_concat = torch.cat( |
| | | (x, self.src_attn(x, memory, memory, memory_mask)), dim=-1 |
| | | ) |
| | | x = residual + self.concat_linear2(x_concat) |
| | | else: |
| | | x, score = self.src_attn(x, memory, memory, memory_mask) |
| | | x = residual + self.dropout(x) |
| | | if not self.normalize_before: |
| | | x = self.norm2(x) |
| | | |
| | | residual = x |
| | | if self.normalize_before: |
| | | x = self.norm3(x) |
| | | x = residual + self.dropout(self.feed_forward(x)) |
| | | if not self.normalize_before: |
| | | x = self.norm3(x) |
| | | |
| | | if cache is not None: |
| | | x = torch.cat([cache, x], dim=1) |
| | | |
| | | return x, tgt_mask, memory, memory_mask |
| | | |
| | | |
| | | @tables.register("encoder_classes", "ConvBiasPredictor") |
| | | class ConvPredictor(nn.Module): |
| | | def __init__(self, size=256, l_order=3, r_order=3, attention_heads=4, attention_dropout_rate=0.1, linear_units=2048): |
| | | super().__init__() |
| | | self.atten = MultiHeadedAttention(attention_heads, size, attention_dropout_rate) |
| | | self.norm1 = LayerNorm(size) |
| | | self.feed_forward = PositionwiseFeedForward(size, linear_units, attention_dropout_rate) |
| | | self.norm2 = LayerNorm(size) |
| | | self.pad = nn.ConstantPad1d((l_order, r_order), 0) |
| | | self.conv1d = nn.Conv1d(size, size, l_order + r_order + 1, groups=size) |
| | | self.output_linear = nn.Linear(size, 1) |
| | | |
| | | |
| | | def forward(self, text_enc, asr_enc): |
| | | # stage1 cross-attention |
| | | residual = text_enc |
| | | text_enc = residual + self.atten(text_enc, asr_enc, asr_enc, None) |
| | | |
| | | # stage2 FFN |
| | | residual = text_enc |
| | | text_enc = self.norm1(text_enc) |
| | | text_enc = residual + self.feed_forward(text_enc) |
| | | |
| | | # stage Conv predictor |
| | | text_enc = self.norm2(text_enc) |
| | | context = text_enc.transpose(1, 2) |
| | | queries = self.pad(context) |
| | | memory = self.conv1d(queries) |
| | | output = memory + context |
| | | output = output.transpose(1, 2) |
| | | output = torch.relu(output) |
| | | output = self.output_linear(output) |
| | | if output.dim()==3: |
| | | output = output.squeeze(2) |
| | | return output |
| New file |
| | |
| | | #!/usr/bin/env python3 |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import logging |
| | | from typing import Union, Dict, List, Tuple, Optional |
| | | |
| | | import time |
| | | import torch |
| | | import torch.nn as nn |
| | | from torch.cuda.amp import autocast |
| | | |
| | | from funasr.losses.label_smoothing_loss import LabelSmoothingLoss |
| | | from funasr.models.ctc.ctc import CTC |
| | | from funasr.models.transformer.utils.add_sos_eos import add_sos_eos |
| | | from funasr.metrics.compute_acc import th_accuracy |
| | | # from funasr.models.e2e_asr_common import ErrorCalculator |
| | | from funasr.train_utils.device_funcs import force_gatherable |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | from funasr.utils import postprocess_utils |
| | | from funasr.utils.datadir_writer import DatadirWriter |
| | | from funasr.register import tables |
| | | |
| | | import pdb |
| | | @tables.register("model_classes", "LCBNet") |
| | | class LCBNet(nn.Module): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | LCB-NET: LONG-CONTEXT BIASING FOR AUDIO-VISUAL SPEECH RECOGNITION |
| | | https://arxiv.org/abs/2401.06390 |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | specaug: str = None, |
| | | specaug_conf: dict = None, |
| | | normalize: str = None, |
| | | normalize_conf: dict = None, |
| | | encoder: str = None, |
| | | encoder_conf: dict = None, |
| | | decoder: str = None, |
| | | decoder_conf: dict = None, |
| | | text_encoder: str = None, |
| | | text_encoder_conf: dict = None, |
| | | bias_predictor: str = None, |
| | | bias_predictor_conf: dict = None, |
| | | fusion_encoder: str = None, |
| | | fusion_encoder_conf: dict = None, |
| | | ctc: str = None, |
| | | ctc_conf: dict = None, |
| | | ctc_weight: float = 0.5, |
| | | interctc_weight: float = 0.0, |
| | | select_num: int = 2, |
| | | select_length: int = 3, |
| | | insert_blank: bool = True, |
| | | input_size: int = 80, |
| | | vocab_size: int = -1, |
| | | ignore_id: int = -1, |
| | | blank_id: int = 0, |
| | | sos: int = 1, |
| | | eos: int = 2, |
| | | lsm_weight: float = 0.0, |
| | | length_normalized_loss: bool = False, |
| | | report_cer: bool = True, |
| | | report_wer: bool = True, |
| | | sym_space: str = "<space>", |
| | | sym_blank: str = "<blank>", |
| | | # extract_feats_in_collect_stats: bool = True, |
| | | share_embedding: bool = False, |
| | | # preencoder: Optional[AbsPreEncoder] = None, |
| | | # postencoder: Optional[AbsPostEncoder] = None, |
| | | **kwargs, |
| | | ): |
| | | |
| | | super().__init__() |
| | | |
| | | if specaug is not None: |
| | | specaug_class = tables.specaug_classes.get(specaug) |
| | | specaug = specaug_class(**specaug_conf) |
| | | if normalize is not None: |
| | | normalize_class = tables.normalize_classes.get(normalize) |
| | | normalize = normalize_class(**normalize_conf) |
| | | encoder_class = tables.encoder_classes.get(encoder) |
| | | encoder = encoder_class(input_size=input_size, **encoder_conf) |
| | | encoder_output_size = encoder.output_size() |
| | | |
| | | # lcbnet modules: text encoder, fusion encoder and bias predictor |
| | | text_encoder_class = tables.encoder_classes.get(text_encoder) |
| | | text_encoder = text_encoder_class(input_size=vocab_size, **text_encoder_conf) |
| | | fusion_encoder_class = tables.encoder_classes.get(fusion_encoder) |
| | | fusion_encoder = fusion_encoder_class(**fusion_encoder_conf) |
| | | bias_predictor_class = tables.encoder_classes.get(bias_predictor) |
| | | bias_predictor = bias_predictor_class(**bias_predictor_conf) |
| | | |
| | | |
| | | if decoder is not None: |
| | | decoder_class = tables.decoder_classes.get(decoder) |
| | | decoder = decoder_class( |
| | | vocab_size=vocab_size, |
| | | encoder_output_size=encoder_output_size, |
| | | **decoder_conf, |
| | | ) |
| | | if ctc_weight > 0.0: |
| | | |
| | | if ctc_conf is None: |
| | | ctc_conf = {} |
| | | |
| | | ctc = CTC( |
| | | odim=vocab_size, encoder_output_size=encoder_output_size, **ctc_conf |
| | | ) |
| | | |
| | | self.blank_id = blank_id |
| | | self.sos = vocab_size - 1 |
| | | self.eos = vocab_size - 1 |
| | | self.vocab_size = vocab_size |
| | | self.ignore_id = ignore_id |
| | | self.ctc_weight = ctc_weight |
| | | self.specaug = specaug |
| | | self.normalize = normalize |
| | | self.encoder = encoder |
| | | # lcbnet |
| | | self.text_encoder = text_encoder |
| | | self.fusion_encoder = fusion_encoder |
| | | self.bias_predictor = bias_predictor |
| | | self.select_num = select_num |
| | | self.select_length = select_length |
| | | self.insert_blank = insert_blank |
| | | |
| | | if not hasattr(self.encoder, "interctc_use_conditioning"): |
| | | self.encoder.interctc_use_conditioning = False |
| | | if self.encoder.interctc_use_conditioning: |
| | | self.encoder.conditioning_layer = torch.nn.Linear( |
| | | vocab_size, self.encoder.output_size() |
| | | ) |
| | | self.interctc_weight = interctc_weight |
| | | |
| | | # self.error_calculator = None |
| | | if ctc_weight == 1.0: |
| | | self.decoder = None |
| | | else: |
| | | self.decoder = decoder |
| | | |
| | | self.criterion_att = LabelSmoothingLoss( |
| | | size=vocab_size, |
| | | padding_idx=ignore_id, |
| | | smoothing=lsm_weight, |
| | | normalize_length=length_normalized_loss, |
| | | ) |
| | | # |
| | | # if report_cer or report_wer: |
| | | # self.error_calculator = ErrorCalculator( |
| | | # token_list, sym_space, sym_blank, report_cer, report_wer |
| | | # ) |
| | | # |
| | | self.error_calculator = None |
| | | if ctc_weight == 0.0: |
| | | self.ctc = None |
| | | else: |
| | | self.ctc = ctc |
| | | |
| | | self.share_embedding = share_embedding |
| | | if self.share_embedding: |
| | | self.decoder.embed = None |
| | | |
| | | self.length_normalized_loss = length_normalized_loss |
| | | self.beam_search = None |
| | | |
| | | def forward( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | **kwargs, |
| | | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
| | | """Encoder + Decoder + Calc loss |
| | | Args: |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | text: (Batch, Length) |
| | | text_lengths: (Batch,) |
| | | """ |
| | | |
| | | if len(text_lengths.size()) > 1: |
| | | text_lengths = text_lengths[:, 0] |
| | | if len(speech_lengths.size()) > 1: |
| | | speech_lengths = speech_lengths[:, 0] |
| | | |
| | | batch_size = speech.shape[0] |
| | | |
| | | # 1. Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | intermediate_outs = None |
| | | if isinstance(encoder_out, tuple): |
| | | intermediate_outs = encoder_out[1] |
| | | encoder_out = encoder_out[0] |
| | | |
| | | loss_att, acc_att, cer_att, wer_att = None, None, None, None |
| | | loss_ctc, cer_ctc = None, None |
| | | stats = dict() |
| | | |
| | | # decoder: CTC branch |
| | | if self.ctc_weight != 0.0: |
| | | loss_ctc, cer_ctc = self._calc_ctc_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | |
| | | # Collect CTC branch stats |
| | | stats["loss_ctc"] = loss_ctc.detach() if loss_ctc is not None else None |
| | | stats["cer_ctc"] = cer_ctc |
| | | |
| | | # Intermediate CTC (optional) |
| | | loss_interctc = 0.0 |
| | | if self.interctc_weight != 0.0 and intermediate_outs is not None: |
| | | for layer_idx, intermediate_out in intermediate_outs: |
| | | # we assume intermediate_out has the same length & padding |
| | | # as those of encoder_out |
| | | loss_ic, cer_ic = self._calc_ctc_loss( |
| | | intermediate_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | loss_interctc = loss_interctc + loss_ic |
| | | |
| | | # Collect Intermedaite CTC stats |
| | | stats["loss_interctc_layer{}".format(layer_idx)] = ( |
| | | loss_ic.detach() if loss_ic is not None else None |
| | | ) |
| | | stats["cer_interctc_layer{}".format(layer_idx)] = cer_ic |
| | | |
| | | loss_interctc = loss_interctc / len(intermediate_outs) |
| | | |
| | | # calculate whole encoder loss |
| | | loss_ctc = ( |
| | | 1 - self.interctc_weight |
| | | ) * loss_ctc + self.interctc_weight * loss_interctc |
| | | |
| | | # decoder: Attention decoder branch |
| | | loss_att, acc_att, cer_att, wer_att = self._calc_att_loss( |
| | | encoder_out, encoder_out_lens, text, text_lengths |
| | | ) |
| | | |
| | | # 3. CTC-Att loss definition |
| | | if self.ctc_weight == 0.0: |
| | | loss = loss_att |
| | | elif self.ctc_weight == 1.0: |
| | | loss = loss_ctc |
| | | else: |
| | | loss = self.ctc_weight * loss_ctc + (1 - self.ctc_weight) * loss_att |
| | | |
| | | # Collect Attn branch stats |
| | | stats["loss_att"] = loss_att.detach() if loss_att is not None else None |
| | | stats["acc"] = acc_att |
| | | stats["cer"] = cer_att |
| | | stats["wer"] = wer_att |
| | | |
| | | # Collect total loss stats |
| | | stats["loss"] = torch.clone(loss.detach()) |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | if self.length_normalized_loss: |
| | | batch_size = int((text_lengths + 1).sum()) |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | | |
| | | def encode( |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor, **kwargs, |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Frontend + Encoder. Note that this method is used by asr_inference.py |
| | | Args: |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | ind: int |
| | | """ |
| | | with autocast(False): |
| | | # Data augmentation |
| | | if self.specaug is not None and self.training: |
| | | speech, speech_lengths = self.specaug(speech, speech_lengths) |
| | | # Normalization for feature: e.g. Global-CMVN, Utterance-CMVN |
| | | if self.normalize is not None: |
| | | speech, speech_lengths = self.normalize(speech, speech_lengths) |
| | | # Forward encoder |
| | | # feats: (Batch, Length, Dim) |
| | | # -> encoder_out: (Batch, Length2, Dim2) |
| | | if self.encoder.interctc_use_conditioning: |
| | | encoder_out, encoder_out_lens, _ = self.encoder( |
| | | speech, speech_lengths, ctc=self.ctc |
| | | ) |
| | | else: |
| | | encoder_out, encoder_out_lens, _ = self.encoder(speech, speech_lengths) |
| | | intermediate_outs = None |
| | | if isinstance(encoder_out, tuple): |
| | | intermediate_outs = encoder_out[1] |
| | | encoder_out = encoder_out[0] |
| | | |
| | | if intermediate_outs is not None: |
| | | return (encoder_out, intermediate_outs), encoder_out_lens |
| | | return encoder_out, encoder_out_lens |
| | | |
| | | def _calc_att_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | ): |
| | | ys_in_pad, ys_out_pad = add_sos_eos(ys_pad, self.sos, self.eos, self.ignore_id) |
| | | ys_in_lens = ys_pad_lens + 1 |
| | | |
| | | # 1. Forward decoder |
| | | decoder_out, _ = self.decoder( |
| | | encoder_out, encoder_out_lens, ys_in_pad, ys_in_lens |
| | | ) |
| | | |
| | | # 2. Compute attention loss |
| | | loss_att = self.criterion_att(decoder_out, ys_out_pad) |
| | | acc_att = th_accuracy( |
| | | decoder_out.view(-1, self.vocab_size), |
| | | ys_out_pad, |
| | | ignore_label=self.ignore_id, |
| | | ) |
| | | |
| | | # Compute cer/wer using attention-decoder |
| | | if self.training or self.error_calculator is None: |
| | | cer_att, wer_att = None, None |
| | | else: |
| | | ys_hat = decoder_out.argmax(dim=-1) |
| | | cer_att, wer_att = self.error_calculator(ys_hat.cpu(), ys_pad.cpu()) |
| | | |
| | | return loss_att, acc_att, cer_att, wer_att |
| | | |
| | | def _calc_ctc_loss( |
| | | self, |
| | | encoder_out: torch.Tensor, |
| | | encoder_out_lens: torch.Tensor, |
| | | ys_pad: torch.Tensor, |
| | | ys_pad_lens: torch.Tensor, |
| | | ): |
| | | # Calc CTC loss |
| | | loss_ctc = self.ctc(encoder_out, encoder_out_lens, ys_pad, ys_pad_lens) |
| | | |
| | | # Calc CER using CTC |
| | | cer_ctc = None |
| | | if not self.training and self.error_calculator is not None: |
| | | ys_hat = self.ctc.argmax(encoder_out).data |
| | | cer_ctc = self.error_calculator(ys_hat.cpu(), ys_pad.cpu(), is_ctc=True) |
| | | return loss_ctc, cer_ctc |
| | | |
| | | def init_beam_search(self, |
| | | **kwargs, |
| | | ): |
| | | from funasr.models.transformer.search import BeamSearch |
| | | from funasr.models.transformer.scorers.ctc import CTCPrefixScorer |
| | | from funasr.models.transformer.scorers.length_bonus import LengthBonus |
| | | |
| | | # 1. Build ASR model |
| | | scorers = {} |
| | | |
| | | if self.ctc != None: |
| | | ctc = CTCPrefixScorer(ctc=self.ctc, eos=self.eos) |
| | | scorers.update( |
| | | ctc=ctc |
| | | ) |
| | | token_list = kwargs.get("token_list") |
| | | scorers.update( |
| | | decoder=self.decoder, |
| | | length_bonus=LengthBonus(len(token_list)), |
| | | ) |
| | | |
| | | |
| | | # 3. Build ngram model |
| | | # ngram is not supported now |
| | | ngram = None |
| | | scorers["ngram"] = ngram |
| | | |
| | | weights = dict( |
| | | decoder=1.0 - kwargs.get("decoding_ctc_weight", 0.3), |
| | | ctc=kwargs.get("decoding_ctc_weight", 0.3), |
| | | lm=kwargs.get("lm_weight", 0.0), |
| | | ngram=kwargs.get("ngram_weight", 0.0), |
| | | length_bonus=kwargs.get("penalty", 0.0), |
| | | ) |
| | | beam_search = BeamSearch( |
| | | beam_size=kwargs.get("beam_size", 20), |
| | | weights=weights, |
| | | scorers=scorers, |
| | | sos=self.sos, |
| | | eos=self.eos, |
| | | vocab_size=len(token_list), |
| | | token_list=token_list, |
| | | pre_beam_score_key=None if self.ctc_weight == 1.0 else "full", |
| | | ) |
| | | |
| | | self.beam_search = beam_search |
| | | |
| | | def inference(self, |
| | | data_in, |
| | | data_lengths=None, |
| | | key: list=None, |
| | | tokenizer=None, |
| | | frontend=None, |
| | | **kwargs, |
| | | ): |
| | | |
| | | if kwargs.get("batch_size", 1) > 1: |
| | | raise NotImplementedError("batch decoding is not implemented") |
| | | |
| | | # init beamsearch |
| | | if self.beam_search is None: |
| | | logging.info("enable beam_search") |
| | | self.init_beam_search(**kwargs) |
| | | self.nbest = kwargs.get("nbest", 1) |
| | | |
| | | meta_data = {} |
| | | if isinstance(data_in, torch.Tensor) and kwargs.get("data_type", "sound") == "fbank": # fbank |
| | | speech, speech_lengths = data_in, data_lengths |
| | | if len(speech.shape) < 3: |
| | | speech = speech[None, :, :] |
| | | if speech_lengths is None: |
| | | speech_lengths = speech.shape[1] |
| | | else: |
| | | # extract fbank feats |
| | | time1 = time.perf_counter() |
| | | sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000), |
| | | data_type=kwargs.get("data_type", "sound"), |
| | | tokenizer=tokenizer) |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | audio_sample_list = sample_list[0] |
| | | if len(sample_list) >1: |
| | | ocr_sample_list = sample_list[1] |
| | | else: |
| | | ocr_sample_list = [[294, 0]] |
| | | speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), |
| | | frontend=frontend) |
| | | time3 = time.perf_counter() |
| | | meta_data["extract_feat"] = f"{time3 - time2:0.3f}" |
| | | frame_shift = 10 |
| | | meta_data["batch_data_time"] = speech_lengths.sum().item() * frame_shift / 1000 |
| | | |
| | | speech = speech.to(device=kwargs["device"]) |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | | # Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | if isinstance(encoder_out, tuple): |
| | | encoder_out = encoder_out[0] |
| | | |
| | | ocr_list_new = [[x + 1 if x != 0 else x for x in sublist] for sublist in ocr_sample_list] |
| | | ocr = torch.tensor(ocr_list_new).to(device=kwargs["device"]) |
| | | ocr_lengths = ocr.new_full([1], dtype=torch.long, fill_value=ocr.size(1)).to(device=kwargs["device"]) |
| | | ocr, ocr_lens, _ = self.text_encoder(ocr, ocr_lengths) |
| | | fusion_out, _, _, _ = self.fusion_encoder(encoder_out,None, ocr, None) |
| | | encoder_out = encoder_out + fusion_out |
| | | # c. Passed the encoder result and the beam search |
| | | nbest_hyps = self.beam_search( |
| | | x=encoder_out[0], maxlenratio=kwargs.get("maxlenratio", 0.0), minlenratio=kwargs.get("minlenratio", 0.0) |
| | | ) |
| | | |
| | | nbest_hyps = nbest_hyps[: self.nbest] |
| | | |
| | | results = [] |
| | | b, n, d = encoder_out.size() |
| | | for i in range(b): |
| | | |
| | | for nbest_idx, hyp in enumerate(nbest_hyps): |
| | | ibest_writer = None |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{nbest_idx + 1}best_recog"] |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | | token_int = hyp.yseq[1:last_pos] |
| | | else: |
| | | token_int = hyp.yseq[1:last_pos].tolist() |
| | | |
| | | # remove blank symbol id, which is assumed to be 0 |
| | | token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int)) |
| | | |
| | | # Change integer-ids to tokens |
| | | token = tokenizer.ids2tokens(token_int) |
| | | text = tokenizer.tokens2text(token) |
| | | |
| | | text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) |
| | | result_i = {"key": key[i], "token": token, "text": text_postprocessed} |
| | | results.append(result_i) |
| | | |
| | | if ibest_writer is not None: |
| | | ibest_writer["token"][key[i]] = " ".join(token) |
| | | ibest_writer["text"][key[i]] = text_postprocessed |
| | | |
| | | return results, meta_data |
| | | |
| | |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | |
| | | |
| | | import pdb |
| | | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | | from torch.cuda.amp import autocast |
| | | else: |
| | |
| | | hotword_pad = kwargs.get("hotword_pad") |
| | | hotword_lengths = kwargs.get("hotword_lengths") |
| | | dha_pad = kwargs.get("dha_pad") |
| | | |
| | | |
| | | batch_size = speech.shape[0] |
| | | # for data-parallel |
| | | text = text[:, : text_lengths.max()] |
| | |
| | | nfilter=50, |
| | | seaco_weight=1.0): |
| | | # decoder forward |
| | | |
| | | decoder_out, decoder_hidden, _ = self.decoder(encoder_out, encoder_out_lens, sematic_embeds, ys_pad_lens, return_hidden=True, return_both=True) |
| | | |
| | | decoder_pred = torch.log_softmax(decoder_out, dim=-1) |
| | | if hw_list is not None: |
| | | hw_lengths = [len(i) for i in hw_list] |
| | | hw_list_ = [torch.Tensor(i).long() for i in hw_list] |
| | | hw_list_pad = pad_list(hw_list_, 0).to(encoder_out.device) |
| | | selected = self._hotword_representation(hw_list_pad, torch.Tensor(hw_lengths).int().to(encoder_out.device)) |
| | | |
| | | contextual_info = selected.squeeze(0).repeat(encoder_out.shape[0], 1, 1).to(encoder_out.device) |
| | | num_hot_word = contextual_info.shape[1] |
| | | _contextual_length = torch.Tensor([num_hot_word]).int().repeat(encoder_out.shape[0]).to(encoder_out.device) |
| | | |
| | | |
| | | # ASF Core |
| | | if nfilter > 0 and nfilter < num_hot_word: |
| | | hotword_scores = self.seaco_decoder.forward_asf6(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens) |
| | |
| | | cif_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, sematic_embeds, ys_pad_lens) |
| | | dec_attended, _ = self.seaco_decoder(contextual_info, _contextual_length, decoder_hidden, ys_pad_lens) |
| | | merged = self._merge(cif_attended, dec_attended) |
| | | |
| | | |
| | | dha_output = self.hotword_output_layer(merged) # remove the last token in loss calculation |
| | | dha_pred = torch.log_softmax(dha_output, dim=-1) |
| | | def _merge_res(dec_output, dha_output): |
| | |
| | | # logits = dec_output * dha_mask + dha_output[:,:,:-1] * (1-dha_mask) |
| | | logits = dec_output * dha_mask + dha_output[:,:,:] * (1-dha_mask) |
| | | return logits |
| | | |
| | | merged_pred = _merge_res(decoder_pred, dha_pred) |
| | | # import pdb; pdb.set_trace() |
| | | return merged_pred |
| | | else: |
| | | return decoder_pred |
| | |
| | | logging.info("enable beam_search") |
| | | self.init_beam_search(**kwargs) |
| | | self.nbest = kwargs.get("nbest", 1) |
| | | |
| | | meta_data = {} |
| | | |
| | | # extract fbank feats |
| | |
| | | if isinstance(encoder_out, tuple): |
| | | encoder_out = encoder_out[0] |
| | | |
| | | |
| | | # predictor |
| | | predictor_outs = self.calc_predictor(encoder_out, encoder_out_lens) |
| | | pre_acoustic_embeds, pre_token_length, _, _ = predictor_outs[0], predictor_outs[1], \ |
| | |
| | | if torch.max(pre_token_length) < 1: |
| | | return [] |
| | | |
| | | |
| | | decoder_out = self._seaco_decode_with_ASF(encoder_out, encoder_out_lens, |
| | | pre_acoustic_embeds, |
| | | pre_token_length, |
| | | hw_list=self.hotword_list) |
| | | |
| | | # decoder_out, _ = decoder_outs[0], decoder_outs[1] |
| | | _, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens, |
| | | pre_token_length) |
| | | |
| | | results = [] |
| | | b, n, d = decoder_out.size() |
| | | for i in range(b): |
| | |
| | | import torch |
| | | import torch.nn |
| | | import torch.optim |
| | | |
| | | import pdb |
| | | |
| | | def filter_state_dict( |
| | | dst_state: Dict[str, Union[float, torch.Tensor]], |
| | |
| | | dst_state = obj.state_dict() |
| | | |
| | | print(f"ckpt: {path}") |
| | | |
| | | if oss_bucket is None: |
| | | src_state = torch.load(path, map_location=map_location) |
| | | else: |
| | |
| | | from funasr.download.file import download_from_url |
| | | except: |
| | | print("urllib is not installed, if you infer from url, please install it first.") |
| | | |
| | | import pdb |
| | | |
| | | |
| | | def load_audio_text_image_video(data_or_path_or_list, fs: int = 16000, audio_fs: int = 16000, data_type="sound", tokenizer=None, **kwargs): |
| | | if isinstance(data_or_path_or_list, (list, tuple)): |
| | | if data_type is not None and isinstance(data_type, (list, tuple)): |
| | | |
| | | data_types = [data_type] * len(data_or_path_or_list) |
| | | data_or_path_or_list_ret = [[] for d in data_type] |
| | | for i, (data_type_i, data_or_path_or_list_i) in enumerate(zip(data_types, data_or_path_or_list)): |
| | | |
| | | for j, (data_type_j, data_or_path_or_list_j) in enumerate(zip(data_type_i, data_or_path_or_list_i)): |
| | | |
| | | data_or_path_or_list_j = load_audio_text_image_video(data_or_path_or_list_j, fs=fs, audio_fs=audio_fs, data_type=data_type_j, tokenizer=tokenizer, **kwargs) |
| | | data_or_path_or_list_ret[j].append(data_or_path_or_list_j) |
| | | |
| | | return data_or_path_or_list_ret |
| | | else: |
| | | return [load_audio_text_image_video(audio, fs=fs, audio_fs=audio_fs, data_type=data_type, **kwargs) for audio in data_or_path_or_list] |
| | | |
| | | if isinstance(data_or_path_or_list, str) and data_or_path_or_list.startswith('http'): # download url to local file |
| | | data_or_path_or_list = download_from_url(data_or_path_or_list) |
| | | |
| | | |
| | | if isinstance(data_or_path_or_list, str) and os.path.exists(data_or_path_or_list): # local file |
| | | if data_type is None or data_type == "sound": |
| | | data_or_path_or_list, audio_fs = torchaudio.load(data_or_path_or_list) |
| | |
| | | data_or_path_or_list = tokenizer.encode(data_or_path_or_list) |
| | | elif isinstance(data_or_path_or_list, np.ndarray): # audio sample point |
| | | data_or_path_or_list = torch.from_numpy(data_or_path_or_list).squeeze() # [n_samples,] |
| | | elif isinstance(data_or_path_or_list, str) and data_type == "kaldi_ark": |
| | | data_mat = kaldiio.load_mat(data_or_path_or_list) |
| | | if isinstance(data_mat, tuple): |
| | | audio_fs, mat = data_mat |
| | | else: |
| | | mat = data_mat |
| | | if mat.dtype == 'int16' or mat.dtype == 'int32': |
| | | mat = mat.astype(np.float64) |
| | | mat = mat / 32768 |
| | | if mat.ndim ==2: |
| | | mat = mat[:,0] |
| | | data_or_path_or_list = mat |
| | | else: |
| | | pass |
| | | # print(f"unsupport data type: {data_or_path_or_list}, return raw data") |
| | | |
| | | |
| | | if audio_fs != fs and data_type != "text": |
| | | resampler = torchaudio.transforms.Resample(audio_fs, fs) |
| | | data_or_path_or_list = resampler(data_or_path_or_list[None, :])[0, :] |
| | |
| | | return array |
| | | |
| | | def extract_fbank(data, data_len = None, data_type: str="sound", frontend=None, **kwargs): |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | if isinstance(data, np.ndarray): |
| | | data = torch.from_numpy(data) |
| | | if len(data.shape) < 2: |
| | |
| | | data_list.append(data_i) |
| | | data_len.append(data_i.shape[0]) |
| | | data = pad_sequence(data_list, batch_first=True) # data: [batch, N] |
| | | # import pdb; |
| | | # pdb.set_trace() |
| | | # if data_type == "sound": |
| | | |
| | | data, data_len = frontend(data, data_len, **kwargs) |
| | | |
| | | if isinstance(data_len, (list, tuple)): |