游雁
2023-02-24 4daea3711063c64485be3c00eaa9727404549f51
funasr/bin/vad_inference.py
@@ -81,6 +81,7 @@
        self.device = device
        self.dtype = dtype
        self.frontend = frontend
        self.batch_size = batch_size
    @torch.no_grad()
    def __call__(
@@ -106,14 +107,11 @@
            feats_len = feats_len.int()
        else:
            raise Exception("Need to extract feats first, please configure frontend configuration")
        # batch = {"feats": feats, "waveform": speech, "is_final_send": True}
        # segments = self.vad_model(**batch)
        # b. Forward Encoder sreaming
        segments = []
        segments_tmp = []
        step = 6000
        # b. Forward Encoder streaming
        t_offset = 0
        step = min(feats_len, 6000)
        segments = [[]] * self.batch_size
        for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
            if t_offset + step >= feats_len - 1:
                step = feats_len - t_offset
@@ -129,8 +127,8 @@
            batch = to_device(batch, device=self.device)
            segments_part = self.vad_model(**batch)
            if segments_part:
                segments_tmp += segments_part[0]
        segments.append(segments_tmp)
                for batch_num in range(0, self.batch_size):
                    segments[batch_num] += segments_part[batch_num]
        return segments
@@ -254,7 +252,6 @@
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
            # do vad segment
            results = speech2vadsegment(**batch)