From f57b68121a526baea43b2e93f4540d8a2995f633 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 29 四月 2024 15:15:24 +0800
Subject: [PATCH] batch
---
funasr/models/sense_voice/whisper_lib/transcribe.py | 62 +++++++-----------------------
1 files changed, 15 insertions(+), 47 deletions(-)
diff --git a/funasr/models/sense_voice/whisper_lib/transcribe.py b/funasr/models/sense_voice/whisper_lib/transcribe.py
index 1c075a2..5c5f49d 100644
--- a/funasr/models/sense_voice/whisper_lib/transcribe.py
+++ b/funasr/models/sense_voice/whisper_lib/transcribe.py
@@ -146,9 +146,7 @@
_, probs = model.detect_language(mel_segment)
decode_options["language"] = max(probs, key=probs.get)
if verbose is not None:
- print(
- f"Detected language: {LANGUAGES[decode_options['language']].title()}"
- )
+ print(f"Detected language: {LANGUAGES[decode_options['language']].title()}")
language: str = decode_options["language"]
task: str = decode_options.get("task", "transcribe")
@@ -176,9 +174,7 @@
warnings.warn("Word-level timestamps on translations may not be reliable.")
def decode_with_fallback(segment: torch.Tensor) -> DecodingResult:
- temperatures = (
- [temperature] if isinstance(temperature, (int, float)) else temperature
- )
+ temperatures = [temperature] if isinstance(temperature, (int, float)) else temperature
decode_result = None
for t in temperatures:
@@ -200,10 +196,7 @@
and decode_result.compression_ratio > compression_ratio_threshold
):
needs_fallback = True # too repetitive
- if (
- logprob_threshold is not None
- and decode_result.avg_logprob < logprob_threshold
- ):
+ if logprob_threshold is not None and decode_result.avg_logprob < logprob_threshold:
needs_fallback = True # average log probability is too low
if (
no_speech_threshold is not None
@@ -217,9 +210,7 @@
clip_idx = 0
seek = seek_clips[clip_idx][0]
- input_stride = exact_div(
- N_FRAMES, model.dims.n_audio_ctx
- ) # mel frames per output token: 2
+ input_stride = exact_div(N_FRAMES, model.dims.n_audio_ctx) # mel frames per output token: 2
time_precision = (
input_stride * HOP_LENGTH / SAMPLE_RATE
) # time per output token: 0.02 (seconds)
@@ -233,9 +224,7 @@
else:
initial_prompt_tokens = []
- def new_segment(
- *, start: float, end: float, tokens: torch.Tensor, result: DecodingResult
- ):
+ def new_segment(*, start: float, end: float, tokens: torch.Tensor, result: DecodingResult):
tokens = tokens.tolist()
text_tokens = [token for token in tokens if token < tokenizer.eot]
return {
@@ -251,9 +240,7 @@
}
# show the progress bar when verbose is False (if True, transcribed text will be printed)
- with tqdm.tqdm(
- total=content_frames, unit="frames", disable=verbose is not False
- ) as pbar:
+ with tqdm.tqdm(total=content_frames, unit="frames", disable=verbose is not False) as pbar:
last_speech_timestamp = 0.0
# NOTE: This loop is obscurely flattened to make the diff readable.
# A later commit should turn this into a simpler nested loop.
@@ -282,10 +269,7 @@
if no_speech_threshold is not None:
# no voice activity check
should_skip = result.no_speech_prob > no_speech_threshold
- if (
- logprob_threshold is not None
- and result.avg_logprob > logprob_threshold
- ):
+ if logprob_threshold is not None and result.avg_logprob > logprob_threshold:
# don't skip if the logprob is high enough, despite the no_speech_prob
should_skip = False
@@ -334,12 +318,8 @@
last_slice = 0
for current_slice in slices:
sliced_tokens = tokens[last_slice:current_slice]
- start_timestamp_pos = (
- sliced_tokens[0].item() - tokenizer.timestamp_begin
- )
- end_timestamp_pos = (
- sliced_tokens[-1].item() - tokenizer.timestamp_begin
- )
+ start_timestamp_pos = sliced_tokens[0].item() - tokenizer.timestamp_begin
+ end_timestamp_pos = sliced_tokens[-1].item() - tokenizer.timestamp_begin
current_segments.append(
new_segment(
start=time_offset + start_timestamp_pos * time_precision,
@@ -355,21 +335,14 @@
seek += segment_size
else:
# otherwise, ignore the unfinished segment and seek to the last timestamp
- last_timestamp_pos = (
- tokens[last_slice - 1].item() - tokenizer.timestamp_begin
- )
+ last_timestamp_pos = tokens[last_slice - 1].item() - tokenizer.timestamp_begin
seek += last_timestamp_pos * input_stride
else:
duration = segment_duration
timestamps = tokens[timestamp_tokens.nonzero().flatten()]
- if (
- len(timestamps) > 0
- and timestamps[-1].item() != tokenizer.timestamp_begin
- ):
+ if len(timestamps) > 0 and timestamps[-1].item() != tokenizer.timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one.
- last_timestamp_pos = (
- timestamps[-1].item() - tokenizer.timestamp_begin
- )
+ last_timestamp_pos = timestamps[-1].item() - tokenizer.timestamp_begin
duration = last_timestamp_pos * time_precision
current_segments.append(
@@ -427,9 +400,7 @@
if not segment["words"]:
continue
if is_segment_anomaly(segment):
- next_segment = next_words_segment(
- current_segments[si + 1 :]
- )
+ next_segment = next_words_segment(current_segments[si + 1 :])
if next_segment is not None:
hal_next_start = next_segment["words"][0]["start"]
else:
@@ -446,8 +417,7 @@
)
if silence_before and silence_after:
seek = round(
- max(time_offset + 1, segment["start"])
- * FRAMES_PER_SECOND
+ max(time_offset + 1, segment["start"]) * FRAMES_PER_SECOND
)
if content_duration - segment["end"] < threshold:
seek = content_frames
@@ -475,9 +445,7 @@
all_segments.extend(
[
{"id": i, **segment}
- for i, segment in enumerate(
- current_segments, start=len(all_segments)
- )
+ for i, segment in enumerate(current_segments, start=len(all_segments))
]
)
all_tokens.extend(
--
Gitblit v1.9.1