From f57b68121a526baea43b2e93f4540d8a2995f633 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 29 四月 2024 15:15:24 +0800
Subject: [PATCH] batch

---
 funasr/models/sense_voice/whisper_lib/transcribe.py |   62 +++++++-----------------------
 1 files changed, 15 insertions(+), 47 deletions(-)

diff --git a/funasr/models/sense_voice/whisper_lib/transcribe.py b/funasr/models/sense_voice/whisper_lib/transcribe.py
index 1c075a2..5c5f49d 100644
--- a/funasr/models/sense_voice/whisper_lib/transcribe.py
+++ b/funasr/models/sense_voice/whisper_lib/transcribe.py
@@ -146,9 +146,7 @@
             _, probs = model.detect_language(mel_segment)
             decode_options["language"] = max(probs, key=probs.get)
             if verbose is not None:
-                print(
-                    f"Detected language: {LANGUAGES[decode_options['language']].title()}"
-                )
+                print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
 
     language: str = decode_options["language"]
     task: str = decode_options.get("task", "transcribe")
@@ -176,9 +174,7 @@
         warnings.warn("Word-level timestamps on translations may not be reliable.")
 
     def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
-        temperatures = (
-            [temperature] if isinstance(temperature, (int, float)) else temperature
-        )
+        temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
         decode_result = None
 
         for t in temperatures:
@@ -200,10 +196,7 @@
                 and decode_result.compression_ratio > compression_ratio_threshold
             ):
                 needs_fallback = True  # too repetitive
-            if (
-                logprob_threshold is not None
-                and decode_result.avg_logprob < logprob_threshold
-            ):
+            if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
                 needs_fallback = True  # average log probability is too low
             if (
                 no_speech_threshold is not None
@@ -217,9 +210,7 @@
 
     clip_idx = 0
     seek = seek_clips[clip_idx][0]
-    input_stride = exact_div(
-        N_FRAMES, model.dims.n_audio_ctx
-    )  # mel frames per output token: 2
+    input_stride = exact_div(N_FRAMES, model.dims.n_audio_ctx)  # mel frames per output token: 2
     time_precision = (
         input_stride * HOP_LENGTH / SAMPLE_RATE
     )  # time per output token: 0.02 (seconds)
@@ -233,9 +224,7 @@
     else:
         initial_prompt_tokens = []
 
-    def new_segment(
-        *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
-    ):
+    def new_segment(*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult):
         tokens = tokens.tolist()
         text_tokens = [token for token in tokens if token < tokenizer.eot]
         return {
@@ -251,9 +240,7 @@
         }
 
     # show the progress bar when verbose is False (if True, transcribed text will be printed)
-    with tqdm.tqdm(
-        total=content_frames, unit="frames", disable=verbose is not False
-    ) as pbar:
+    with tqdm.tqdm(total=content_frames, unit="frames", disable=verbose is not False) as pbar:
         last_speech_timestamp = 0.0
         # NOTE: This loop is obscurely flattened to make the diff readable.
         # A later commit should turn this into a simpler nested loop.
@@ -282,10 +269,7 @@
             if no_speech_threshold is not None:
                 # no voice activity check
                 should_skip = result.no_speech_prob > no_speech_threshold
-                if (
-                    logprob_threshold is not None
-                    and result.avg_logprob > logprob_threshold
-                ):
+                if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
                     # don't skip if the logprob is high enough, despite the no_speech_prob
                     should_skip = False
 
@@ -334,12 +318,8 @@
                 last_slice = 0
                 for current_slice in slices:
                     sliced_tokens = tokens[last_slice:current_slice]
-                    start_timestamp_pos = (
-                        sliced_tokens[0].item() - tokenizer.timestamp_begin
-                    )
-                    end_timestamp_pos = (
-                        sliced_tokens[-1].item() - tokenizer.timestamp_begin
-                    )
+                    start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin
+                    end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin
                     current_segments.append(
                         new_segment(
                             start=time_offset + start_timestamp_pos * time_precision,
@@ -355,21 +335,14 @@
                     seek += segment_size
                 else:
                     # otherwise, ignore the unfinished segment and seek to the last timestamp
-                    last_timestamp_pos = (
-                        tokens[last_slice - 1].item() - tokenizer.timestamp_begin
-                    )
+                    last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin
                     seek += last_timestamp_pos * input_stride
             else:
                 duration = segment_duration
                 timestamps = tokens[timestamp_tokens.nonzero().flatten()]
-                if (
-                    len(timestamps) > 0
-                    and timestamps[-1].item() != tokenizer.timestamp_begin
-                ):
+                if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
                     # no consecutive timestamps but it has a timestamp; use the last one.
-                    last_timestamp_pos = (
-                        timestamps[-1].item() - tokenizer.timestamp_begin
-                    )
+                    last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin
                     duration = last_timestamp_pos * time_precision
 
                 current_segments.append(
@@ -427,9 +400,7 @@
                         if not segment["words"]:
                             continue
                         if is_segment_anomaly(segment):
-                            next_segment = next_words_segment(
-                                current_segments[si + 1 :]
-                            )
+                            next_segment = next_words_segment(current_segments[si + 1 :])
                             if next_segment is not None:
                                 hal_next_start = next_segment["words"][0]["start"]
                             else:
@@ -446,8 +417,7 @@
                             )
                             if silence_before and silence_after:
                                 seek = round(
-                                    max(time_offset + 1, segment["start"])
-                                    * FRAMES_PER_SECOND
+                                    max(time_offset + 1, segment["start"]) * FRAMES_PER_SECOND
                                 )
                                 if content_duration - segment["end"] < threshold:
                                     seek = content_frames
@@ -475,9 +445,7 @@
             all_segments.extend(
                 [
                     {"id": i, **segment}
-                    for i, segment in enumerate(
-                        current_segments, start=len(all_segments)
-                    )
+                    for i, segment in enumerate(current_segments, start=len(all_segments))
                 ]
             )
             all_tokens.extend(

--
Gitblit v1.9.1