| | |
| | | output_name = [nd.name for nd in sess.get_outputs()] |
| | | |
| | | def _get_feed_dict(text_length): |
| | | return {'input': np.ones((1, text_length), dtype=np.int64), |
| | | return {'inputs': np.ones((1, text_length), dtype=np.int64), |
| | | 'text_lengths': np.array([text_length,], dtype=np.int32), |
| | | 'vad_mask': np.ones((1, 1, text_length, text_length), dtype=np.float32), |
| | | 'vad_masks': np.ones((1, 1, text_length, text_length), dtype=np.float32), |
| | | 'sub_masks': np.tril(np.ones((text_length, text_length), dtype=np.float32))[None, None, :, :].astype(np.float32) |
| | | } |
| | | |