| | |
| | | self.device = device |
| | | self.dtype = dtype |
| | | self.frontend = frontend |
| | | self.batch_size = batch_size |
| | | |
| | | @torch.no_grad() |
| | | def __call__( |
| | |
| | | # segments = self.vad_model(**batch) |
| | | |
| | | # b. Forward Encoder sreaming |
| | | segments = [] |
| | | segments_tmp = [] |
| | | step = 6000 |
| | | t_offset = 0 |
| | | step = min(feats_len, 6000) |
| | | segments = [[]] * self.batch_size |
| | | for t_offset in range(0, feats_len, min(step, feats_len - t_offset)): |
| | | if t_offset + step >= feats_len - 1: |
| | | step = feats_len - t_offset |
| | |
| | | batch = to_device(batch, device=self.device) |
| | | segments_part = self.vad_model(**batch) |
| | | if segments_part: |
| | | segments_tmp += segments_part[0] |
| | | segments.append(segments_tmp) |
| | | for batch_num in range(0, self.batch_size): |
| | | segments[batch_num] += segments_part[batch_num] |
| | | return segments |
| | | |
| | | |
| | |
| | | assert all(isinstance(s, str) for s in keys), keys |
| | | _bs = len(next(iter(batch.values()))) |
| | | assert len(keys) == _bs, f"{len(keys)} != {_bs}" |
| | | # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")} |
| | | |
| | | # do vad segment |
| | | results = speech2vadsegment(**batch) |