Merge pull request #163 from alibaba-damo-academy/dev_zly
update vad inference
| | |
| | | |
| | | @torch.no_grad() |
| | | def __call__( |
| | | self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None |
| | | self, speech: Union[torch.Tensor, np.ndarray], speech_lengths: Union[torch.Tensor, np.ndarray] = None, |
| | | in_cache: Dict[str, torch.Tensor] = dict() |
| | | ) -> Tuple[List[List[int]], Dict[str, torch.Tensor]]: |
| | | """Inference |
| | | |
| | |
| | | batch = { |
| | | "feats": feats[:, t_offset:t_offset + step, :], |
| | | "waveform": speech[:, t_offset * 160:min(speech.shape[-1], (t_offset + step - 1) * 160 + 400)], |
| | | "is_final": is_final |
| | | "is_final": is_final, |
| | | "in_cache": in_cache |
| | | } |
| | | # a. To device |
| | | batch = to_device(batch, device=self.device) |
| | | segments_part = self.vad_model(**batch) |
| | | segments_part, in_cache = self.vad_model(**batch) |
| | | if segments_part: |
| | | for batch_num in range(0, self.batch_size): |
| | | segments[batch_num] += segments_part[batch_num] |
| | |
| | | |
| | | def forward(self, feats: torch.Tensor, waveform: torch.tensor, in_cache: Dict[str, torch.Tensor] = dict(), |
| | | is_final: bool = False |
| | | ) -> List[List[List[int]]]: |
| | | ) -> Tuple[List[List[List[int]]], Dict[str, torch.Tensor]]: |
| | | self.waveform = waveform # compute decibel for each frame |
| | | self.ComputeDecibel() |
| | | self.ComputeScores(feats, in_cache) |
| | |
| | | if is_final: |
| | | # reset class variables and clear the dict for the next query |
| | | self.AllResetDetection() |
| | | in_cache.clear() |
| | | return segments |
| | | return segments, in_cache |
| | | |
| | | def DetectCommonFrames(self) -> int: |
| | | if self.vad_state_machine == VadStateMachine.kVadInStateEndPointDetected: |