| | |
| | | # -*- encoding: utf-8 -*- |
| | | # Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | # MIT License (https://opensource.org/licenses/MIT) |
| | | |
| | | import os.path |
| | | from pathlib import Path |
| | |
| | | |
| | | |
| | | class CT_Transformer(): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection |
| | | https://arxiv.org/pdf/2003.01309.pdf |
| | | """ |
| | | def __init__(self, model_dir: Union[str, Path] = None, |
| | | batch_size: int = 1, |
| | | device_id: Union[str, int] = "-1", |
| | |
| | | mini_sentence = mini_sentences[mini_sentence_i] |
| | | mini_sentence_id = mini_sentences_id[mini_sentence_i] |
| | | mini_sentence = cache_sent + mini_sentence |
| | | mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype='int64') |
| | | mini_sentence_id = np.array(cache_sent_id + mini_sentence_id, dtype='int32') |
| | | data = { |
| | | "text": mini_sentence_id[None,:], |
| | | "text_lengths": np.array([len(mini_sentence_id)], dtype='int32'), |
| | |
| | | |
| | | |
| | | class CT_Transformer_VadRealtime(CT_Transformer): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | CT-Transformer: Controllable time-delay transformer for real-time punctuation prediction and disfluency detection |
| | | https://arxiv.org/pdf/2003.01309.pdf |
| | | """ |
| | | def __init__(self, model_dir: Union[str, Path] = None, |
| | | batch_size: int = 1, |
| | | device_id: Union[str, int] = "-1", |
| | |
| | | else: |
| | | precache = "" |
| | | cache = [] |
| | | full_text = precache + text |
| | | full_text = precache + " " + text |
| | | split_text = code_mix_split_words(full_text) |
| | | split_text_id = self.converter.tokens2ids(split_text) |
| | | mini_sentences = split_to_mini_sentence(split_text, split_size) |
| | |
| | | mini_sentence = mini_sentences[mini_sentence_i] |
| | | mini_sentence_id = mini_sentences_id[mini_sentence_i] |
| | | mini_sentence = cache_sent + mini_sentence |
| | | mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0) |
| | | mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0,dtype='int32') |
| | | text_length = len(mini_sentence_id) |
| | | data = { |
| | | "input": mini_sentence_id[None,:], |
| | | "text_lengths": np.array([text_length], dtype='int32'), |
| | | "vad_mask": self.vad_mask(text_length, len(cache) - 1)[None, None, :, :].astype(np.float32), |
| | | "vad_mask": self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32), |
| | | "sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32) |
| | | } |
| | | try: |