From 06f937b53e88502e5d254fb6e80a5fb9ee3b25e9 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 13 三月 2023 18:32:38 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/models/e2e_tp.py | 175 ++++++++
egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer_online.py | 32 +
funasr/models/e2e_asr_paraformer.py | 4
egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/infer.py | 12
funasr/tasks/asr.py | 82 ++++
egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer_online.py | 32 +
funasr/bin/tp_inference.py | 59 --
funasr/models/frontend/wav_frontend.py | 282 +++++++++++++-
funasr/bin/asr_inference_paraformer.py | 9
funasr/bin/vad_inference_online.py | 344 +++++++++++++++++
funasr/models/e2e_vad.py | 40 +
funasr/utils/timestamp_tools.py | 56 +-
funasr/bin/build_trainer.py | 4
funasr/bin/asr_inference_paraformer_vad_punc.py | 10
egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/README.md | 25 +
15 files changed, 1,056 insertions(+), 110 deletions(-)
diff --git a/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/README.md b/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/README.md
new file mode 100644
index 0000000..5488aaa
--- /dev/null
+++ b/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/README.md
@@ -0,0 +1,25 @@
+# ModelScope Model
+
+## How to finetune and infer using a pretrained ModelScope Model
+
+### Inference
+
+Or you can use the finetuned model for inference directly.
+
+- Setting parameters in `infer.py`
+ - <strong>audio_in:</strong> # support wav, url, bytes, and parsed audio format.
+ - <strong>text_in:</strong> # support text, text url.
+ - <strong>output_dir:</strong> # If the input format is wav.scp, it needs to be set.
+
+- Then you can run the pipeline to infer with:
+```python
+ python infer.py
+```
+
+
+Modify inference related parameters in vad.yaml.
+
+- max_end_silence_time: The end-point silence duration to judge the end of sentence, the parameter range is 500ms~6000ms, and the default value is 800ms
+- speech_noise_thres: The balance of speech and silence scores, the parameter range is (-1,1)
+ - The value tends to -1, the greater probability of noise being judged as speech
+ - The value tends to 1, the greater probability of speech being judged as noise
diff --git a/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/infer.py b/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/infer.py
new file mode 100644
index 0000000..ff42e68
--- /dev/null
+++ b/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/infer.py
@@ -0,0 +1,12 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+
+inference_pipline = pipeline(
+ task=Tasks.speech_timestamp,
+ model='damo/speech_timestamp_prediction-v1-16k-offline',
+ output_dir='./tmp')
+
+rec_result = inference_pipline(
+ audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_timestamps.wav',
+ text_in='涓� 涓� 涓� 澶� 骞� 娲� 鍥� 瀹� 涓� 浠� 涔� 璺� 鍒� 瑗� 澶� 骞� 娲� 鏉� 浜� 鍛�')
+print(rec_result)
\ No newline at end of file
diff --git a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer_online.py b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer_online.py
new file mode 100644
index 0000000..bcf764b
--- /dev/null
+++ b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-16k-common/infer_online.py
@@ -0,0 +1,32 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+import soundfile
+
+
+if __name__ == '__main__':
+ output_dir = None
+ inference_pipline = pipeline(
+ task=Tasks.voice_activity_detection,
+ model="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
+ model_revision='v1.1.9',
+ output_dir=None,
+ batch_size=1,
+ )
+ speech, sample_rate = soundfile.read("./vad_example_16k.wav")
+ speech_length = speech.shape[0]
+
+ sample_offset = 0
+
+ step = 160 * 10
+ param_dict = {'in_cache': dict()}
+ for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
+ if sample_offset + step >= speech_length - 1:
+ step = speech_length - sample_offset
+ is_final = True
+ else:
+ is_final = False
+ param_dict['is_final'] = is_final
+ segments_result = inference_pipline(audio_in=speech[sample_offset: sample_offset + step],
+ param_dict=param_dict)
+ print(segments_result)
+
diff --git a/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer_online.py b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer_online.py
new file mode 100644
index 0000000..9d12b34
--- /dev/null
+++ b/egs_modelscope/vad/speech_fsmn_vad_zh-cn-8k-common/infer_online.py
@@ -0,0 +1,32 @@
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+import soundfile
+
+
+if __name__ == '__main__':
+ output_dir = None
+ inference_pipline = pipeline(
+ task=Tasks.voice_activity_detection,
+ model="damo/speech_fsmn_vad_zh-cn-8k-common",
+ model_revision='v1.1.9',
+ output_dir='./output_dir',
+ batch_size=1,
+ )
+ speech, sample_rate = soundfile.read("./vad_example_8k.wav")
+ speech_length = speech.shape[0]
+
+ sample_offset = 0
+
+ step = 80 * 10
+ param_dict = {'in_cache': dict()}
+ for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
+ if sample_offset + step >= speech_length - 1:
+ step = speech_length - sample_offset
+ is_final = True
+ else:
+ is_final = False
+ param_dict['is_final'] = is_final
+ segments_result = inference_pipline(audio_in=speech[sample_offset: sample_offset + step],
+ param_dict=param_dict)
+ print(segments_result)
+
diff --git a/funasr/bin/asr_inference_paraformer.py b/funasr/bin/asr_inference_paraformer.py
index 8265fc5..6413d92 100644
--- a/funasr/bin/asr_inference_paraformer.py
+++ b/funasr/bin/asr_inference_paraformer.py
@@ -42,7 +42,7 @@
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
-from funasr.utils.timestamp_tools import time_stamp_lfr6_pl, time_stamp_sentence
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
class Speech2Text:
@@ -245,7 +245,7 @@
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
if isinstance(self.asr_model, BiCifParaformer):
- _, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len,
+ _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
pre_token_length) # test no bias cif2
results = []
@@ -291,7 +291,10 @@
text = None
if isinstance(self.asr_model, BiCifParaformer):
- timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
+ _, timestamp = ts_prediction_lfr6_standard(us_alphas[i],
+ us_peaks[i],
+ copy.copy(token),
+ vad_offset=begin_time)
results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor))
else:
results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor))
diff --git a/funasr/bin/asr_inference_paraformer_vad_punc.py b/funasr/bin/asr_inference_paraformer_vad_punc.py
index 1320877..a0e7b47 100644
--- a/funasr/bin/asr_inference_paraformer_vad_punc.py
+++ b/funasr/bin/asr_inference_paraformer_vad_punc.py
@@ -44,11 +44,10 @@
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.tasks.vad import VADTask
from funasr.bin.vad_inference import Speech2VadSegment
-from funasr.utils.timestamp_tools import time_stamp_lfr6_pl
+from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard
from funasr.bin.punctuation_infer import Text2Punc
from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer
-from funasr.utils.timestamp_tools import time_stamp_sentence
header_colors = '\033[95m'
end_colors = '\033[0m'
@@ -257,7 +256,7 @@
decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1]
if isinstance(self.asr_model, BiCifParaformer):
- _, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len,
+ _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len,
pre_token_length) # test no bias cif2
results = []
@@ -303,7 +302,10 @@
text = None
if isinstance(self.asr_model, BiCifParaformer):
- timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time)
+ _, timestamp = ts_prediction_lfr6_standard(us_alphas[i],
+ us_peaks[i],
+ copy.copy(token),
+ vad_offset=begin_time)
results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor))
else:
results.append((text, token, token_int, enc_len_batch_total, lfr_factor))
diff --git a/funasr/bin/build_trainer.py b/funasr/bin/build_trainer.py
index 8dee758..94f7262 100644
--- a/funasr/bin/build_trainer.py
+++ b/funasr/bin/build_trainer.py
@@ -28,7 +28,9 @@
elif mode == "uniasr":
from funasr.tasks.asr import ASRTaskUniASR as ASRTask
elif mode == "mfcca":
- from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
+ from funasr.tasks.asr import ASRTaskMFCCA as ASRTask
+ elif mode == "tp":
+ from funasr.tasks.asr import ASRTaskAligner as ASRTask
else:
raise ValueError("Unknown mode: {}".format(mode))
parser = ASRTask.get_parser()
diff --git a/funasr/bin/tp_inference.py b/funasr/bin/tp_inference.py
index e7a1f1b..e374a22 100644
--- a/funasr/bin/tp_inference.py
+++ b/funasr/bin/tp_inference.py
@@ -28,6 +28,8 @@
from funasr.utils.types import str_or_none
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.text.token_id_converter import TokenIDConverter
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
+
header_colors = '\033[95m'
end_colors = '\033[0m'
@@ -37,61 +39,6 @@
'audio_fs': 16000,
'model_fs': 16000
}
-
-def time_stamp_lfr6_advance(us_alphas, us_cif_peak, char_list):
- START_END_THRESHOLD = 5
- MAX_TOKEN_DURATION = 12
- TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
- if len(us_cif_peak.shape) == 2:
- alphas, cif_peak = us_alphas[0], us_cif_peak[0] # support inference batch_size=1 only
- else:
- alphas, cif_peak = us_alphas, us_cif_peak
- num_frames = cif_peak.shape[0]
- if char_list[-1] == '</s>':
- char_list = char_list[:-1]
- # char_list = [i for i in text]
- timestamp_list = []
- new_char_list = []
- # for bicif model trained with large data, cif2 actually fires when a character starts
- # so treat the frames between two peaks as the duration of the former token
- fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() - 3.2 # total offset
- num_peak = len(fire_place)
- assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
- # begin silence
- if fire_place[0] > START_END_THRESHOLD:
- # char_list.insert(0, '<sil>')
- timestamp_list.append([0.0, fire_place[0]*TIME_RATE])
- new_char_list.append('<sil>')
- # tokens timestamp
- for i in range(len(fire_place)-1):
- new_char_list.append(char_list[i])
- if MAX_TOKEN_DURATION < 0 or fire_place[i+1] - fire_place[i] < MAX_TOKEN_DURATION:
- timestamp_list.append([fire_place[i]*TIME_RATE, fire_place[i+1]*TIME_RATE])
- else:
- # cut the duration to token and sil of the 0-weight frames last long
- _split = fire_place[i] + MAX_TOKEN_DURATION
- timestamp_list.append([fire_place[i]*TIME_RATE, _split*TIME_RATE])
- timestamp_list.append([_split*TIME_RATE, fire_place[i+1]*TIME_RATE])
- new_char_list.append('<sil>')
- # tail token and end silence
- # new_char_list.append(char_list[-1])
- if num_frames - fire_place[-1] > START_END_THRESHOLD:
- _end = (num_frames + fire_place[-1]) * 0.5
- # _end = fire_place[-1]
- timestamp_list[-1][1] = _end*TIME_RATE
- timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE])
- new_char_list.append("<sil>")
- else:
- timestamp_list[-1][1] = num_frames*TIME_RATE
- assert len(new_char_list) == len(timestamp_list)
- res_str = ""
- for char, timestamp in zip(new_char_list, timestamp_list):
- res_str += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5])
- res = []
- for char, timestamp in zip(new_char_list, timestamp_list):
- if char != '<sil>':
- res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
- return res_str, res
class SpeechText2Timestamp:
@@ -315,7 +262,7 @@
for batch_id in range(_bs):
key = keys[batch_id]
token = speechtext2timestamp.converter.ids2tokens(batch['text'][batch_id])
- ts_str, ts_list = time_stamp_lfr6_advance(us_alphas[batch_id], us_cif_peak[batch_id], token)
+ ts_str, ts_list = ts_prediction_lfr6_standard(us_alphas[batch_id], us_cif_peak[batch_id], token, force_time_shift=-3.0)
logging.warning(ts_str)
item = {'key': key, 'value': ts_str, 'timestamp':ts_list}
tp_result_list.append(item)
diff --git a/funasr/bin/vad_inference_online.py b/funasr/bin/vad_inference_online.py
new file mode 100644
index 0000000..cee1929
--- /dev/null
+++ b/funasr/bin/vad_inference_online.py
@@ -0,0 +1,344 @@
+import argparse
+import logging
+import sys
+import json
+from pathlib import Path
+from typing import Any
+from typing import List
+from typing import Optional
+from typing import Sequence
+from typing import Tuple
+from typing import Union
+from typing import Dict
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.fileio.datadir_writer import DatadirWriter
+from funasr.tasks.vad import VADTask
+from funasr.torch_utils.device_funcs import to_device
+from funasr.torch_utils.set_all_random_seed import set_all_random_seed
+from funasr.utils import config_argparse
+from funasr.utils.cli_utils import get_commandline_args
+from funasr.utils.types import str2bool
+from funasr.utils.types import str2triple_str
+from funasr.utils.types import str_or_none
+from funasr.models.frontend.wav_frontend import WavFrontendOnline
+from funasr.models.frontend.wav_frontend import WavFrontend
+from funasr.bin.vad_inference import Speech2VadSegment
+
+header_colors = '\033[95m'
+end_colors = '\033[0m'
+
+global_asr_language: str = 'zh-cn'
+global_sample_rate: Union[int, Dict[Any, int]] = {
+ 'audio_fs': 16000,
+ 'model_fs': 16000
+}
+
+
+class Speech2VadSegmentOnline(Speech2VadSegment):
+ """Speech2VadSegmentOnline class
+
+ Examples:
+ >>> import soundfile
+ >>> speech2segment = Speech2VadSegmentOnline("vad_config.yml", "vad.pt")
+ >>> audio, rate = soundfile.read("speech.wav")
+ >>> speech2segment(audio)
+ [[10, 230], [245, 450], ...]
+
+ """
+ def __init__(self, **kwargs):
+ super(Speech2VadSegmentOnline, self).__init__(**kwargs)
+ vad_cmvn_file = kwargs.get('vad_cmvn_file', None)
+ self.frontend = None
+ if self.vad_infer_args.frontend is not None:
+ self.frontend = WavFrontendOnline(cmvn_file=vad_cmvn_file, **self.vad_infer_args.frontend_conf)
+
+
+ @torch.no_grad()
+ def __call__(
+ self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None,
+ in_cache: Dict[str, torch.Tensor] = dict(), is_final: bool = False
+ ) -> Tuple[torch.Tensor, List[List[int]], torch.Tensor]:
+ """Inference
+
+ Args:
+ speech: Input speech data
+ Returns:
+ text, token, token_int, hyp
+
+ """
+ assert check_argument_types()
+
+ # Input as audio signal
+ if isinstance(speech, np.ndarray):
+ speech = torch.tensor(speech)
+ batch_size = speech.shape[0]
+ segments = [[]] * batch_size
+ if self.frontend is not None:
+ feats, feats_len = self.frontend.forward(speech, speech_lengths, is_final)
+ fbanks, _ = self.frontend.get_fbank()
+ else:
+ raise Exception("Need to extract feats first, please configure frontend configuration")
+ if feats.shape[0]:
+ feats = to_device(feats, device=self.device)
+ feats_len = feats_len.int()
+ waveforms = self.frontend.get_waveforms()
+
+ batch = {
+ "feats": feats,
+ "waveform": waveforms,
+ "in_cache": in_cache,
+ "is_final": is_final
+ }
+ # a. To device
+ batch = to_device(batch, device=self.device)
+ segments, in_cache = self.vad_model(**batch)
+ # in_cache.update(batch['in_cache'])
+ # in_cache = {key: value for key, value in batch['in_cache'].items()}
+ return fbanks, segments, in_cache
+
+
+def inference(
+ batch_size: int,
+ ngpu: int,
+ log_level: Union[int, str],
+ data_path_and_name_and_type,
+ vad_infer_config: Optional[str],
+ vad_model_file: Optional[str],
+ vad_cmvn_file: Optional[str] = None,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ key_file: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ num_workers: int = 1,
+ **kwargs,
+):
+ inference_pipeline = inference_modelscope(
+ batch_size=batch_size,
+ ngpu=ngpu,
+ log_level=log_level,
+ vad_infer_config=vad_infer_config,
+ vad_model_file=vad_model_file,
+ vad_cmvn_file=vad_cmvn_file,
+ key_file=key_file,
+ allow_variable_data_keys=allow_variable_data_keys,
+ output_dir=output_dir,
+ dtype=dtype,
+ seed=seed,
+ num_workers=num_workers,
+ **kwargs,
+ )
+ return inference_pipeline(data_path_and_name_and_type, raw_inputs)
+
+
+def inference_modelscope(
+ batch_size: int,
+ ngpu: int,
+ log_level: Union[int, str],
+ # data_path_and_name_and_type,
+ vad_infer_config: Optional[str],
+ vad_model_file: Optional[str],
+ vad_cmvn_file: Optional[str] = None,
+ # raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ key_file: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ output_dir: Optional[str] = None,
+ dtype: str = "float32",
+ seed: int = 0,
+ num_workers: int = 1,
+ **kwargs,
+):
+ assert check_argument_types()
+ if batch_size > 1:
+ raise NotImplementedError("batch decoding is not implemented")
+ if ngpu > 1:
+ raise NotImplementedError("only single GPU decoding is supported")
+
+ logging.basicConfig(
+ level=log_level,
+ format="%(asctime)s (%(module)s:%(lineno)d) %(levelname)s: %(message)s",
+ )
+
+ if ngpu >= 1 and torch.cuda.is_available():
+ device = "cuda"
+ else:
+ device = "cpu"
+
+ # 1. Set random-seed
+ set_all_random_seed(seed)
+
+ # 2. Build speech2vadsegment
+ speech2vadsegment_kwargs = dict(
+ vad_infer_config=vad_infer_config,
+ vad_model_file=vad_model_file,
+ vad_cmvn_file=vad_cmvn_file,
+ device=device,
+ dtype=dtype,
+ )
+ logging.info("speech2vadsegment_kwargs: {}".format(speech2vadsegment_kwargs))
+ speech2vadsegment = Speech2VadSegmentOnline(**speech2vadsegment_kwargs)
+
+ def _forward(
+ data_path_and_name_and_type,
+ raw_inputs: Union[np.ndarray, torch.Tensor] = None,
+ output_dir_v2: Optional[str] = None,
+ fs: dict = None,
+ param_dict: dict = None,
+ ):
+ # 3. Build data-iterator
+ if data_path_and_name_and_type is None and raw_inputs is not None:
+ if isinstance(raw_inputs, torch.Tensor):
+ raw_inputs = raw_inputs.numpy()
+ data_path_and_name_and_type = [raw_inputs, "speech", "waveform"]
+ loader = VADTask.build_streaming_iterator(
+ data_path_and_name_and_type,
+ dtype=dtype,
+ batch_size=batch_size,
+ key_file=key_file,
+ num_workers=num_workers,
+ preprocess_fn=VADTask.build_preprocess_fn(speech2vadsegment.vad_infer_args, False),
+ collate_fn=VADTask.build_collate_fn(speech2vadsegment.vad_infer_args, False),
+ allow_variable_data_keys=allow_variable_data_keys,
+ inference=True,
+ )
+
+ finish_count = 0
+ file_count = 1
+ # 7 .Start for-loop
+ # FIXME(kamo): The output format should be discussed about
+ output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
+ if output_path is not None:
+ writer = DatadirWriter(output_path)
+ ibest_writer = writer[f"1best_recog"]
+ else:
+ writer = None
+ ibest_writer = None
+
+ vad_results = []
+ batch_in_cache = param_dict['in_cache'] if param_dict is not None else dict()
+ is_final = param_dict['is_final'] if param_dict is not None else False
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ batch['in_cache'] = batch_in_cache
+ batch['is_final'] = is_final
+
+ # do vad segment
+ _, results, param_dict['in_cache'] = speech2vadsegment(**batch)
+ # param_dict['in_cache'] = batch['in_cache']
+ if results:
+ for i, _ in enumerate(keys):
+ results[i] = json.dumps(results[i])
+ item = {'key': keys[i], 'value': results[i]}
+ vad_results.append(item)
+ if writer is not None:
+ results[i] = json.loads(results[i])
+ ibest_writer["text"][keys[i]] = "{}".format(results[i])
+
+ return vad_results
+
+ return _forward
+
+
+def get_parser():
+ parser = config_argparse.ArgumentParser(
+ description="VAD Decoding",
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
+ )
+
+ # Note(kamo): Use '_' instead of '-' as separator.
+ # '-' is confusing if written in yaml.
+ parser.add_argument(
+ "--log_level",
+ type=lambda x: x.upper(),
+ default="INFO",
+ choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"),
+ help="The verbose level of logging",
+ )
+
+ parser.add_argument("--output_dir", type=str, required=False)
+ parser.add_argument(
+ "--ngpu",
+ type=int,
+ default=0,
+ help="The number of gpus. 0 indicates CPU mode",
+ )
+ parser.add_argument(
+ "--gpuid_list",
+ type=str,
+ default="",
+ help="The visible gpus",
+ )
+ parser.add_argument("--seed", type=int, default=0, help="Random seed")
+ parser.add_argument(
+ "--dtype",
+ default="float32",
+ choices=["float16", "float32", "float64"],
+ help="Data type",
+ )
+ parser.add_argument(
+ "--num_workers",
+ type=int,
+ default=1,
+ help="The number of workers used for DataLoader",
+ )
+
+ group = parser.add_argument_group("Input data related")
+ group.add_argument(
+ "--data_path_and_name_and_type",
+ type=str2triple_str,
+ required=False,
+ action="append",
+ )
+ group.add_argument("--raw_inputs", type=list, default=None)
+ # example=[{'key':'EdevDEWdIYQ_0021','file':'/mnt/data/jiangyu.xzy/test_data/speech_io/SPEECHIO_ASR_ZH00007_zhibodaihuo/wav/EdevDEWdIYQ_0021.wav'}])
+ group.add_argument("--key_file", type=str_or_none)
+ group.add_argument("--allow_variable_data_keys", type=str2bool, default=False)
+
+ group = parser.add_argument_group("The model configuration related")
+ group.add_argument(
+ "--vad_infer_config",
+ type=str,
+ help="VAD infer configuration",
+ )
+ group.add_argument(
+ "--vad_model_file",
+ type=str,
+ help="VAD model parameter file",
+ )
+ group.add_argument(
+ "--vad_cmvn_file",
+ type=str,
+ help="Global cmvn file",
+ )
+
+ group = parser.add_argument_group("infer related")
+ group.add_argument(
+ "--batch_size",
+ type=int,
+ default=1,
+ help="The batch size for inference",
+ )
+
+ return parser
+
+
+def main(cmd=None):
+ print(get_commandline_args(), file=sys.stderr)
+ parser = get_parser()
+ args = parser.parse_args(cmd)
+ kwargs = vars(args)
+ kwargs.pop("config", None)
+ inference(**kwargs)
+
+
+if __name__ == "__main__":
+ main()
diff --git a/funasr/models/e2e_asr_paraformer.py b/funasr/models/e2e_asr_paraformer.py
index 8439f40..44c9de3 100644
--- a/funasr/models/e2e_asr_paraformer.py
+++ b/funasr/models/e2e_asr_paraformer.py
@@ -926,10 +926,10 @@
def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
encoder_out.device)
- ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out,
+ ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
encoder_out_mask,
token_num)
- return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
+ return ds_alphas, ds_cif_peak, us_alphas, us_peaks
def forward(
self,
diff --git a/funasr/models/e2e_tp.py b/funasr/models/e2e_tp.py
new file mode 100644
index 0000000..887439c
--- /dev/null
+++ b/funasr/models/e2e_tp.py
@@ -0,0 +1,175 @@
+import logging
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+from typing import Union
+
+import torch
+import numpy as np
+from typeguard import check_argument_types
+
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.predictor.cif import mae_loss
+from funasr.modules.add_sos_eos import add_sos_eos
+from funasr.modules.nets_utils import make_pad_mask, pad_list
+from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.predictor.cif import CifPredictorV3
+
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
+
+class TimestampPredictor(AbsESPnetModel):
+ """
+ Author: Speech Lab, Alibaba Group, China
+ """
+
+ def __init__(
+ self,
+ frontend: Optional[AbsFrontend],
+ encoder: AbsEncoder,
+ predictor: CifPredictorV3,
+ predictor_bias: int = 0,
+ token_list=None,
+ ):
+ assert check_argument_types()
+
+ super().__init__()
+ # note that eos is the same as sos (equivalent ID)
+
+ self.frontend = frontend
+ self.encoder = encoder
+ self.encoder.interctc_use_conditioning = False
+
+ self.predictor = predictor
+ self.predictor_bias = predictor_bias
+ self.criterion_pre = mae_loss()
+ self.token_list = token_list
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Frontend + Encoder + Decoder + Calc loss
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ text: (Batch, Length)
+ text_lengths: (Batch,)
+ """
+ assert text_lengths.dim() == 1, text_lengths.shape
+ # Check that batch_size is unified
+ assert (
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ == text.shape[0]
+ == text_lengths.shape[0]
+ ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape)
+ batch_size = speech.shape[0]
+ # for data-parallel
+ text = text[:, : text_lengths.max()]
+ speech = speech[:, :speech_lengths.max()]
+
+ # 1. Encoder
+ encoder_out, encoder_out_lens = self.encode(speech, speech_lengths)
+
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ if self.predictor_bias == 1:
+ _, text = add_sos_eos(text, 1, 2, -1)
+ text_lengths = text_lengths + self.predictor_bias
+ _, _, _, _, pre_token_length2 = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=-1)
+
+ # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length)
+ loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length2), pre_token_length2)
+
+ loss = loss_pre
+ stats = dict()
+
+ # Collect Attn branch stats
+ stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None
+ stats["loss"] = torch.clone(loss.detach())
+
+ # force_gatherable: to-device and to-tensor if scalar for DataParallel
+ loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device)
+ return loss, stats, weight
+
+ def encode(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ """Frontend + Encoder. Note that this method is used by asr_inference.py
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ """
+ with autocast(False):
+ # 1. Extract feats
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+ # 4. Forward encoder
+ # feats: (Batch, Length, Dim)
+ # -> encoder_out: (Batch, Length2, Dim2)
+ encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
+
+ return encoder_out, encoder_out_lens
+
+ def _extract_feats(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ assert speech_lengths.dim() == 1, speech_lengths.shape
+
+ # for data-parallel
+ speech = speech[:, : speech_lengths.max()]
+ if self.frontend is not None:
+ # Frontend
+ # e.g. STFT and Feature extract
+ # data_loader may send time-domain signal in this case
+ # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
+ feats, feats_lengths = self.frontend(speech, speech_lengths)
+ else:
+ # No frontend and no feature extract
+ feats, feats_lengths = speech, speech_lengths
+ return feats, feats_lengths
+
+ def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
+ encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
+ encoder_out.device)
+ ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
+ encoder_out_mask,
+ token_num)
+ return ds_alphas, ds_cif_peak, us_alphas, us_peaks
+
+ def collect_feats(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ text: torch.Tensor,
+ text_lengths: torch.Tensor,
+ ) -> Dict[str, torch.Tensor]:
+ if self.extract_feats_in_collect_stats:
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+ else:
+ # Generate dummy stats if extract_feats_in_collect_stats is False
+ logging.warning(
+ "Generating dummy stats for feats and feats_lengths, "
+ "because encoder_conf.extract_feats_in_collect_stats is "
+ f"{self.extract_feats_in_collect_stats}"
+ )
+ feats, feats_lengths = speech, speech_lengths
+ return {"feats": feats, "feats_lengths": feats_lengths}
diff --git a/funasr/models/e2e_vad.py b/funasr/models/e2e_vad.py
index b9be89a..2c5673c 100755
--- a/funasr/models/e2e_vad.py
+++ b/funasr/models/e2e_vad.py
@@ -215,6 +215,7 @@
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
self.noise_average_decibel = -100.0
self.pre_end_silence_detected = False
+ self.next_seg = True
self.output_data_buf = []
self.output_data_buf_offset = 0
@@ -244,6 +245,7 @@
self.sil_pdf_ids = self.vad_opts.sil_pdf_ids
self.noise_average_decibel = -100.0
self.pre_end_silence_detected = False
+ self.next_seg = True
self.output_data_buf = []
self.output_data_buf_offset = 0
@@ -441,7 +443,7 @@
- 1)) / self.vad_opts.noise_frame_num_used_for_snr
return frame_state
-
+
def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
is_final: bool = False
) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
@@ -470,6 +472,42 @@
self.AllResetDetection()
return segments, in_cache
+ def forward_online(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(),
+ is_final: bool = False
+ ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]:
+ self.waveform = waveform # compute decibel for each frame
+ self.ComputeDecibel()
+ self.ComputeScores(feats, in_cache)
+ if not is_final:
+ self.DetectCommonFrames()
+ else:
+ self.DetectLastFrames()
+ segments = []
+ for batch_num in range(0, feats.shape[0]): # only support batch_size = 1 now
+ segment_batch = []
+ if len(self.output_data_buf) > 0:
+ for i in range(self.output_data_buf_offset, len(self.output_data_buf)):
+ if not self.output_data_buf[i].contain_seg_start_point:
+ continue
+ if not self.next_seg and not self.output_data_buf[i].contain_seg_end_point:
+ continue
+ start_ms = self.output_data_buf[i].start_ms if self.next_seg else -1
+ if self.output_data_buf[i].contain_seg_end_point:
+ end_ms = self.output_data_buf[i].end_ms
+ self.next_seg = True
+ self.output_data_buf_offset += 1
+ else:
+ end_ms = -1
+ self.next_seg = False
+ segment = [start_ms, end_ms]
+ segment_batch.append(segment)
+ if segment_batch:
+ segments.append(segment_batch)
+ if is_final:
+ # reset class variables and clear the dict for the next query
+ self.AllResetDetection()
+ return segments, in_cache
+
def DetectCommonFrames(self) -> int:
if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected:
return 0
diff --git a/funasr/models/frontend/wav_frontend.py b/funasr/models/frontend/wav_frontend.py
index ed8cb36..445efca 100644
--- a/funasr/models/frontend/wav_frontend.py
+++ b/funasr/models/frontend/wav_frontend.py
@@ -1,6 +1,6 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
-
+from abc import ABC
from typing import Tuple
import numpy as np
@@ -33,9 +33,9 @@
means = np.array(means_list).astype(np.float)
vars = np.array(vars_list).astype(np.float)
cmvn = np.array([means, vars])
- cmvn = torch.as_tensor(cmvn)
- return cmvn
-
+ cmvn = torch.as_tensor(cmvn)
+ return cmvn
+
def apply_cmvn(inputs, cmvn_file): # noqa
"""
@@ -78,21 +78,22 @@
class WavFrontend(AbsFrontend):
"""Conventional frontend structure for ASR.
"""
+
def __init__(
- self,
- cmvn_file: str = None,
- fs: int = 16000,
- window: str = 'hamming',
- n_mels: int = 80,
- frame_length: int = 25,
- frame_shift: int = 10,
- filter_length_min: int = -1,
- filter_length_max: int = -1,
- lfr_m: int = 1,
- lfr_n: int = 1,
- dither: float = 1.0,
- snip_edges: bool = True,
- upsacle_samples: bool = True,
+ self,
+ cmvn_file: str = None,
+ fs: int = 16000,
+ window: str = 'hamming',
+ n_mels: int = 80,
+ frame_length: int = 25,
+ frame_shift: int = 10,
+ filter_length_min: int = -1,
+ filter_length_max: int = -1,
+ lfr_m: int = 1,
+ lfr_n: int = 1,
+ dither: float = 1.0,
+ snip_edges: bool = True,
+ upsacle_samples: bool = True,
):
assert check_argument_types()
super().__init__()
@@ -135,11 +136,11 @@
window_type=self.window,
sample_frequency=self.fs,
snip_edges=self.snip_edges)
-
+
if self.lfr_m != 1 or self.lfr_n != 1:
mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
if self.cmvn_file is not None:
- mat = apply_cmvn(mat, self.cmvn_file)
+ mat = apply_cmvn(mat, self.cmvn_file)
feat_length = mat.size(0)
feats.append(mat)
feats_lens.append(feat_length)
@@ -170,7 +171,6 @@
energy_floor=0.0,
window_type=self.window,
sample_frequency=self.fs)
-
feat_length = mat.size(0)
feats.append(mat)
@@ -204,3 +204,243 @@
batch_first=True,
padding_value=0.0)
return feats_pad, feats_lens
+
+
+class WavFrontendOnline(AbsFrontend):
+ """Conventional frontend structure for streaming ASR/VAD.
+ """
+
+ def __init__(
+ self,
+ cmvn_file: str = None,
+ fs: int = 16000,
+ window: str = 'hamming',
+ n_mels: int = 80,
+ frame_length: int = 25,
+ frame_shift: int = 10,
+ filter_length_min: int = -1,
+ filter_length_max: int = -1,
+ lfr_m: int = 1,
+ lfr_n: int = 1,
+ dither: float = 1.0,
+ snip_edges: bool = True,
+ upsacle_samples: bool = True,
+ ):
+ assert check_argument_types()
+ super().__init__()
+ self.fs = fs
+ self.window = window
+ self.n_mels = n_mels
+ self.frame_length = frame_length
+ self.frame_shift = frame_shift
+ self.frame_sample_length = int(self.frame_length * self.fs / 1000)
+ self.frame_shift_sample_length = int(self.frame_shift * self.fs / 1000)
+ self.filter_length_min = filter_length_min
+ self.filter_length_max = filter_length_max
+ self.lfr_m = lfr_m
+ self.lfr_n = lfr_n
+ self.cmvn_file = cmvn_file
+ self.dither = dither
+ self.snip_edges = snip_edges
+ self.upsacle_samples = upsacle_samples
+ self.waveforms = None
+ self.reserve_waveforms = None
+ self.fbanks = None
+ self.fbanks_lens = None
+ self.cmvn = None if self.cmvn_file is None else load_cmvn(self.cmvn_file)
+ self.input_cache = None
+ self.lfr_splice_cache = []
+
+ def output_size(self) -> int:
+ return self.n_mels * self.lfr_m
+
+ @staticmethod
+ def apply_cmvn(inputs: torch.Tensor, cmvn: torch.Tensor) -> torch.Tensor:
+ """
+ Apply CMVN with mvn data
+ """
+
+ device = inputs.device
+ dtype = inputs.dtype
+ frame, dim = inputs.shape
+
+ means = np.tile(cmvn[0:1, :dim], (frame, 1))
+ vars = np.tile(cmvn[1:2, :dim], (frame, 1))
+ inputs += torch.from_numpy(means).type(dtype).to(device)
+ inputs *= torch.from_numpy(vars).type(dtype).to(device)
+
+ return inputs.type(torch.float32)
+
+ @staticmethod
+ # inputs tensor has catted the cache tensor
+ # def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, inputs_lfr_cache: torch.Tensor = None,
+ # is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor, int]:
+ def apply_lfr(inputs: torch.Tensor, lfr_m: int, lfr_n: int, is_final: bool = False) -> Tuple[torch.Tensor, torch.Tensor, int]:
+ """
+ Apply lfr with data
+ """
+
+ LFR_inputs = []
+ # inputs = torch.vstack((inputs_lfr_cache, inputs))
+ T = inputs.shape[0] # include the right context
+ T_lfr = int(np.ceil((T - (lfr_m - 1) // 2) / lfr_n)) # minus the right context: (lfr_m - 1) // 2
+ splice_idx = T_lfr
+ for i in range(T_lfr):
+ if lfr_m <= T - i * lfr_n:
+ LFR_inputs.append((inputs[i * lfr_n:i * lfr_n + lfr_m]).view(1, -1))
+ else: # process last LFR frame
+ if is_final:
+ num_padding = lfr_m - (T - i * lfr_n)
+ frame = (inputs[i * lfr_n:]).view(-1)
+ for _ in range(num_padding):
+ frame = torch.hstack((frame, inputs[-1]))
+ LFR_inputs.append(frame)
+ else:
+ # update splice_idx and break the circle
+ splice_idx = i
+ break
+ splice_idx = min(T - 1, splice_idx * lfr_n)
+ lfr_splice_cache = inputs[splice_idx:, :]
+ LFR_outputs = torch.vstack(LFR_inputs)
+ return LFR_outputs.type(torch.float32), lfr_splice_cache, splice_idx
+
+ @staticmethod
+ def compute_frame_num(sample_length: int, frame_sample_length: int, frame_shift_sample_length: int) -> int:
+ frame_num = int((sample_length - frame_sample_length) / frame_shift_sample_length + 1)
+ return frame_num if frame_num >= 1 and sample_length >= frame_sample_length else 0
+
+ def forward_fbank(
+ self,
+ input: torch.Tensor,
+ input_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ batch_size = input.size(0)
+ if self.input_cache is None:
+ self.input_cache = torch.empty(0)
+ input = torch.cat((self.input_cache, input), dim=1)
+ frame_num = self.compute_frame_num(input.shape[-1], self.frame_sample_length, self.frame_shift_sample_length)
+ # update self.in_cache
+ self.input_cache = input[:, -(input.shape[-1] - frame_num * self.frame_shift_sample_length):]
+ waveforms = torch.empty(0)
+ feats_pad = torch.empty(0)
+ feats_lens = torch.empty(0)
+ if frame_num:
+ waveforms = []
+ feats = []
+ feats_lens = []
+ for i in range(batch_size):
+ waveform = input[i]
+ # we need accurate wave samples that used for fbank extracting
+ waveforms.append(
+ waveform[:((frame_num - 1) * self.frame_shift_sample_length + self.frame_sample_length)])
+ waveform = waveform * (1 << 15)
+ waveform = waveform.unsqueeze(0)
+ mat = kaldi.fbank(waveform,
+ num_mel_bins=self.n_mels,
+ frame_length=self.frame_length,
+ frame_shift=self.frame_shift,
+ dither=self.dither,
+ energy_floor=0.0,
+ window_type=self.window,
+ sample_frequency=self.fs)
+
+ feat_length = mat.size(0)
+ feats.append(mat)
+ feats_lens.append(feat_length)
+
+ waveforms = torch.stack(waveforms)
+ feats_lens = torch.as_tensor(feats_lens)
+ feats_pad = pad_sequence(feats,
+ batch_first=True,
+ padding_value=0.0)
+ self.fbanks = feats_pad
+ import copy
+ self.fbanks_lens = copy.deepcopy(feats_lens)
+ return waveforms, feats_pad, feats_lens
+
+ def get_fbank(self) -> Tuple[torch.Tensor, torch.Tensor]:
+ return self.fbanks, self.fbanks_lens
+
+ def forward_lfr_cmvn(
+ self,
+ input: torch.Tensor,
+ input_lengths: torch.Tensor,
+ is_final: bool = False
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
+ batch_size = input.size(0)
+ feats = []
+ feats_lens = []
+ lfr_splice_frame_idxs = []
+ for i in range(batch_size):
+ mat = input[i, :input_lengths[i], :]
+ if self.lfr_m != 1 or self.lfr_n != 1:
+ # update self.lfr_splice_cache in self.apply_lfr
+ # mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, self.lfr_splice_cache[i],
+ mat, self.lfr_splice_cache[i], lfr_splice_frame_idx = self.apply_lfr(mat, self.lfr_m, self.lfr_n, is_final)
+ if self.cmvn_file is not None:
+ mat = self.apply_cmvn(mat, self.cmvn)
+ feat_length = mat.size(0)
+ feats.append(mat)
+ feats_lens.append(feat_length)
+ lfr_splice_frame_idxs.append(lfr_splice_frame_idx)
+
+ feats_lens = torch.as_tensor(feats_lens)
+ feats_pad = pad_sequence(feats,
+ batch_first=True,
+ padding_value=0.0)
+ lfr_splice_frame_idxs = torch.as_tensor(lfr_splice_frame_idxs)
+ return feats_pad, feats_lens, lfr_splice_frame_idxs
+
+ def forward(
+ self, input: torch.Tensor, input_lengths: torch.Tensor, is_final: bool = False
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ batch_size = input.shape[0]
+ assert batch_size == 1, 'we support to extract feature online only when the batch size is equal to 1 now'
+ waveforms, feats, feats_lengths = self.forward_fbank(input, input_lengths) # input shape: B T D
+ if feats.shape[0]:
+ #if self.reserve_waveforms is None and self.lfr_m > 1:
+ # self.reserve_waveforms = waveforms[:, :(self.lfr_m - 1) // 2 * self.frame_shift_sample_length]
+ self.waveforms = waveforms if self.reserve_waveforms is None else torch.cat((self.reserve_waveforms, waveforms), dim=1)
+ if not self.lfr_splice_cache: # 鍒濆鍖杝plice_cache
+ for i in range(batch_size):
+ self.lfr_splice_cache.append(feats[i][0, :].unsqueeze(dim=0).repeat((self.lfr_m - 1) // 2, 1))
+ # need the number of the input frames + self.lfr_splice_cache[0].shape[0] is greater than self.lfr_m
+ if feats_lengths[0] + self.lfr_splice_cache[0].shape[0] >= self.lfr_m:
+ lfr_splice_cache_tensor = torch.stack(self.lfr_splice_cache) # B T D
+ feats = torch.cat((lfr_splice_cache_tensor, feats), dim=1)
+ feats_lengths += lfr_splice_cache_tensor[0].shape[0]
+ frame_from_waveforms = int((self.waveforms.shape[1] - self.frame_sample_length) / self.frame_shift_sample_length + 1)
+ minus_frame = (self.lfr_m - 1) // 2 if self.reserve_waveforms is None else 0
+ feats, feats_lengths, lfr_splice_frame_idxs = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
+ if self.lfr_m == 1:
+ self.reserve_waveforms = None
+ else:
+ reserve_frame_idx = lfr_splice_frame_idxs[0] - minus_frame
+ # print('reserve_frame_idx: ' + str(reserve_frame_idx))
+ # print('frame_frame: ' + str(frame_from_waveforms))
+ self.reserve_waveforms = self.waveforms[:, reserve_frame_idx * self.frame_shift_sample_length:frame_from_waveforms * self.frame_shift_sample_length]
+ sample_length = (frame_from_waveforms - 1) * self.frame_shift_sample_length + self.frame_sample_length
+ self.waveforms = self.waveforms[:, :sample_length]
+ else:
+ # update self.reserve_waveforms and self.lfr_splice_cache
+ self.reserve_waveforms = self.waveforms[:, :-(self.frame_sample_length - self.frame_shift_sample_length)]
+ for i in range(batch_size):
+ self.lfr_splice_cache[i] = torch.cat((self.lfr_splice_cache[i], feats[i]), dim=0)
+ return torch.empty(0), feats_lengths
+ else:
+ if is_final:
+ self.waveforms = waveforms if self.reserve_waveforms is None else self.reserve_waveforms
+ feats = torch.stack(self.lfr_splice_cache)
+ feats_lengths = torch.zeros(batch_size, dtype=torch.int) + feats.shape[1]
+ feats, feats_lengths, _ = self.forward_lfr_cmvn(feats, feats_lengths, is_final)
+ if is_final:
+ self.cache_reset()
+ return feats, feats_lengths
+
+ def get_waveforms(self):
+ return self.waveforms
+
+ def cache_reset(self):
+ self.reserve_waveforms = None
+ self.input_cache = None
+ self.lfr_splice_cache = []
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index bc89744..36499a2 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -40,6 +40,7 @@
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
from funasr.models.e2e_asr import ESPnetASRModel
from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer
+from funasr.models.e2e_tp import TimestampPredictor
from funasr.models.e2e_asr_mfcca import MFCCA
from funasr.models.e2e_uni_asr import UniASR
from funasr.models.encoder.abs_encoder import AbsEncoder
@@ -124,6 +125,7 @@
bicif_paraformer=BiCifParaformer,
contextual_paraformer=ContextualParaformer,
mfcca=MFCCA,
+ timestamp_prediction=TimestampPredictor,
),
type_check=AbsESPnetModel,
default="asr",
@@ -1245,9 +1247,87 @@
class ASRTaskAligner(ASRTaskParaformer):
+ # If you need more than one optimizers, change this value
+ num_optimizers: int = 1
+
+ # Add variable objects configurations
+ class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --model and --model_conf
+ model_choices,
+ # --encoder and --encoder_conf
+ encoder_choices,
+ # --decoder and --decoder_conf
+ decoder_choices,
+ ]
+
+ # If you need to modify train() or eval() procedures, change Trainer class here
+ trainer = Trainer
+
+ @classmethod
+ def build_model(cls, args: argparse.Namespace):
+ assert check_argument_types()
+ if isinstance(args.token_list, str):
+ with open(args.token_list, encoding="utf-8") as f:
+ token_list = [line.rstrip() for line in f]
+
+ # Overwriting token_list to keep it as "portable".
+ args.token_list = list(token_list)
+ elif isinstance(args.token_list, (tuple, list)):
+ token_list = list(args.token_list)
+ else:
+ raise RuntimeError("token_list must be str or list")
+
+ # 1. frontend
+ if args.input_size is None:
+ # Extract features in the model
+ frontend_class = frontend_choices.get_class(args.frontend)
+ if args.frontend == 'wav_frontend':
+ frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
+ else:
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ # Give features from data-loader
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # 2. Encoder
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(input_size=input_size, **args.encoder_conf)
+
+ # 3. Predictor
+ predictor_class = predictor_choices.get_class(args.predictor)
+ predictor = predictor_class(**args.predictor_conf)
+
+ # 10. Build model
+ try:
+ model_class = model_choices.get_class(args.model)
+ except AttributeError:
+ model_class = model_choices.get_class("asr")
+
+ # 8. Build model
+ model = model_class(
+ frontend=frontend,
+ encoder=encoder,
+ predictor=predictor,
+ token_list=token_list,
+ **args.model_conf,
+ )
+
+ # 11. Initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ assert check_return_type(model)
+ return model
+
@classmethod
def required_data_names(
cls, train: bool = True, inference: bool = False
) -> Tuple[str, ...]:
retval = ("speech", "text")
- return retval
\ No newline at end of file
+ return retval
diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index 4a367f8..f5a238e 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -5,55 +5,69 @@
from typing import Any, List, Tuple, Union
-def time_stamp_lfr6_pl(us_alphas, us_cif_peak, char_list, begin_time=0.0, end_time=None):
+def ts_prediction_lfr6_standard(us_alphas,
+ us_peaks,
+ char_list,
+ vad_offset=0.0,
+ force_time_shift=-1.5
+ ):
if not len(char_list):
return []
START_END_THRESHOLD = 5
+ MAX_TOKEN_DURATION = 12
TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
- if len(us_alphas.shape) == 3:
- alphas, cif_peak = us_alphas[0], us_cif_peak[0] # support inference batch_size=1 only
+ if len(us_alphas.shape) == 2:
+ _, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
else:
- alphas, cif_peak = us_alphas, us_cif_peak
- num_frames = cif_peak.shape[0]
+ _, peaks = us_alphas, us_peaks
+ num_frames = peaks.shape[0]
if char_list[-1] == '</s>':
char_list = char_list[:-1]
- # char_list = [i for i in text]
timestamp_list = []
+ new_char_list = []
# for bicif model trained with large data, cif2 actually fires when a character starts
# so treat the frames between two peaks as the duration of the former token
- fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() - 1.5
+ fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset
num_peak = len(fire_place)
assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1
# begin silence
if fire_place[0] > START_END_THRESHOLD:
- char_list.insert(0, '<sil>')
+ # char_list.insert(0, '<sil>')
timestamp_list.append([0.0, fire_place[0]*TIME_RATE])
+ new_char_list.append('<sil>')
# tokens timestamp
for i in range(len(fire_place)-1):
- # the peak is always a little ahead of the start time
- # timestamp_list.append([(fire_place[i]-1.2)*TIME_RATE, fire_place[i+1]*TIME_RATE])
- timestamp_list.append([(fire_place[i])*TIME_RATE, fire_place[i+1]*TIME_RATE])
- # cut the duration to token and sil of the 0-weight frames last long
+ new_char_list.append(char_list[i])
+ if MAX_TOKEN_DURATION < 0 or fire_place[i+1] - fire_place[i] <= MAX_TOKEN_DURATION:
+ timestamp_list.append([fire_place[i]*TIME_RATE, fire_place[i+1]*TIME_RATE])
+ else:
+ # cut the duration to token and sil of the 0-weight frames last long
+ _split = fire_place[i] + MAX_TOKEN_DURATION
+ timestamp_list.append([fire_place[i]*TIME_RATE, _split*TIME_RATE])
+ timestamp_list.append([_split*TIME_RATE, fire_place[i+1]*TIME_RATE])
+ new_char_list.append('<sil>')
# tail token and end silence
+ # new_char_list.append(char_list[-1])
if num_frames - fire_place[-1] > START_END_THRESHOLD:
- _end = (num_frames + fire_place[-1]) / 2
+ _end = (num_frames + fire_place[-1]) * 0.5
+ # _end = fire_place[-1]
timestamp_list[-1][1] = _end*TIME_RATE
timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE])
- char_list.append("<sil>")
+ new_char_list.append("<sil>")
else:
timestamp_list[-1][1] = num_frames*TIME_RATE
- if begin_time: # add offset time in model with vad
+ if vad_offset: # add offset time in model with vad
for i in range(len(timestamp_list)):
- timestamp_list[i][0] = timestamp_list[i][0] + begin_time / 1000.0
- timestamp_list[i][1] = timestamp_list[i][1] + begin_time / 1000.0
+ timestamp_list[i][0] = timestamp_list[i][0] + vad_offset / 1000.0
+ timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0
res_txt = ""
- for char, timestamp in zip(char_list, timestamp_list):
- res_txt += "{} {} {};".format(char, timestamp[0], timestamp[1])
+ for char, timestamp in zip(new_char_list, timestamp_list):
+ res_txt += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5])
res = []
- for char, timestamp in zip(char_list, timestamp_list):
+ for char, timestamp in zip(new_char_list, timestamp_list):
if char != '<sil>':
res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)])
- return res
+ return res_txt, res
def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed):
--
Gitblit v1.9.1