凌匀
2023-02-16 91027ddab49e5791fc42569b4db9dafca55735e6
funasr/bin/vad_inference.py
@@ -81,6 +81,7 @@
        self.device = device
        self.dtype = dtype
        self.frontend = frontend
        self.batch_size = batch_size
    @torch.no_grad()
    def __call__(
@@ -110,10 +111,9 @@
        # segments = self.vad_model(**batch)
        # b. Forward Encoder sreaming
        segments = []
        segments_tmp = []
        step = 6000
        t_offset = 0
        step = min(feats_len, 6000)
        segments = [[]] * self.batch_size
        for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
            if t_offset + step >= feats_len - 1:
                step = feats_len - t_offset
@@ -129,8 +129,8 @@
            batch = to_device(batch, device=self.device)
            segments_part = self.vad_model(**batch)
            if segments_part:
                segments_tmp += segments_part[0]
        segments.append(segments_tmp)
                for batch_num in range(0, self.batch_size):
                    segments[batch_num] += segments_part[batch_num]
        return segments
@@ -254,7 +254,6 @@
            assert all(isinstance(s, str) for s in keys), keys
            _bs = len(next(iter(batch.values())))
            assert len(keys) == _bs, f"{len(keys)} != {_bs}"
            # batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
            # do vad segment
            results = speech2vadsegment(**batch)