From 87d5f69b819df11969263cf99f7cc2f80bea30da Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期二, 07 五月 2024 13:43:53 +0800
Subject: [PATCH] decoding key
---
funasr/datasets/sense_voice_datasets/datasets.py | 2
funasr/models/sense_voice/model.py | 11 ++++-
funasr/utils/misc.py | 14 +++++++
examples/industrial_data_pretraining/paraformer/demo.py | 10 +----
funasr/auto/auto_model.py | 25 ++++++------
examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py | 4 -
6 files changed, 38 insertions(+), 28 deletions(-)
diff --git a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
index 0f30a37..21ce0cb 100644
--- a/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
+++ b/examples/industrial_data_pretraining/fsmn_vad_streaming/demo.py
@@ -9,11 +9,9 @@
model = AutoModel(model="iic/speech_fsmn_vad_zh-cn-16k-common-pytorch")
-mm = model.model
-for p in mm.parameters():
- print(f"{p.numel()}")
res = model.generate(input=wav_file)
print(res)
+
# [[beg1, end1], [beg2, end2], .., [begN, endN]]
# beg/end: ms
diff --git a/examples/industrial_data_pretraining/paraformer/demo.py b/examples/industrial_data_pretraining/paraformer/demo.py
index f6f4c75..eb7e72f 100644
--- a/examples/industrial_data_pretraining/paraformer/demo.py
+++ b/examples/industrial_data_pretraining/paraformer/demo.py
@@ -14,14 +14,8 @@
)
res = model.generate(
- input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
-)
-res = model.generate(
- input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
-)
-
-res = model.generate(
- input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav"
+ input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_zh.wav",
+ cache={},
)
print(res)
diff --git a/funasr/auto/auto_model.py b/funasr/auto/auto_model.py
index 32fd560..e77f04f 100644
--- a/funasr/auto/auto_model.py
+++ b/funasr/auto/auto_model.py
@@ -26,6 +26,7 @@
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils import export_utils
+from funasr.utils import misc
try:
from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
@@ -35,14 +36,7 @@
def prepare_data_iterator(data_in, input_len=None, data_type=None, key=None):
- """
-
- :param input:
- :param input_len:
- :param data_type:
- :param frontend:
- :return:
- """
+ """ """
data_list = []
key_list = []
filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
@@ -73,7 +67,8 @@
key_list.append(key)
else:
if key is None:
- key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+ # key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+ key = misc.extract_filename_without_extension(data_in)
data_list = [data_in]
key_list = [key]
elif isinstance(data_in, (list, tuple)):
@@ -90,10 +85,14 @@
else:
# [audio sample point, fbank, text]
data_list = data_in
- key_list = [
- "rand_key_" + "".join(random.choice(chars) for _ in range(13))
- for _ in range(len(data_in))
- ]
+ key_list = []
+ for data_i in data_in:
+ if isinstance(data_i, str) and os.path.exists(data_i):
+ key = misc.extract_filename_without_extension(data_i)
+ else:
+ key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
+ key_list.append(key)
+
else: # raw text; audio sample point, fbank; bytes
if isinstance(data_in, bytes): # audio bytes
data_in = load_bytes(data_in)
diff --git a/funasr/datasets/sense_voice_datasets/datasets.py b/funasr/datasets/sense_voice_datasets/datasets.py
index 5d80956..ee2f13d 100644
--- a/funasr/datasets/sense_voice_datasets/datasets.py
+++ b/funasr/datasets/sense_voice_datasets/datasets.py
@@ -112,7 +112,7 @@
eos = self.tokenizer.encode(self.eos, allowed_special="all") # [eos]
- ids = prompt_ids + target_ids + eos
+ ids = prompt_ids + target_ids + eos # [sos, task, lid, text, eos]
ids_lengths = len(ids)
text = torch.tensor(ids, dtype=torch.int64)
diff --git a/funasr/models/sense_voice/model.py b/funasr/models/sense_voice/model.py
index d5e4130..bcaaca3 100644
--- a/funasr/models/sense_voice/model.py
+++ b/funasr/models/sense_voice/model.py
@@ -378,14 +378,19 @@
stats = {}
# 1. Forward decoder
+ # ys_pad: [sos, task, lid, text, eos]
decoder_out = self.model.decoder(
x=ys_pad, xa=encoder_out, hlens=encoder_out_lens, ys_in_lens=ys_pad_lens
)
# 2. Compute attention loss
- mask = torch.ones_like(ys_pad) * (-1)
- ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(torch.int64)
- ys_pad_mask[ys_pad_mask == 0] = -1
+ mask = torch.ones_like(ys_pad) * (-1) # [sos, task, lid, text, eos]: [-1, -1, -1, -1]
+ ys_pad_mask = (ys_pad * target_mask + mask * (1 - target_mask)).to(
+ torch.int64
+ ) # [sos, task, lid, text, eos]: [0, 0, 1, 1, 1] + [-1, -1, 0, 0, 0]
+ ys_pad_mask[ys_pad_mask == 0] = -1 # [-1, -1, lid, text, eos]
+ # decoder_out: [sos, task, lid, text]
+ # ys_pad_mask: [-1, lid, text, eos]
loss_att = self.criterion_att(decoder_out[:, :-1, :], ys_pad_mask[:, 1:])
with torch.no_grad():
diff --git a/funasr/utils/misc.py b/funasr/utils/misc.py
index 5eaa4f8..9f01955 100644
--- a/funasr/utils/misc.py
+++ b/funasr/utils/misc.py
@@ -78,3 +78,17 @@
# config_json = os.path.join(model_path, "configuration.json")
# if os.path.exists(config_json):
# shutil.copy(config_json, os.path.join(kwargs.get("output_dir", "./"), "configuration.json"))
+
+
+def extract_filename_without_extension(file_path):
+ """
+ 浠庣粰瀹氱殑鏂囦欢璺緞涓彁鍙栨枃浠跺悕锛堜笉鍖呭惈璺緞鍜屾墿灞曞悕锛�
+ :param file_path: 瀹屾暣鐨勬枃浠惰矾寰�
+ :return: 鏂囦欢鍚嶏紙涓嶅惈璺緞鍜屾墿灞曞悕锛�
+ """
+ # 棣栧厛锛屼娇鐢╫s.path.basename鑾峰彇璺緞涓殑鏂囦欢鍚嶉儴鍒嗭紙鍚墿灞曞悕锛�
+ filename_with_extension = os.path.basename(file_path)
+ # 鐒跺悗锛屼娇鐢╫s.path.splitext鍒嗙鏂囦欢鍚嶅拰鎵╁睍鍚�
+ filename, extension = os.path.splitext(filename_with_extension)
+ # 杩斿洖涓嶅寘鍚墿灞曞悕鐨勬枃浠跺悕
+ return filename
--
Gitblit v1.9.1