游雁
2024-04-29 8f596af4be1c2e5c4e4b4a7008ba96f412d40fca
funasr/models/sense_voice/whisper_lib/transcribe.py
@@ -146,9 +146,7 @@
            _, probs = model.detect_language(mel_segment)
            decode_options["language"] = max(probs, key=probs.get)
            if verbose is not None:
                print(
                    f"Detected language: {LANGUAGES[decode_options['language']].title()}"
                )
                print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
    language: str = decode_options["language"]
    task: str = decode_options.get("task", "transcribe")
@@ -176,9 +174,7 @@
        warnings.warn("Word-level timestamps on translations may not be reliable.")
    def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
        temperatures = (
            [temperature] if isinstance(temperature, (int, float)) else temperature
        )
        temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
        decode_result = None
        for t in temperatures:
@@ -200,10 +196,7 @@
                and decode_result.compression_ratio > compression_ratio_threshold
            ):
                needs_fallback = True  # too repetitive
            if (
                logprob_threshold is not None
                and decode_result.avg_logprob < logprob_threshold
            ):
            if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
                needs_fallback = True  # average log probability is too low
            if (
                no_speech_threshold is not None
@@ -217,9 +210,7 @@
    clip_idx = 0
    seek = seek_clips[clip_idx][0]
    input_stride = exact_div(
        N_FRAMES, model.dims.n_audio_ctx
    )  # mel frames per output token: 2
    input_stride = exact_div(N_FRAMES, model.dims.n_audio_ctx)  # mel frames per output token: 2
    time_precision = (
        input_stride * HOP_LENGTH / SAMPLE_RATE
    )  # time per output token: 0.02 (seconds)
@@ -233,9 +224,7 @@
    else:
        initial_prompt_tokens = []
    def new_segment(
        *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
    ):
    def new_segment(*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult):
        tokens = tokens.tolist()
        text_tokens = [token for token in tokens if token < tokenizer.eot]
        return {
@@ -251,9 +240,7 @@
        }
    # show the progress bar when verbose is False (if True, transcribed text will be printed)
    with tqdm.tqdm(
        total=content_frames, unit="frames", disable=verbose is not False
    ) as pbar:
    with tqdm.tqdm(total=content_frames, unit="frames", disable=verbose is not False) as pbar:
        last_speech_timestamp = 0.0
        # NOTE: This loop is obscurely flattened to make the diff readable.
        # A later commit should turn this into a simpler nested loop.
@@ -282,10 +269,7 @@
            if no_speech_threshold is not None:
                # no voice activity check
                should_skip = result.no_speech_prob > no_speech_threshold
                if (
                    logprob_threshold is not None
                    and result.avg_logprob > logprob_threshold
                ):
                if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
                    # don't skip if the logprob is high enough, despite the no_speech_prob
                    should_skip = False
@@ -334,12 +318,8 @@
                last_slice = 0
                for current_slice in slices:
                    sliced_tokens = tokens[last_slice:current_slice]
                    start_timestamp_pos = (
                        sliced_tokens[0].item() - tokenizer.timestamp_begin
                    )
                    end_timestamp_pos = (
                        sliced_tokens[-1].item() - tokenizer.timestamp_begin
                    )
                    start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin
                    end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin
                    current_segments.append(
                        new_segment(
                            start=time_offset + start_timestamp_pos * time_precision,
@@ -355,21 +335,14 @@
                    seek += segment_size
                else:
                    # otherwise, ignore the unfinished segment and seek to the last timestamp
                    last_timestamp_pos = (
                        tokens[last_slice - 1].item() - tokenizer.timestamp_begin
                    )
                    last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin
                    seek += last_timestamp_pos * input_stride
            else:
                duration = segment_duration
                timestamps = tokens[timestamp_tokens.nonzero().flatten()]
                if (
                    len(timestamps) > 0
                    and timestamps[-1].item() != tokenizer.timestamp_begin
                ):
                if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
                    # no consecutive timestamps but it has a timestamp; use the last one.
                    last_timestamp_pos = (
                        timestamps[-1].item() - tokenizer.timestamp_begin
                    )
                    last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin
                    duration = last_timestamp_pos * time_precision
                current_segments.append(
@@ -427,9 +400,7 @@
                        if not segment["words"]:
                            continue
                        if is_segment_anomaly(segment):
                            next_segment = next_words_segment(
                                current_segments[si + 1 :]
                            )
                            next_segment = next_words_segment(current_segments[si + 1 :])
                            if next_segment is not None:
                                hal_next_start = next_segment["words"][0]["start"]
                            else:
@@ -446,8 +417,7 @@
                            )
                            if silence_before and silence_after:
                                seek = round(
                                    max(time_offset + 1, segment["start"])
                                    * FRAMES_PER_SECOND
                                    max(time_offset + 1, segment["start"]) * FRAMES_PER_SECOND
                                )
                                if content_duration - segment["end"] < threshold:
                                    seek = content_frames
@@ -475,9 +445,7 @@
            all_segments.extend(
                [
                    {"id": i, **segment}
                    for i, segment in enumerate(
                        current_segments, start=len(all_segments)
                    )
                    for i, segment in enumerate(current_segments, start=len(all_segments))
                ]
            )
            all_tokens.extend(