Merge pull request #218 from alibaba-damo-academy/dev_ts
update timestamp related codes and egs_modelscope
| New file |
| | |
| | | # ModelScope Model |
| | | |
| | | ## How to finetune and infer using a pretrained ModelScope Model |
| | | |
| | | ### Inference |
| | | |
| | | Or you can use the finetuned model for inference directly. |
| | | |
| | | - Setting parameters in `infer.py` |
| | | - <strong>audio_in:</strong> # support wav, url, bytes, and parsed audio format. |
| | | - <strong>text_in:</strong> # support text, text url. |
| | | - <strong>output_dir:</strong> # If the input format is wav.scp, it needs to be set. |
| | | |
| | | - Then you can run the pipeline to infer with: |
| | | ```python |
| | | python infer.py |
| | | ``` |
| | | |
| | | |
| | | Modify inference related parameters in vad.yaml. |
| | | |
| | | - max_end_silence_time: The end-point silence duration to judge the end of sentence, the parameter range is 500ms~6000ms, and the default value is 800ms |
| | | - speech_noise_thres: The balance of speech and silence scores, the parameter range is (-1,1) |
| | | - The value tends to -1, the greater probability of noise being judged as speech |
| | | - The value tends to 1, the greater probability of speech being judged as noise |
| New file |
| | |
| | | from modelscope.pipelines import pipeline |
| | | from modelscope.utils.constant import Tasks |
| | | |
| | | inference_pipline = pipeline( |
| | | task=Tasks.speech_timestamp, |
| | | model='damo/speech_timestamp_prediction-v1-16k-offline', |
| | | output_dir='./tmp') |
| | | |
| | | rec_result = inference_pipline( |
| | | audio_in='https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_timestamps.wav', |
| | | text_in='一 个 东 太 平 洋 国 家 为 什 么 跑 到 西 太 平 洋 来 了 呢') |
| | | print(rec_result) |
| | |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer |
| | | from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export |
| | | from funasr.utils.timestamp_tools import time_stamp_lfr6_pl, time_stamp_sentence |
| | | from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard |
| | | |
| | | |
| | | class Speech2Text: |
| | |
| | | decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] |
| | | |
| | | if isinstance(self.asr_model, BiCifParaformer): |
| | | _, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len, |
| | | _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len, |
| | | pre_token_length) # test no bias cif2 |
| | | |
| | | results = [] |
| | |
| | | text = None |
| | | |
| | | if isinstance(self.asr_model, BiCifParaformer): |
| | | timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time) |
| | | _, timestamp = ts_prediction_lfr6_standard(us_alphas[i], |
| | | us_peaks[i], |
| | | copy.copy(token), |
| | | vad_offset=begin_time) |
| | | results.append((text, token, token_int, hyp, timestamp, enc_len_batch_total, lfr_factor)) |
| | | else: |
| | | results.append((text, token, token_int, hyp, enc_len_batch_total, lfr_factor)) |
| | |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | from funasr.tasks.vad import VADTask |
| | | from funasr.bin.vad_inference import Speech2VadSegment |
| | | from funasr.utils.timestamp_tools import time_stamp_lfr6_pl |
| | | from funasr.utils.timestamp_tools import time_stamp_sentence, ts_prediction_lfr6_standard |
| | | from funasr.bin.punctuation_infer import Text2Punc |
| | | from funasr.models.e2e_asr_paraformer import BiCifParaformer, ContextualParaformer |
| | | |
| | | from funasr.utils.timestamp_tools import time_stamp_sentence |
| | | |
| | | header_colors = '\033[95m' |
| | | end_colors = '\033[0m' |
| | |
| | | decoder_out, ys_pad_lens = decoder_outs[0], decoder_outs[1] |
| | | |
| | | if isinstance(self.asr_model, BiCifParaformer): |
| | | _, _, us_alphas, us_cif_peak = self.asr_model.calc_predictor_timestamp(enc, enc_len, |
| | | _, _, us_alphas, us_peaks = self.asr_model.calc_predictor_timestamp(enc, enc_len, |
| | | pre_token_length) # test no bias cif2 |
| | | |
| | | results = [] |
| | |
| | | text = None |
| | | |
| | | if isinstance(self.asr_model, BiCifParaformer): |
| | | timestamp = time_stamp_lfr6_pl(us_alphas[i], us_cif_peak[i], copy.copy(token), begin_time, end_time) |
| | | _, timestamp = ts_prediction_lfr6_standard(us_alphas[i], |
| | | us_peaks[i], |
| | | copy.copy(token), |
| | | vad_offset=begin_time) |
| | | results.append((text, token, token_int, timestamp, enc_len_batch_total, lfr_factor)) |
| | | else: |
| | | results.append((text, token, token_int, enc_len_batch_total, lfr_factor)) |
| | |
| | | elif mode == "uniasr": |
| | | from funasr.tasks.asr import ASRTaskUniASR as ASRTask |
| | | elif mode == "mfcca": |
| | | from funasr.tasks.asr import ASRTaskMFCCA as ASRTask |
| | | from funasr.tasks.asr import ASRTaskMFCCA as ASRTask |
| | | elif mode == "tp": |
| | | from funasr.tasks.asr import ASRTaskAligner as ASRTask |
| | | else: |
| | | raise ValueError("Unknown mode: {}".format(mode)) |
| | | parser = ASRTask.get_parser() |
| | |
| | | from funasr.utils.types import str_or_none |
| | | from funasr.models.frontend.wav_frontend import WavFrontend |
| | | from funasr.text.token_id_converter import TokenIDConverter |
| | | from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard |
| | | |
| | | |
| | | header_colors = '\033[95m' |
| | | end_colors = '\033[0m' |
| | |
| | | 'audio_fs': 16000, |
| | | 'model_fs': 16000 |
| | | } |
| | | |
| | | def time_stamp_lfr6_advance(us_alphas, us_cif_peak, char_list): |
| | | START_END_THRESHOLD = 5 |
| | | MAX_TOKEN_DURATION = 12 |
| | | TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled |
| | | if len(us_cif_peak.shape) == 2: |
| | | alphas, cif_peak = us_alphas[0], us_cif_peak[0] # support inference batch_size=1 only |
| | | else: |
| | | alphas, cif_peak = us_alphas, us_cif_peak |
| | | num_frames = cif_peak.shape[0] |
| | | if char_list[-1] == '</s>': |
| | | char_list = char_list[:-1] |
| | | # char_list = [i for i in text] |
| | | timestamp_list = [] |
| | | new_char_list = [] |
| | | # for bicif model trained with large data, cif2 actually fires when a character starts |
| | | # so treat the frames between two peaks as the duration of the former token |
| | | fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() - 3.2 # total offset |
| | | num_peak = len(fire_place) |
| | | assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1 |
| | | # begin silence |
| | | if fire_place[0] > START_END_THRESHOLD: |
| | | # char_list.insert(0, '<sil>') |
| | | timestamp_list.append([0.0, fire_place[0]*TIME_RATE]) |
| | | new_char_list.append('<sil>') |
| | | # tokens timestamp |
| | | for i in range(len(fire_place)-1): |
| | | new_char_list.append(char_list[i]) |
| | | if MAX_TOKEN_DURATION < 0 or fire_place[i+1] - fire_place[i] < MAX_TOKEN_DURATION: |
| | | timestamp_list.append([fire_place[i]*TIME_RATE, fire_place[i+1]*TIME_RATE]) |
| | | else: |
| | | # cut the duration to token and sil of the 0-weight frames last long |
| | | _split = fire_place[i] + MAX_TOKEN_DURATION |
| | | timestamp_list.append([fire_place[i]*TIME_RATE, _split*TIME_RATE]) |
| | | timestamp_list.append([_split*TIME_RATE, fire_place[i+1]*TIME_RATE]) |
| | | new_char_list.append('<sil>') |
| | | # tail token and end silence |
| | | # new_char_list.append(char_list[-1]) |
| | | if num_frames - fire_place[-1] > START_END_THRESHOLD: |
| | | _end = (num_frames + fire_place[-1]) * 0.5 |
| | | # _end = fire_place[-1] |
| | | timestamp_list[-1][1] = _end*TIME_RATE |
| | | timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE]) |
| | | new_char_list.append("<sil>") |
| | | else: |
| | | timestamp_list[-1][1] = num_frames*TIME_RATE |
| | | assert len(new_char_list) == len(timestamp_list) |
| | | res_str = "" |
| | | for char, timestamp in zip(new_char_list, timestamp_list): |
| | | res_str += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5]) |
| | | res = [] |
| | | for char, timestamp in zip(new_char_list, timestamp_list): |
| | | if char != '<sil>': |
| | | res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)]) |
| | | return res_str, res |
| | | |
| | | |
| | | class SpeechText2Timestamp: |
| | |
| | | for batch_id in range(_bs): |
| | | key = keys[batch_id] |
| | | token = speechtext2timestamp.converter.ids2tokens(batch['text'][batch_id]) |
| | | ts_str, ts_list = time_stamp_lfr6_advance(us_alphas[batch_id], us_cif_peak[batch_id], token) |
| | | ts_str, ts_list = ts_prediction_lfr6_standard(us_alphas[batch_id], us_cif_peak[batch_id], token, force_time_shift=-3.0) |
| | | logging.warning(ts_str) |
| | | item = {'key': key, 'value': ts_str, 'timestamp':ts_list} |
| | | tp_result_list.append(item) |
| | |
| | | def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num): |
| | | encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( |
| | | encoder_out.device) |
| | | ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out, |
| | | ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out, |
| | | encoder_out_mask, |
| | | token_num) |
| | | return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak |
| | | return ds_alphas, ds_cif_peak, us_alphas, us_peaks |
| | | |
| | | def forward( |
| | | self, |
| New file |
| | |
| | | import logging |
| | | from contextlib import contextmanager |
| | | from distutils.version import LooseVersion |
| | | from typing import Dict |
| | | from typing import List |
| | | from typing import Optional |
| | | from typing import Tuple |
| | | from typing import Union |
| | | |
| | | import torch |
| | | import numpy as np |
| | | from typeguard import check_argument_types |
| | | |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | | from funasr.models.frontend.abs_frontend import AbsFrontend |
| | | from funasr.models.predictor.cif import mae_loss |
| | | from funasr.modules.add_sos_eos import add_sos_eos |
| | | from funasr.modules.nets_utils import make_pad_mask, pad_list |
| | | from funasr.torch_utils.device_funcs import force_gatherable |
| | | from funasr.train.abs_espnet_model import AbsESPnetModel |
| | | from funasr.models.predictor.cif import CifPredictorV3 |
| | | |
| | | |
| | | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | | from torch.cuda.amp import autocast |
| | | else: |
| | | # Nothing to do if torch<1.6.0 |
| | | @contextmanager |
| | | def autocast(enabled=True): |
| | | yield |
| | | |
| | | |
| | | class TimestampPredictor(AbsESPnetModel): |
| | | """ |
| | | Author: Speech Lab, Alibaba Group, China |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | frontend: Optional[AbsFrontend], |
| | | encoder: AbsEncoder, |
| | | predictor: CifPredictorV3, |
| | | predictor_bias: int = 0, |
| | | token_list=None, |
| | | ): |
| | | assert check_argument_types() |
| | | |
| | | super().__init__() |
| | | # note that eos is the same as sos (equivalent ID) |
| | | |
| | | self.frontend = frontend |
| | | self.encoder = encoder |
| | | self.encoder.interctc_use_conditioning = False |
| | | |
| | | self.predictor = predictor |
| | | self.predictor_bias = predictor_bias |
| | | self.criterion_pre = mae_loss() |
| | | self.token_list = token_list |
| | | |
| | | def forward( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]: |
| | | """Frontend + Encoder + Decoder + Calc loss |
| | | |
| | | Args: |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | text: (Batch, Length) |
| | | text_lengths: (Batch,) |
| | | """ |
| | | assert text_lengths.dim() == 1, text_lengths.shape |
| | | # Check that batch_size is unified |
| | | assert ( |
| | | speech.shape[0] |
| | | == speech_lengths.shape[0] |
| | | == text.shape[0] |
| | | == text_lengths.shape[0] |
| | | ), (speech.shape, speech_lengths.shape, text.shape, text_lengths.shape) |
| | | batch_size = speech.shape[0] |
| | | # for data-parallel |
| | | text = text[:, : text_lengths.max()] |
| | | speech = speech[:, :speech_lengths.max()] |
| | | |
| | | # 1. Encoder |
| | | encoder_out, encoder_out_lens = self.encode(speech, speech_lengths) |
| | | |
| | | encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( |
| | | encoder_out.device) |
| | | if self.predictor_bias == 1: |
| | | _, text = add_sos_eos(text, 1, 2, -1) |
| | | text_lengths = text_lengths + self.predictor_bias |
| | | _, _, _, _, pre_token_length2 = self.predictor(encoder_out, text, encoder_out_mask, ignore_id=-1) |
| | | |
| | | # loss_pre = self.criterion_pre(ys_pad_lens.type_as(pre_token_length), pre_token_length) |
| | | loss_pre = self.criterion_pre(text_lengths.type_as(pre_token_length2), pre_token_length2) |
| | | |
| | | loss = loss_pre |
| | | stats = dict() |
| | | |
| | | # Collect Attn branch stats |
| | | stats["loss_pre"] = loss_pre.detach().cpu() if loss_pre is not None else None |
| | | stats["loss"] = torch.clone(loss.detach()) |
| | | |
| | | # force_gatherable: to-device and to-tensor if scalar for DataParallel |
| | | loss, stats, weight = force_gatherable((loss, stats, batch_size), loss.device) |
| | | return loss, stats, weight |
| | | |
| | | def encode( |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | """Frontend + Encoder. Note that this method is used by asr_inference.py |
| | | |
| | | Args: |
| | | speech: (Batch, Length, ...) |
| | | speech_lengths: (Batch, ) |
| | | """ |
| | | with autocast(False): |
| | | # 1. Extract feats |
| | | feats, feats_lengths = self._extract_feats(speech, speech_lengths) |
| | | |
| | | # 4. Forward encoder |
| | | # feats: (Batch, Length, Dim) |
| | | # -> encoder_out: (Batch, Length2, Dim2) |
| | | encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths) |
| | | |
| | | return encoder_out, encoder_out_lens |
| | | |
| | | def _extract_feats( |
| | | self, speech: torch.Tensor, speech_lengths: torch.Tensor |
| | | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | | assert speech_lengths.dim() == 1, speech_lengths.shape |
| | | |
| | | # for data-parallel |
| | | speech = speech[:, : speech_lengths.max()] |
| | | if self.frontend is not None: |
| | | # Frontend |
| | | # e.g. STFT and Feature extract |
| | | # data_loader may send time-domain signal in this case |
| | | # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim) |
| | | feats, feats_lengths = self.frontend(speech, speech_lengths) |
| | | else: |
| | | # No frontend and no feature extract |
| | | feats, feats_lengths = speech, speech_lengths |
| | | return feats, feats_lengths |
| | | |
| | | def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num): |
| | | encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to( |
| | | encoder_out.device) |
| | | ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out, |
| | | encoder_out_mask, |
| | | token_num) |
| | | return ds_alphas, ds_cif_peak, us_alphas, us_peaks |
| | | |
| | | def collect_feats( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | text: torch.Tensor, |
| | | text_lengths: torch.Tensor, |
| | | ) -> Dict[str, torch.Tensor]: |
| | | if self.extract_feats_in_collect_stats: |
| | | feats, feats_lengths = self._extract_feats(speech, speech_lengths) |
| | | else: |
| | | # Generate dummy stats if extract_feats_in_collect_stats is False |
| | | logging.warning( |
| | | "Generating dummy stats for feats and feats_lengths, " |
| | | "because encoder_conf.extract_feats_in_collect_stats is " |
| | | f"{self.extract_feats_in_collect_stats}" |
| | | ) |
| | | feats, feats_lengths = speech, speech_lengths |
| | | return {"feats": feats, "feats_lengths": feats_lengths} |
| | |
| | | from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder |
| | | from funasr.models.e2e_asr import ESPnetASRModel |
| | | from funasr.models.e2e_asr_paraformer import Paraformer, ParaformerBert, BiCifParaformer, ContextualParaformer |
| | | from funasr.models.e2e_tp import TimestampPredictor |
| | | from funasr.models.e2e_asr_mfcca import MFCCA |
| | | from funasr.models.e2e_uni_asr import UniASR |
| | | from funasr.models.encoder.abs_encoder import AbsEncoder |
| | |
| | | bicif_paraformer=BiCifParaformer, |
| | | contextual_paraformer=ContextualParaformer, |
| | | mfcca=MFCCA, |
| | | timestamp_prediction=TimestampPredictor, |
| | | ), |
| | | type_check=AbsESPnetModel, |
| | | default="asr", |
| | |
| | | |
| | | |
| | | class ASRTaskAligner(ASRTaskParaformer): |
| | | # If you need more than one optimizers, change this value |
| | | num_optimizers: int = 1 |
| | | |
| | | # Add variable objects configurations |
| | | class_choices_list = [ |
| | | # --frontend and --frontend_conf |
| | | frontend_choices, |
| | | # --model and --model_conf |
| | | model_choices, |
| | | # --encoder and --encoder_conf |
| | | encoder_choices, |
| | | # --decoder and --decoder_conf |
| | | decoder_choices, |
| | | ] |
| | | |
| | | # If you need to modify train() or eval() procedures, change Trainer class here |
| | | trainer = Trainer |
| | | |
| | | @classmethod |
| | | def build_model(cls, args: argparse.Namespace): |
| | | assert check_argument_types() |
| | | if isinstance(args.token_list, str): |
| | | with open(args.token_list, encoding="utf-8") as f: |
| | | token_list = [line.rstrip() for line in f] |
| | | |
| | | # Overwriting token_list to keep it as "portable". |
| | | args.token_list = list(token_list) |
| | | elif isinstance(args.token_list, (tuple, list)): |
| | | token_list = list(args.token_list) |
| | | else: |
| | | raise RuntimeError("token_list must be str or list") |
| | | |
| | | # 1. frontend |
| | | if args.input_size is None: |
| | | # Extract features in the model |
| | | frontend_class = frontend_choices.get_class(args.frontend) |
| | | if args.frontend == 'wav_frontend': |
| | | frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf) |
| | | else: |
| | | frontend = frontend_class(**args.frontend_conf) |
| | | input_size = frontend.output_size() |
| | | else: |
| | | # Give features from data-loader |
| | | args.frontend = None |
| | | args.frontend_conf = {} |
| | | frontend = None |
| | | input_size = args.input_size |
| | | |
| | | # 2. Encoder |
| | | encoder_class = encoder_choices.get_class(args.encoder) |
| | | encoder = encoder_class(input_size=input_size, **args.encoder_conf) |
| | | |
| | | # 3. Predictor |
| | | predictor_class = predictor_choices.get_class(args.predictor) |
| | | predictor = predictor_class(**args.predictor_conf) |
| | | |
| | | # 10. Build model |
| | | try: |
| | | model_class = model_choices.get_class(args.model) |
| | | except AttributeError: |
| | | model_class = model_choices.get_class("asr") |
| | | |
| | | # 8. Build model |
| | | model = model_class( |
| | | frontend=frontend, |
| | | encoder=encoder, |
| | | predictor=predictor, |
| | | token_list=token_list, |
| | | **args.model_conf, |
| | | ) |
| | | |
| | | # 11. Initialize |
| | | if args.init is not None: |
| | | initialize(model, args.init) |
| | | |
| | | assert check_return_type(model) |
| | | return model |
| | | |
| | | @classmethod |
| | | def required_data_names( |
| | | cls, train: bool = True, inference: bool = False |
| | | ) -> Tuple[str, ...]: |
| | | retval = ("speech", "text") |
| | | return retval |
| | | return retval |
| | |
| | | from typing import Any, List, Tuple, Union |
| | | |
| | | |
| | | def time_stamp_lfr6_pl(us_alphas, us_cif_peak, char_list, begin_time=0.0, end_time=None): |
| | | def ts_prediction_lfr6_standard(us_alphas, |
| | | us_peaks, |
| | | char_list, |
| | | vad_offset=0.0, |
| | | force_time_shift=-1.5 |
| | | ): |
| | | if not len(char_list): |
| | | return [] |
| | | START_END_THRESHOLD = 5 |
| | | MAX_TOKEN_DURATION = 12 |
| | | TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled |
| | | if len(us_alphas.shape) == 3: |
| | | alphas, cif_peak = us_alphas[0], us_cif_peak[0] # support inference batch_size=1 only |
| | | if len(us_alphas.shape) == 2: |
| | | _, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only |
| | | else: |
| | | alphas, cif_peak = us_alphas, us_cif_peak |
| | | num_frames = cif_peak.shape[0] |
| | | _, peaks = us_alphas, us_peaks |
| | | num_frames = peaks.shape[0] |
| | | if char_list[-1] == '</s>': |
| | | char_list = char_list[:-1] |
| | | # char_list = [i for i in text] |
| | | timestamp_list = [] |
| | | new_char_list = [] |
| | | # for bicif model trained with large data, cif2 actually fires when a character starts |
| | | # so treat the frames between two peaks as the duration of the former token |
| | | fire_place = torch.where(cif_peak>1.0-1e-4)[0].cpu().numpy() - 1.5 |
| | | fire_place = torch.where(peaks>1.0-1e-4)[0].cpu().numpy() + force_time_shift # total offset |
| | | num_peak = len(fire_place) |
| | | assert num_peak == len(char_list) + 1 # number of peaks is supposed to be number of tokens + 1 |
| | | # begin silence |
| | | if fire_place[0] > START_END_THRESHOLD: |
| | | char_list.insert(0, '<sil>') |
| | | # char_list.insert(0, '<sil>') |
| | | timestamp_list.append([0.0, fire_place[0]*TIME_RATE]) |
| | | new_char_list.append('<sil>') |
| | | # tokens timestamp |
| | | for i in range(len(fire_place)-1): |
| | | # the peak is always a little ahead of the start time |
| | | # timestamp_list.append([(fire_place[i]-1.2)*TIME_RATE, fire_place[i+1]*TIME_RATE]) |
| | | timestamp_list.append([(fire_place[i])*TIME_RATE, fire_place[i+1]*TIME_RATE]) |
| | | # cut the duration to token and sil of the 0-weight frames last long |
| | | new_char_list.append(char_list[i]) |
| | | if MAX_TOKEN_DURATION < 0 or fire_place[i+1] - fire_place[i] <= MAX_TOKEN_DURATION: |
| | | timestamp_list.append([fire_place[i]*TIME_RATE, fire_place[i+1]*TIME_RATE]) |
| | | else: |
| | | # cut the duration to token and sil of the 0-weight frames last long |
| | | _split = fire_place[i] + MAX_TOKEN_DURATION |
| | | timestamp_list.append([fire_place[i]*TIME_RATE, _split*TIME_RATE]) |
| | | timestamp_list.append([_split*TIME_RATE, fire_place[i+1]*TIME_RATE]) |
| | | new_char_list.append('<sil>') |
| | | # tail token and end silence |
| | | # new_char_list.append(char_list[-1]) |
| | | if num_frames - fire_place[-1] > START_END_THRESHOLD: |
| | | _end = (num_frames + fire_place[-1]) / 2 |
| | | _end = (num_frames + fire_place[-1]) * 0.5 |
| | | # _end = fire_place[-1] |
| | | timestamp_list[-1][1] = _end*TIME_RATE |
| | | timestamp_list.append([_end*TIME_RATE, num_frames*TIME_RATE]) |
| | | char_list.append("<sil>") |
| | | new_char_list.append("<sil>") |
| | | else: |
| | | timestamp_list[-1][1] = num_frames*TIME_RATE |
| | | if begin_time: # add offset time in model with vad |
| | | if vad_offset: # add offset time in model with vad |
| | | for i in range(len(timestamp_list)): |
| | | timestamp_list[i][0] = timestamp_list[i][0] + begin_time / 1000.0 |
| | | timestamp_list[i][1] = timestamp_list[i][1] + begin_time / 1000.0 |
| | | timestamp_list[i][0] = timestamp_list[i][0] + vad_offset / 1000.0 |
| | | timestamp_list[i][1] = timestamp_list[i][1] + vad_offset / 1000.0 |
| | | res_txt = "" |
| | | for char, timestamp in zip(char_list, timestamp_list): |
| | | res_txt += "{} {} {};".format(char, timestamp[0], timestamp[1]) |
| | | for char, timestamp in zip(new_char_list, timestamp_list): |
| | | res_txt += "{} {} {};".format(char, str(timestamp[0]+0.0005)[:5], str(timestamp[1]+0.0005)[:5]) |
| | | res = [] |
| | | for char, timestamp in zip(char_list, timestamp_list): |
| | | for char, timestamp in zip(new_char_list, timestamp_list): |
| | | if char != '<sil>': |
| | | res.append([int(timestamp[0] * 1000), int(timestamp[1] * 1000)]) |
| | | return res |
| | | return res_txt, res |
| | | |
| | | |
| | | def time_stamp_sentence(punc_id_list, time_stamp_postprocessed, text_postprocessed): |