| examples/industrial_data_pretraining/paraformer-large-long/infer.sh | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| examples/industrial_data_pretraining/punc/infer.sh | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/bin/inference.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/bici_paraformer/model.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/ct_transformer/model.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/ct_transformer/template.yaml | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/ct_transformer/utils.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 | |
| funasr/models/paraformer/model.py | ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史 |
examples/industrial_data_pretraining/paraformer-large-long/infer.sh
@@ -4,6 +4,7 @@ python $cmd \ +model="/Users/zhifu/Downloads/modelscope_models/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch" \ +vad_model="/Users/zhifu/Downloads/modelscope_models/speech_fsmn_vad_zh-cn-16k-common-pytorch" \ +punc_model="/Users/zhifu/Downloads/modelscope_models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" \ +input="/Users/zhifu/funasr_github/test_local/vad_example.wav" \ +output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2" \ +device="cpu" \ examples/industrial_data_pretraining/punc/infer.sh
@@ -2,8 +2,17 @@ cmd="funasr/bin/inference.py" python $cmd \ +model="/Users/zhifu/Downloads/modelscope_models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" \ +input="/Users/zhifu/FunASR/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/data/punc_example.txt" \ +model="/Users/zhifu/Downloads/modelscope_models/punc_ct-transformer_zh-cn-common-vocab272727-pytorch" \ +output_dir="/Users/zhifu/Downloads/ckpt/funasr2/exp2_punc" \ +device="cpu" \ +debug="true" #+input="/Users/zhifu/FunASR/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/data/punc_example.txt" \ #+"input='跨境河流是养育沿岸人民的生命之源长期以来为帮助下游地区防灾减灾中方技术人员在上游地区极为恶劣的自然条件下克服巨大困难甚至冒着生命危险向印方提供汛期水文资料处理紧急事件中方重视印方在跨境河流问题上的关切愿意进一步完善双方联合工作机制凡是中方能做的我们都会去做而且会做得更好我请印度朋友们放心中国在上游的任何开发利用都会经过科学规划和论证兼顾上下游的利益'" \ #+input="/Users/zhifu/FunASR/egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vocab272727-pytorch/data/punc_example.txt" \ #+"input='那今天的会就到这里吧 happy new year 明年见'" \ funasr/bin/inference.py
@@ -18,6 +18,7 @@ from funasr.register import tables from funasr.datasets.audio_datasets.load_audio_extract_fbank import load_audio from funasr.utils.vad_utils import slice_padding_audio_samples from funasr.utils.timestamp_tools import time_stamp_sentence def build_iter_for_infer(data_in, input_len=None, data_type="sound"): """ @@ -46,7 +47,7 @@ data = lines["source"] key = data["key"] if "key" in data else key else: # filelist, wav.scp, text.txt: id \t data or data lines = line.strip().split() lines = line.strip().split(maxsplit=1) data = lines[1] if len(lines)>1 else lines[0] key = lines[0] if len(lines)>1 else key @@ -227,6 +228,7 @@ # step.1: compute the vad model model = self.vad_model kwargs = self.vad_kwargs kwargs.update(cfg) beg_vad = time.time() res = self.generate(input, input_len=input_len, model=model, kwargs=kwargs, **cfg) end_vad = time.time() @@ -322,6 +324,23 @@ result["key"] = key results_ret_list.append(result) pbar_total.update(1) # step.3 compute punc model model = self.punc_model kwargs = self.punc_kwargs kwargs.update(cfg) for i, result in enumerate(results_ret_list): beg_punc = time.time() res = self.generate(result["text"], model=model, kwargs=kwargs, **cfg) end_punc = time.time() print(f"time punc: {end_punc - beg_punc:0.3f}") # sentences = time_stamp_sentence(model.punc_list, model.sentence_end_id, results_ret_list[i]["timestamp"], res[i]["text"]) # results_ret_list[i]["time_stamp"] = res[0]["text_postprocessed_punc"] # results_ret_list[i]["sentences"] = sentences # results_ret_list[i]["text_with_punc"] = res[i]["text"] pbar_total.update(1) end_total = time.time() time_escape_total_all_samples = end_total - beg_total funasr/models/bici_paraformer/model.py
@@ -29,7 +29,7 @@ from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard from funasr.register import tables from funasr.models.ctc.ctc import CTC from funasr.utils.timestamp_tools import time_stamp_sentence from funasr.models.paraformer.model import Paraformer @@ -321,18 +321,16 @@ text_postprocessed, time_stamp_postprocessed, word_lists = postprocess_utils.sentence_postprocess( token, timestamp) sentences = time_stamp_sentence(None, time_stamp_postprocessed, text_postprocessed) result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed, result_i = {"key": key[i], "text": text_postprocessed, "timestamp": time_stamp_postprocessed, "word_lists": word_lists, "sentences": sentences } if ibest_writer is not None: ibest_writer["token"][key[i]] = " ".join(token) ibest_writer["text"][key[i]] = text # ibest_writer["text"][key[i]] = text ibest_writer["timestamp"][key[i]] = time_stamp_postprocessed ibest_writer["text_postprocessed"][key[i]] = text_postprocessed ibest_writer["text"][key[i]] = text_postprocessed else: result_i = {"key": key[i], "token_int": token_int} results.append(result_i) funasr/models/ct_transformer/model.py
@@ -10,7 +10,7 @@ from funasr.train_utils.device_funcs import to_device import torch import torch.nn as nn from funasr.models.ct_transformer.utils import split_to_mini_sentence from funasr.models.ct_transformer.utils import split_to_mini_sentence, split_words from funasr.register import tables @@ -34,6 +34,7 @@ ignore_id: int = -1, sos: int = 1, eos: int = 2, sentence_end_id: int = 3, **kwargs, ): super().__init__() @@ -54,10 +55,11 @@ self.ignore_id = ignore_id self.sos = sos self.eos = eos self.sentence_end_id = sentence_end_id def punc_forward(self, input: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: def punc_forward(self, text: torch.Tensor, text_lengths: torch.Tensor) -> Tuple[torch.Tensor, None]: """Compute loss value from buffer sequences. Args: @@ -65,7 +67,7 @@ hidden (torch.Tensor): Target ids. (batch, len) """ x = self.embed(input) x = self.embed(text) # mask = self._target_mask(input) h, _, _ = self.encoder(x, text_lengths) y = self.decoder(h) @@ -216,22 +218,26 @@ frontend=None, **kwargs, ): assert len(data_in) == 1 vad_indexes = kwargs.get("vad_indexes", None) text = data_in text_lengths = data_lengths text = data_in[0] text_lengths = data_lengths[0] if data_lengths is not None else None split_size = kwargs.get("split_size", 20) data = {"text": text} result = self.preprocessor(data=data, uid="12938712838719") split_text = self.preprocessor.pop_split_text_data(result) mini_sentences = split_to_mini_sentence(split_text, split_size) mini_sentences_id = split_to_mini_sentence(data["text"], split_size) tokens = split_words(text) tokens_int = tokenizer.encode(tokens) mini_sentences = split_to_mini_sentence(tokens, split_size) mini_sentences_id = split_to_mini_sentence(tokens_int, split_size) assert len(mini_sentences) == len(mini_sentences_id) cache_sent = [] cache_sent_id = torch.from_numpy(np.array([], dtype='int32')) new_mini_sentence = "" new_mini_sentence_punc = [] cache_pop_trigger_limit = 200 results = [] meta_data = {} for mini_sentence_i in range(len(mini_sentences)): mini_sentence = mini_sentences[mini_sentence_i] mini_sentence_id = mini_sentences_id[mini_sentence_i] @@ -241,9 +247,9 @@ "text": torch.unsqueeze(torch.from_numpy(mini_sentence_id), 0), "text_lengths": torch.from_numpy(np.array([len(mini_sentence_id)], dtype='int32')), } data = to_device(data, self.device) data = to_device(data, kwargs["device"]) # y, _ = self.wrapped_model(**data) y, _ = self.punc_forward(text, text_lengths) y, _ = self.punc_forward(**data) _, indices = y.view(-1, y.shape[-1]).topk(1, dim=1) punctuations = indices if indices.size()[0] != 1: @@ -264,7 +270,7 @@ if sentenceEnd < 0 and len(mini_sentence) > cache_pop_trigger_limit and last_comma_index >= 0: # The sentence it too long, cut off at a comma. sentenceEnd = last_comma_index punctuations[sentenceEnd] = self.period punctuations[sentenceEnd] = self.sentence_end_id cache_sent = mini_sentence[sentenceEnd + 1:] cache_sent_id = mini_sentence_id[sentenceEnd + 1:] mini_sentence = mini_sentence[0:sentenceEnd + 1] @@ -303,21 +309,19 @@ if mini_sentence_i == len(mini_sentences) - 1: if new_mini_sentence[-1] == "," or new_mini_sentence[-1] == "、": new_mini_sentence_out = new_mini_sentence[:-1] + "。" new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period] new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id] elif new_mini_sentence[-1] == ",": new_mini_sentence_out = new_mini_sentence[:-1] + "." new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period] new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id] elif new_mini_sentence[-1] != "。" and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==0: new_mini_sentence_out = new_mini_sentence + "。" new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period] new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id] elif new_mini_sentence[-1] != "." and new_mini_sentence[-1] != "?" and len(new_mini_sentence[-1].encode())==1: new_mini_sentence_out = new_mini_sentence + "." new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.period] new_mini_sentence_punc_out = new_mini_sentence_punc[:-1] + [self.sentence_end_id] return new_mini_sentence_out, new_mini_sentence_punc_out result_i = {"key": key[0], "text": new_mini_sentence_out} results.append(result_i) # if self.with_vad(): # assert vad_indexes is not None # return self.punc_forward(text, text_lengths, vad_indexes) # else: # return self.punc_forward(text, text_lengths) return results, meta_data funasr/models/ct_transformer/template.yaml
New file @@ -0,0 +1,52 @@ # This is an example that demonstrates how to configure a model file. # You can modify the configuration according to your own requirements. # to print the register_table: # from funasr.register import tables # tables.print() model: CTTransformer model_conf: ignore_id: 0 embed_unit: 256 att_unit: 256 dropout_rate: 0.1 punc_list: - <unk> - _ - ',' - 。 - '?' - 、 punc_weight: - 1.0 - 1.0 - 1.0 - 1.0 - 1.0 - 1.0 encoder: SANMEncoder encoder_conf: input_size: 256 output_size: 256 attention_heads: 8 linear_units: 1024 num_blocks: 4 dropout_rate: 0.1 positional_dropout_rate: 0.1 attention_dropout_rate: 0.0 input_layer: pe pos_enc_class: SinusoidalPositionEncoder normalize_before: true kernel_size: 11 sanm_shfit: 0 selfattention_layer_type: sanm padding_idx: 0 tokenizer: CharTokenizer tokenizer_conf: unk_symbol: <unk> funasr/models/ct_transformer/utils.py
@@ -12,3 +12,25 @@ if length % word_limit > 0: sentences.append(words[sentence_len * word_limit:]) return sentences def split_words(text: str): words = [] segs = text.split() for seg in segs: # There is no space in seg. current_word = "" for c in seg: if len(c.encode()) == 1: # This is an ASCII char. current_word += c else: # This is a Chinese char. if len(current_word) > 0: words.append(current_word) current_word = "" words.append(c) if len(current_word) > 0: words.append(current_word) return words funasr/models/paraformer/model.py
@@ -535,13 +535,13 @@ text = tokenizer.tokens2text(token) text_postprocessed, _ = postprocess_utils.sentence_postprocess(token) result_i = {"key": key[i], "token": token, "text": text, "text_postprocessed": text_postprocessed} result_i = {"key": key[i], "text_postprocessed": text_postprocessed} if ibest_writer is not None: ibest_writer["token"][key[i]] = " ".join(token) ibest_writer["text"][key[i]] = text ibest_writer["text_postprocessed"][key[i]] = text_postprocessed # ibest_writer["text"][key[i]] = text ibest_writer["text"][key[i]] = text_postprocessed else: result_i = {"key": key[i], "token_int": token_int} results.append(result_i)