From 5a8f37908469d9550f905ba0876c7c4e6f9b8026 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期四, 21 十二月 2023 21:08:46 +0800
Subject: [PATCH] vad + asr
---
funasr/bin/inference.py | 198 ++++++++++++++++++++++--
funasr/models/bici_paraformer/model.py | 88 ++++++----
funasr/utils/vad_utils.py | 15 +
examples/industrial_data_pretraining/paraformer-large-long/infer.sh | 31 +++
funasr/models/paraformer_streaming/__init__.py | 0
funasr/bin/train.py | 4
funasr/download/download_from_hub.py | 3
funasr/models/bici_paraformer/template.yaml | 134 ++++++++++++++++
funasr/models/paraformer_streaming/model.py | 2
funasr/models/paraformer_streaming/sanm_decoder.py | 0
10 files changed, 411 insertions(+), 64 deletions(-)
diff --git a/examples/industrial_data_pretraining/paraformer-large-long/infer.sh b/examples/industrial_data_pretraining/paraformer-large-long/infer.sh
new file mode 100644
index 0000000..d77329e
--- /dev/null
+++ b/examples/industrial_data_pretraining/paraformer-large-long/infer.sh
@@ -0,0 +1,31 @@
+
+cmd="funasr/bin/inference.py"
+
+python $cmd \
++model="/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
++vad_model="/Users/zhifu/Downloads/modelscope_models/speech_fsmn_vad_zh-cn-16k-common-pytorch" \
++input="/Users/zhifu/funasr_github/test_local/vad_example.wav" \
++output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2" \
++device="cpu" \
++batch_size_s=300 \
++batch_size_threshold_s=60 \
++debug="true"
+
+#python $cmd \
+#+model="/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \
+#+input="/Users/zhifu/Downloads/asr_example.wav" \
+#+output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2" \
+#+device="cpu" \
+#+"hotword='杈鹃瓟闄� 榄旀惌'"
+
+#+input="/Users/zhifu/funasr_github/test_local/wav.scp"
+#+input="/Users/zhifu/funasr_github/test_local/asr_example.wav" \
+#+input="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len.jsonl" \
+#+input="/Users/zhifu/funasr_github/test_local/aishell2_dev_ios/asr_task_debug_len_10.jsonl" \
+#+model="/Users/zhifu/modelscope_models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
+
+#+model="/Users/zhifu/modelscope_models/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \
+#+model="/Users/zhifu/modelscope_models/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" \
+#+"hotword='杈鹃瓟闄� 榄旀惌'"
+
+#+vad_model="/Users/zhifu/Downloads/modelscope_models/speech_fsmn_vad_zh-cn-16k-common-pytorch" \
diff --git a/funasr/bin/inference.py b/funasr/bin/inference.py
index d7b33e3..fda7abe 100644
--- a/funasr/bin/inference.py
+++ b/funasr/bin/inference.py
@@ -16,7 +16,8 @@
import random
import string
from funasr.register import tables
-
+from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio
+from funasr.utils.vad_utils import slice_padding_audio_samples
def build_iter_for_infer(data_in, input_len=None, data_type="sound"):
"""
@@ -73,15 +74,44 @@
logging.basicConfig(level=log_level)
- import pdb;
- pdb.set_trace()
+ if kwargs.get("debug", False):
+ import pdb; pdb.set_trace()
model = AutoModel(**kwargs)
- res = model.generate(input=kwargs["input"])
+ res = model(input=kwargs["input"])
print(res)
class AutoModel:
+
def __init__(self, **kwargs):
tables.print()
+
+ model, kwargs = self.build_model(**kwargs)
+
+ # if vad_model is not None, build vad model else None
+ vad_model = kwargs.get("vad_model", None)
+ vad_kwargs = kwargs.get("vad_model_revision", None)
+ if vad_model is not None:
+ print("build vad model")
+ vad_kwargs = {"model": vad_model, "model_revision": vad_kwargs}
+ vad_model, vad_kwargs = self.build_model(**vad_kwargs)
+
+ # if punc_model is not None, build punc model else None
+ punc_model = kwargs.get("punc_model", None)
+ punc_kwargs = kwargs.get("punc_model_revision", None)
+ if punc_model is not None:
+ punc_kwargs = {"model": punc_model, "model_revision": punc_kwargs}
+ punc_model, punc_kwargs = self.build_model(**punc_kwargs)
+
+ self.kwargs = kwargs
+ self.model = model
+ self.vad_model = vad_model
+ self.vad_kwargs = vad_kwargs
+ self.punc_model = punc_model
+ self.punc_kwargs = punc_kwargs
+
+
+
+ def build_model(self, **kwargs):
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
@@ -94,7 +124,7 @@
device = "cpu"
kwargs["batch_size"] = 1
kwargs["device"] = device
-
+
# build tokenizer
tokenizer = kwargs.get("tokenizer", None)
if tokenizer is not None:
@@ -113,7 +143,8 @@
# build model
model_class = tables.model_classes.get(kwargs["model"].lower())
- model = model_class(**kwargs, **kwargs["model_conf"], vocab_size=len(tokenizer.token_list) if tokenizer is not None else -1)
+ model = model_class(**kwargs, **kwargs["model_conf"],
+ vocab_size=len(tokenizer.token_list) if tokenizer is not None else -1)
model.eval()
model.to(device)
@@ -127,23 +158,34 @@
ignore_init_mismatch=kwargs.get("ignore_init_mismatch", False),
oss_bucket=kwargs.get("oss_bucket", None),
)
- self.kwargs = kwargs
- self.model = model
- self.tokenizer = tokenizer
+
+ return model, kwargs
- def generate(self, input, input_len=None, **cfg):
- self.kwargs.update(cfg)
- data_type = self.kwargs.get("data_type", "sound")
- batch_size = self.kwargs.get("batch_size", 1)
- if self.kwargs.get("device", "cpu") == "cpu":
- batch_size = 1
+ def __call__(self, input, input_len=None, **cfg):
+ if self.vad_model is None:
+ return self.generate(input, input_len=input_len, **cfg)
+
+ else:
+ return self.generate_with_vad(input, input_len=input_len, **cfg)
+
+ def generate(self, input, input_len=None, model=None, kwargs=None, **cfg):
+ kwargs = self.kwargs if kwargs is None else kwargs
+ kwargs.update(cfg)
+ model = self.model if model is None else model
+
+ data_type = kwargs.get("data_type", "sound")
+ batch_size = kwargs.get("batch_size", 1)
+ # if kwargs.get("device", "cpu") == "cpu":
+ # batch_size = 1
key_list, data_list = build_iter_for_infer(input, input_len=input_len, data_type=data_type)
speed_stats = {}
asr_result_list = []
num_samples = len(data_list)
- pbar = tqdm(colour="blue", total=num_samples, dynamic_ncols=True)
+ pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True)
+ time_speech_total = 0.0
+ time_escape_total = 0.0
for beg_idx in range(0, num_samples, batch_size):
end_idx = min(num_samples, beg_idx + batch_size)
data_batch = data_list[beg_idx:end_idx]
@@ -154,25 +196,139 @@
batch["data_lengths"] = input_len
time1 = time.perf_counter()
- results, meta_data = self.model.generate(**batch, **self.kwargs)
+ results, meta_data = model.generate(**batch, **kwargs)
time2 = time.perf_counter()
- asr_result_list.append(results)
+ asr_result_list.extend(results)
pbar.update(1)
# batch_data_time = time_per_frame_s * data_batch_i["speech_lengths"].sum().item()
batch_data_time = meta_data.get("batch_data_time", -1)
+ time_escape = time2 - time1
speed_stats["load_data"] = meta_data.get("load_data", 0.0)
speed_stats["extract_feat"] = meta_data.get("extract_feat", 0.0)
- speed_stats["forward"] = f"{time2 - time1:0.3f}"
- speed_stats["rtf"] = f"{(time2 - time1) / batch_data_time:0.3f}"
+ speed_stats["forward"] = f"{time_escape:0.3f}"
+ speed_stats["batch_size"] = f"{len(results)}"
+ speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}"
description = (
f"{speed_stats}, "
)
pbar.set_description(description)
-
+ time_speech_total += batch_data_time
+ time_escape_total += time_escape
+
+ pbar.update(1)
+ pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
torch.cuda.empty_cache()
return asr_result_list
+
+ def generate_with_vad(self, input, input_len=None, **cfg):
+
+ # step.1: compute the vad model
+ model = self.vad_model
+ kwargs = self.vad_kwargs
+ beg_vad = time.time()
+ res = self.generate(input, input_len=input_len, model=model, kwargs=kwargs, **cfg)
+ end_vad = time.time()
+ print(f"time cost vad: {end_vad - beg_vad:0.3f}")
+
+ # step.2 compute asr model
+ model = self.model
+ kwargs = self.kwargs
+ kwargs.update(cfg)
+ batch_size = int(kwargs.get("batch_size_s", 300))*1000
+ batch_size_threshold_ms = int(kwargs.get("batch_size_threshold_s", 60))*1000
+ kwargs["batch_size"] = batch_size
+ data_type = kwargs.get("data_type", "sound")
+ key_list, data_list = build_iter_for_infer(input, input_len=input_len, data_type=data_type)
+ results_ret_list = []
+ time_speech_total_all_samples = 0.0
+
+ beg_total = time.time()
+ pbar_total = tqdm(colour="red", total=len(res) + 1, dynamic_ncols=True)
+ for i in range(len(res)):
+ key = res[i]["key"]
+ vadsegments = res[i]["value"]
+ input_i = data_list[i]
+ speech = load_audio(input_i, fs=kwargs["frontend"].fs, audio_fs=kwargs.get("fs", 16000))
+ speech_lengths = len(speech)
+ n = len(vadsegments)
+ data_with_index = [(vadsegments[i], i) for i in range(n)]
+ sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0])
+ results_sorted = []
+
+ if not len(sorted_data):
+ logging.info("decoding, utt: {}, empty speech".format(key))
+ continue
+
+
+ # if kwargs["device"] == "cpu":
+ # batch_size = 0
+ if len(sorted_data) > 0 and len(sorted_data[0]) > 0:
+ batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0])
+
+ batch_size_ms_cum = 0
+ beg_idx = 0
+ beg_asr_total = time.time()
+ time_speech_total_per_sample = speech_lengths/16000
+ time_speech_total_all_samples += time_speech_total_per_sample
+
+ for j, _ in enumerate(range(0, n)):
+ batch_size_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
+ if j < n - 1 and (
+ batch_size_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size and (
+ sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size_threshold_ms:
+ continue
+ batch_size_ms_cum = 0
+ end_idx = j + 1
+ speech_j, speech_lengths_j = slice_padding_audio_samples(speech, speech_lengths, sorted_data[beg_idx:end_idx])
+ beg_idx = end_idx
+
+ results = self.generate(speech_j, input_len=None, model=model, kwargs=kwargs, **cfg)
+
+ if len(results) < 1:
+ continue
+ results_sorted.extend(results)
+
+
+ pbar_total.update(1)
+ end_asr_total = time.time()
+ time_escape_total_per_sample = end_asr_total - beg_asr_total
+ pbar_total.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
+ f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
+ f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
+
+ restored_data = [0] * n
+ for j in range(n):
+ index = sorted_data[j][1]
+ restored_data[index] = results_sorted[j]
+ result = {}
+
+ for j in range(n):
+ for k, v in restored_data[j].items():
+ if not k.startswith("timestamp"):
+ if k not in result:
+ result[k] = restored_data[j][k]
+ else:
+ result[k] += restored_data[j][k]
+ else:
+ result[k] = []
+ for t in restored_data[j][k]:
+ t[0] += vadsegments[j][0]
+ t[1] += vadsegments[j][0]
+ result[k] += restored_data[j][k]
+
+ result["key"] = key
+ results_ret_list.append(result)
+ pbar_total.update(1)
+ pbar_total.update(1)
+ end_total = time.time()
+ time_escape_total_all_samples = end_total - beg_total
+ pbar_total.set_description(f"rtf_avg_all_samples: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
+ f"time_speech_total_all_samples: {time_speech_total_all_samples: 0.3f}, "
+ f"time_escape_total_all_samples: {time_escape_total_all_samples:0.3f}")
+ return results_ret_list
+
if __name__ == '__main__':
main_hydra()
\ No newline at end of file
diff --git a/funasr/bin/train.py b/funasr/bin/train.py
index b1f0d06..af3e8af 100644
--- a/funasr/bin/train.py
+++ b/funasr/bin/train.py
@@ -25,7 +25,9 @@
@hydra.main(config_name=None, version_base=None)
def main_hydra(kwargs: DictConfig):
- import pdb; pdb.set_trace()
+ if kwargs.get("debug", False):
+ import pdb; pdb.set_trace()
+
assert "model" in kwargs
if "model_conf" not in kwargs:
logging.info("download models from model hub: {}".format(kwargs.get("model_hub", "ms")))
diff --git a/funasr/download/download_from_hub.py b/funasr/download/download_from_hub.py
index eeb5d0c..47eda9e 100644
--- a/funasr/download/download_from_hub.py
+++ b/funasr/download/download_from_hub.py
@@ -24,11 +24,10 @@
kwargs["init_param"] = init_param
if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
kwargs["tokenizer_conf"]["token_list"] = os.path.join(model_or_path, "tokens.txt")
- if os.path.exists(os.path.join(model_or_path, "tokens.txt")):
+ if os.path.exists(os.path.join(model_or_path, "seg_dict")):
kwargs["tokenizer_conf"]["seg_dict"] = os.path.join(model_or_path, "seg_dict")
if os.path.exists(os.path.join(model_or_path, "bpe.model")):
kwargs["tokenizer_conf"]["bpemodel"] = os.path.join(model_or_path, "bpe.model")
-
kwargs["model"] = cfg["model"]
kwargs["frontend_conf"]["cmvn_file"] = os.path.join(model_or_path, "am.mvn")
diff --git a/funasr/models/bici_paraformer/model.py b/funasr/models/bici_paraformer/model.py
index 52eac87..03c8896 100644
--- a/funasr/models/bici_paraformer/model.py
+++ b/funasr/models/bici_paraformer/model.py
@@ -29,6 +29,7 @@
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.register import tables
from funasr.models.ctc.ctc import CTC
+from funasr.utils.timestamp_tools import time_stamp_sentence
from funasr.models.paraformer.model import Paraformer
@@ -211,10 +212,11 @@
loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
return loss, stats, weight
-
+
+
def generate(self,
- data_in: list,
- data_lengths: list = None,
+ data_in,
+ data_lengths=None,
key: list = None,
tokenizer=None,
frontend=None,
@@ -230,17 +232,23 @@
self.nbest = kwargs.get("nbest", 1)
meta_data = {}
- # extract fbank feats
- time1 = time.perf_counter()
- audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
- time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
- frontend=self.frontend)
- time3 = time.perf_counter()
- meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- meta_data[
- "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
+ if isinstance(data_in, torch.Tensor): # fbank
+ speech, speech_lengths = data_in, data_lengths
+ if len(speech.shape) < 3:
+ speech = speech[None, :, :]
+ if speech_lengths is None:
+ speech_lengths = speech.shape[1]
+ else:
+ # extract fbank feats
+ time1 = time.perf_counter()
+ audio_sample_list = load_audio(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
+ time2 = time.perf_counter()
+ meta_data["load_data"] = f"{time2 - time1:0.3f}"
+ speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
+ frontend=frontend)
+ time3 = time.perf_counter()
+ meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
+ meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
@@ -261,9 +269,8 @@
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
# BiCifParaformer, test no bias cif2
-
_, _, us_alphas, us_peaks = self.calc_predictor_timestamp(encoder_out, encoder_out_lens,
- pre_token_length)
+ pre_token_length)
results = []
b, n, d = decoder_out.size()
@@ -302,27 +309,32 @@
# remove blank symbol id, which is assumed to be 0
token_int = list(filter(lambda x: x != self.eos and x != self.sos and x != self.blank_id, token_int))
- # Change integer-ids to tokens
- token = tokenizer.ids2tokens(token_int)
- text = tokenizer.tokens2text(token)
-
- _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
- us_peaks[i][:encoder_out_lens[i] * 3],
- copy.copy(token),
- vad_offset=kwargs.get("begin_time", 0))
-
- text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(token, timestamp)
-
- result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed,
- "time_stamp_postprocessed": time_stamp_postprocessed,
- "word_lists": word_lists
- }
- results.append(result_i)
-
- if ibest_writer is not None:
- ibest_writer["token"][key[i]] = " ".join(token)
- ibest_writer["text"][key[i]] = text
- ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
+ if tokenizer is not None:
+ # Change integer-ids to tokens
+ token = tokenizer.ids2tokens(token_int)
+ text = tokenizer.tokens2text(token)
+ _, timestamp = ts_prediction_lfr6_standard(us_alphas[i][:encoder_out_lens[i] * 3],
+ us_peaks[i][:encoder_out_lens[i] * 3],
+ copy.copy(token),
+ vad_offset=kwargs.get("begin_time", 0))
+
+ text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess(
+ token, timestamp)
+ sentences = time_stamp_sentence(None, time_stamp_postprocessed, text_postprocessed)
+ result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed,
+ "timestamp": time_stamp_postprocessed,
+ "word_lists": word_lists,
+ "sentences": sentences
+ }
+
+ if ibest_writer is not None:
+ ibest_writer["token"][key[i]] = " ".join(token)
+ ibest_writer["text"][key[i]] = text
+ ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed
+ ibest_writer["text_postprocessed"][key[i]] = text_postprocessed
+ else:
+ result_i = {"key": key[i], "token_int": token_int}
+ results.append(result_i)
- return results, meta_data
+ return results, meta_data
\ No newline at end of file
diff --git a/funasr/models/bici_paraformer/template.yaml b/funasr/models/bici_paraformer/template.yaml
new file mode 100644
index 0000000..d2b0e0a
--- /dev/null
+++ b/funasr/models/bici_paraformer/template.yaml
@@ -0,0 +1,134 @@
+# This is an example that demonstrates how to configure a model file.
+# You can modify the configuration according to your own requirements.
+
+# to print the register_table:
+# from funasr.register import tables
+# tables.print()
+
+# network architecture
+#model: funasr.models.paraformer.model:Paraformer
+model: BiCifParaformer
+model_conf:
+ ctc_weight: 0.0
+ lsm_weight: 0.1
+ length_normalized_loss: true
+ predictor_weight: 1.0
+ predictor_bias: 1
+ sampling_ratio: 0.75
+
+# encoder
+encoder: SANMEncoder
+encoder_conf:
+ output_size: 512
+ attention_heads: 4
+ linear_units: 2048
+ num_blocks: 50
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ attention_dropout_rate: 0.1
+ input_layer: pe
+ pos_enc_class: SinusoidalPositionEncoder
+ normalize_before: true
+ kernel_size: 11
+ sanm_shfit: 0
+ selfattention_layer_type: sanm
+
+# decoder
+decoder: ParaformerSANMDecoder
+decoder_conf:
+ attention_heads: 4
+ linear_units: 2048
+ num_blocks: 16
+ dropout_rate: 0.1
+ positional_dropout_rate: 0.1
+ self_attention_dropout_rate: 0.1
+ src_attention_dropout_rate: 0.1
+ att_layer_num: 16
+ kernel_size: 11
+ sanm_shfit: 0
+
+predictor: CifPredictorV3
+predictor_conf:
+ idim: 512
+ threshold: 1.0
+ l_order: 1
+ r_order: 1
+ tail_threshold: 0.45
+ smooth_factor2: 0.25
+ noise_threshold2: 0.01
+ upsample_times: 3
+ use_cif1_cnn: false
+ upsample_type: cnn_blstm
+
+# frontend related
+frontend: WavFrontend
+frontend_conf:
+ fs: 16000
+ window: hamming
+ n_mels: 80
+ frame_length: 25
+ frame_shift: 10
+ lfr_m: 7
+ lfr_n: 6
+
+specaug: SpecAugLFR
+specaug_conf:
+ apply_time_warp: false
+ time_warp_window: 5
+ time_warp_mode: bicubic
+ apply_freq_mask: true
+ freq_mask_width_range:
+ - 0
+ - 30
+ lfr_rate: 6
+ num_freq_mask: 1
+ apply_time_mask: true
+ time_mask_width_range:
+ - 0
+ - 12
+ num_time_mask: 1
+
+train_conf:
+ accum_grad: 1
+ grad_clip: 5
+ max_epoch: 150
+ val_scheduler_criterion:
+ - valid
+ - acc
+ best_model_criterion:
+ - - valid
+ - acc
+ - max
+ keep_nbest_models: 10
+ log_interval: 50
+
+optim: adam
+optim_conf:
+ lr: 0.0005
+scheduler: warmuplr
+scheduler_conf:
+ warmup_steps: 30000
+
+dataset: AudioDataset
+dataset_conf:
+ index_ds: IndexDSJsonl
+ batch_sampler: DynamicBatchLocalShuffleSampler
+ batch_type: example # example or length
+ batch_size: 1 # if batch_type is example, batch_size is the numbers of samples; if length, batch_size is source_token_len+target_token_len;
+ max_token_length: 2048 # filter samples if source_token_len+target_token_len > max_token_length,
+ buffer_size: 500
+ shuffle: True
+ num_workers: 0
+
+tokenizer: CharTokenizer
+tokenizer_conf:
+ unk_symbol: <unk>
+ split_with_space: true
+
+
+ctc_conf:
+ dropout_rate: 0.0
+ ctc_type: builtin
+ reduce: true
+ ignore_nan_grad: true
+normalize: null
diff --git a/funasr/models/paraformer_online/__init__.py b/funasr/models/paraformer_streaming/__init__.py
similarity index 100%
rename from funasr/models/paraformer_online/__init__.py
rename to funasr/models/paraformer_streaming/__init__.py
diff --git a/funasr/models/paraformer_online/model.py b/funasr/models/paraformer_streaming/model.py
similarity index 99%
rename from funasr/models/paraformer_online/model.py
rename to funasr/models/paraformer_streaming/model.py
index 27871bc..bb24469 100644
--- a/funasr/models/paraformer_online/model.py
+++ b/funasr/models/paraformer_streaming/model.py
@@ -857,7 +857,7 @@
return results, meta_data
-class ParaformerOnline(Paraformer):
+class ParaformerStreaming(Paraformer):
"""
Author: Speech Lab of DAMO Academy, Alibaba Group
Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
diff --git a/funasr/models/paraformer_online/sanm_decoder.py b/funasr/models/paraformer_streaming/sanm_decoder.py
similarity index 100%
rename from funasr/models/paraformer_online/sanm_decoder.py
rename to funasr/models/paraformer_streaming/sanm_decoder.py
diff --git a/funasr/utils/vad_utils.py b/funasr/utils/vad_utils.py
index 9135513..f84e2b9 100644
--- a/funasr/utils/vad_utils.py
+++ b/funasr/utils/vad_utils.py
@@ -15,4 +15,17 @@
feats_pad = pad_sequence(speech_list, batch_first=True, padding_value=0.0)
speech_lengths_pad = torch.Tensor(speech_lengths_list).int()
return feats_pad, speech_lengths_pad
-
+
+
+def slice_padding_audio_samples(speech, speech_lengths, vad_segments):
+ speech_list = []
+ speech_lengths_list = []
+ for i, segment in enumerate(vad_segments):
+ bed_idx = int(segment[0][0] * 16)
+ end_idx = min(int(segment[0][1] * 16), speech_lengths)
+ speech_i = speech[bed_idx: end_idx]
+ speech_lengths_i = end_idx - bed_idx
+ speech_list.append(speech_i)
+ speech_lengths_list.append(speech_lengths_i)
+
+ return speech_list, speech_lengths_list
\ No newline at end of file
--
Gitblit v1.9.1