support offline inference for unified streaming/non-streaming rnnt
| | |
| | | nbest: int = 1, |
| | | streaming: bool = False, |
| | | simu_streaming: bool = False, |
| | | full_utt: bool = False, |
| | | chunk_size: int = 16, |
| | | left_context: int = 32, |
| | | right_context: int = 0, |
| | |
| | | self.beam_search = beam_search |
| | | self.streaming = streaming |
| | | self.simu_streaming = simu_streaming |
| | | self.full_utt = full_utt |
| | | self.chunk_size = max(chunk_size, 0) |
| | | self.left_context = left_context |
| | | self.right_context = max(right_context, 0) |
| | |
| | | self._ctx = self.asr_model.encoder.get_encoder_input_size( |
| | | self.window_size |
| | | ) |
| | | self._right_ctx = right_context |
| | | |
| | | self.last_chunk_length = ( |
| | | self.asr_model.encoder.embed.min_frame_length + self.right_context + 1 |
| | |
| | | return nbest_hyps |
| | | |
| | | @torch.no_grad() |
| | | def full_utt_decode(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]: |
| | | """Speech2Text call. |
| | | Args: |
| | | speech: Speech data. (S) |
| | | Returns: |
| | | nbest_hypothesis: N-best hypothesis. |
| | | """ |
| | | assert check_argument_types() |
| | | |
| | | if isinstance(speech, np.ndarray): |
| | | speech = torch.tensor(speech) |
| | | |
| | | if self.frontend is not None: |
| | | speech = torch.unsqueeze(speech, axis=0) |
| | | speech_lengths = speech.new_full([1], dtype=torch.long, fill_value=speech.size(1)) |
| | | feats, feats_lengths = self.frontend(speech, speech_lengths) |
| | | else: |
| | | feats = speech.unsqueeze(0).to(getattr(torch, self.dtype)) |
| | | feats_lengths = feats.new_full([1], dtype=torch.long, fill_value=feats.size(1)) |
| | | |
| | | if self.asr_model.normalize is not None: |
| | | feats, feats_lengths = self.asr_model.normalize(feats, feats_lengths) |
| | | |
| | | feats = to_device(feats, device=self.device) |
| | | feats_lengths = to_device(feats_lengths, device=self.device) |
| | | enc_out = self.asr_model.encoder.full_utt_forward(feats, feats_lengths) |
| | | nbest_hyps = self.beam_search(enc_out[0]) |
| | | |
| | | return nbest_hyps |
| | | |
| | | @torch.no_grad() |
| | | def __call__(self, speech: Union[torch.Tensor, np.ndarray]) -> List[HypothesisTransducer]: |
| | | """Speech2Text call. |
| | | Args: |
| | |
| | | quantize_dtype: Optional[str] = "float16", |
| | | streaming: Optional[bool] = False, |
| | | simu_streaming: Optional[bool] = False, |
| | | full_utt: Optional[bool] = False, |
| | | chunk_size: Optional[int] = 16, |
| | | left_context: Optional[int] = 16, |
| | | right_context: Optional[int] = 0, |
| | |
| | | quantize_dtype=quantize_dtype, |
| | | streaming=streaming, |
| | | simu_streaming=simu_streaming, |
| | | full_utt=full_utt, |
| | | chunk_size=chunk_size, |
| | | left_context=left_context, |
| | | right_context=right_context, |
| | |
| | | _end = (i + 1) * speech2text._ctx |
| | | |
| | | speech2text.streaming_decode( |
| | | speech[i * speech2text._ctx: _end], is_final=False |
| | | speech[i * speech2text._ctx: _end + speech2text._right_ctx], is_final=False |
| | | ) |
| | | |
| | | final_hyps = speech2text.streaming_decode( |
| | |
| | | ) |
| | | elif speech2text.simu_streaming: |
| | | final_hyps = speech2text.simu_streaming_decode(**batch) |
| | | elif speech2text.full_utt: |
| | | final_hyps = speech2text.full_utt_decode(**batch) |
| | | else: |
| | | final_hyps = speech2text(**batch) |
| | | |