From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords
---
funasr/models/sense_voice/whisper_lib/timing.py | 46 +++++++++++-----------------------------------
1 files changed, 11 insertions(+), 35 deletions(-)
diff --git a/funasr/models/sense_voice/whisper_lib/timing.py b/funasr/models/sense_voice/whisper_lib/timing.py
index b695ead..ba9cb13 100644
--- a/funasr/models/sense_voice/whisper_lib/timing.py
+++ b/funasr/models/sense_voice/whisper_lib/timing.py
@@ -27,9 +27,7 @@
# `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
x = x[None, None, :]
- assert (
- filter_width > 0 and filter_width % 2 == 1
- ), "`filter_width` should be an odd number"
+ assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number"
result = None
x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
@@ -111,9 +109,7 @@
M, N = x.shape
assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
- x_skew = (
- F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
- )
+ x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
x_skew = x_skew.T.contiguous()
cost = torch.ones(N + M + 2, M + 2) * np.inf
cost[0, 0] = 0
@@ -132,9 +128,7 @@
BLOCK_SIZE=BLOCK_SIZE,
)
- trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
- :, : N + 1
- ]
+ trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, : N + 1]
return backtrace(trace.cpu().numpy())
@@ -228,8 +222,7 @@
start_times = jump_times[word_boundaries[:-1]]
end_times = jump_times[word_boundaries[1:]]
word_probabilities = [
- np.mean(text_token_probs[i:j])
- for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
+ np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
]
return [
@@ -290,8 +283,7 @@
return
text_tokens_per_segment = [
- [token for token in segment["tokens"] if token < tokenizer.eot]
- for segment in segments
+ [token for token in segment["tokens"] if token < tokenizer.eot] for segment in segments
]
text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
@@ -346,38 +338,22 @@
# twice the median word duration.
if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
words[0]["end"] - words[0]["start"] > max_duration
- or (
- len(words) > 1
- and words[1]["end"] - words[0]["start"] > max_duration * 2
- )
+ or (len(words) > 1 and words[1]["end"] - words[0]["start"] > max_duration * 2)
):
- if (
- len(words) > 1
- and words[1]["end"] - words[1]["start"] > max_duration
- ):
+ if len(words) > 1 and words[1]["end"] - words[1]["start"] > max_duration:
boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
words[0]["end"] = words[1]["start"] = boundary
words[0]["start"] = max(0, words[0]["end"] - max_duration)
# prefer the segment-level start timestamp if the first word is too long.
- if (
- segment["start"] < words[0]["end"]
- and segment["start"] - 0.5 > words[0]["start"]
- ):
- words[0]["start"] = max(
- 0, min(words[0]["end"] - median_duration, segment["start"])
- )
+ if segment["start"] < words[0]["end"] and segment["start"] - 0.5 > words[0]["start"]:
+ words[0]["start"] = max(0, min(words[0]["end"] - median_duration, segment["start"]))
else:
segment["start"] = words[0]["start"]
# prefer the segment-level end timestamp if the last word is too long.
- if (
- segment["end"] > words[-1]["start"]
- and segment["end"] + 0.5 < words[-1]["end"]
- ):
- words[-1]["end"] = max(
- words[-1]["start"] + median_duration, segment["end"]
- )
+ if segment["end"] > words[-1]["start"] and segment["end"] + 0.5 < words[-1]["end"]:
+ words[-1]["end"] = max(words[-1]["start"] + median_duration, segment["end"])
else:
segment["end"] = words[-1]["end"]
--
Gitblit v1.9.1