Merge pull request #522 from alibaba-damo-academy/dev_cmz_fromDev_infer
increase vad realtime punc
| | |
| | | inference_pipeline = pipeline( |
| | | task=Tasks.punctuation, |
| | | model='damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727', |
| | | output_dir="./tmp/" |
| | | model_revision = 'v1.0.2' |
| | | ) |
| | | |
| | | ##################text二进制数据##################### |
| | |
| | | return {'inputs': np.ones((1, text_length), dtype=np.int64), |
| | | 'text_lengths': np.array([text_length,], dtype=np.int32), |
| | | 'vad_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32), |
| | | 'sub_masks': np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32) |
| | | 'sub_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32), |
| | | } |
| | | |
| | | def _run(feed_dict): |
| | |
| | | mini_sentence = cache_sent + mini_sentence |
| | | mini_sentence_id = np.concatenate((cache_sent_id, mini_sentence_id), axis=0,dtype='int32') |
| | | text_length = len(mini_sentence_id) |
| | | vad_mask = self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32) |
| | | data = { |
| | | "input": mini_sentence_id[None,:], |
| | | "text_lengths": np.array([text_length], dtype='int32'), |
| | | "vad_mask": self.vad_mask(text_length, len(cache))[None, None, :, :].astype(np.float32), |
| | | "sub_masks": np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32) |
| | | "vad_mask": vad_mask |
| | | "sub_masks": vad_mask |
| | | } |
| | | try: |
| | | outputs = self.infer(data['input'], data['text_lengths'], data['vad_mask'], data["sub_masks"]) |