游雁
2024-04-29 8f596af4be1c2e5c4e4b4a7008ba96f412d40fca
funasr/models/sense_voice/whisper_lib/timing.py
@@ -27,9 +27,7 @@
        # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
        x = x[None, None, :]
    assert (
        filter_width > 0 and filter_width % 2 == 1
    ), "`filter_width` should be an odd number"
    assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number"
    result = None
    x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
@@ -111,9 +109,7 @@
    M, N = x.shape
    assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
    x_skew = (
        F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
    )
    x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
    x_skew = x_skew.T.contiguous()
    cost = torch.ones(N + M + 2, M + 2) * np.inf
    cost[0, 0] = 0
@@ -132,9 +128,7 @@
        BLOCK_SIZE=BLOCK_SIZE,
    )
    trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
        :, : N + 1
    ]
    trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, : N + 1]
    return backtrace(trace.cpu().numpy())
@@ -228,8 +222,7 @@
    start_times = jump_times[word_boundaries[:-1]]
    end_times = jump_times[word_boundaries[1:]]
    word_probabilities = [
        np.mean(text_token_probs[i:j])
        for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
        np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
    ]
    return [
@@ -290,8 +283,7 @@
        return
    text_tokens_per_segment = [
        [token for token in segment["tokens"] if token < tokenizer.eot]
        for segment in segments
        [token for token in segment["tokens"] if token < tokenizer.eot] for segment in segments
    ]
    text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
@@ -346,38 +338,22 @@
            # twice the median word duration.
            if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
                words[0]["end"] - words[0]["start"] > max_duration
                or (
                    len(words) > 1
                    and words[1]["end"] - words[0]["start"] > max_duration * 2
                )
                or (len(words) > 1 and words[1]["end"] - words[0]["start"] > max_duration * 2)
            ):
                if (
                    len(words) > 1
                    and words[1]["end"] - words[1]["start"] > max_duration
                ):
                if len(words) > 1 and words[1]["end"] - words[1]["start"] > max_duration:
                    boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
                    words[0]["end"] = words[1]["start"] = boundary
                words[0]["start"] = max(0, words[0]["end"] - max_duration)
            # prefer the segment-level start timestamp if the first word is too long.
            if (
                segment["start"] < words[0]["end"]
                and segment["start"] - 0.5 > words[0]["start"]
            ):
                words[0]["start"] = max(
                    0, min(words[0]["end"] - median_duration, segment["start"])
                )
            if segment["start"] < words[0]["end"] and segment["start"] - 0.5 > words[0]["start"]:
                words[0]["start"] = max(0, min(words[0]["end"] - median_duration, segment["start"]))
            else:
                segment["start"] = words[0]["start"]
            # prefer the segment-level end timestamp if the last word is too long.
            if (
                segment["end"] > words[-1]["start"]
                and segment["end"] + 0.5 < words[-1]["end"]
            ):
                words[-1]["end"] = max(
                    words[-1]["start"] + median_duration, segment["end"]
                )
            if segment["end"] > words[-1]["start"] and segment["end"] + 0.5 < words[-1]["end"]:
                words[-1]["end"] = max(words[-1]["start"] + median_duration, segment["end"])
            else:
                segment["end"] = words[-1]["end"]