From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords

---
 funasr/models/sense_voice/whisper_lib/timing.py |   46 +++++++++++-----------------------------------
 1 files changed, 11 insertions(+), 35 deletions(-)

diff --git a/funasr/models/sense_voice/whisper_lib/timing.py b/funasr/models/sense_voice/whisper_lib/timing.py
index b695ead..ba9cb13 100644
--- a/funasr/models/sense_voice/whisper_lib/timing.py
+++ b/funasr/models/sense_voice/whisper_lib/timing.py
@@ -27,9 +27,7 @@
         # `F.pad` does not support 1D or 2D inputs for reflect padding but supports 3D and 4D
         x = x[None, None, :]
 
-    assert (
-        filter_width > 0 and filter_width % 2 == 1
-    ), "`filter_width` should be an odd number"
+    assert filter_width > 0 and filter_width % 2 == 1, "`filter_width` should be an odd number"
 
     result = None
     x = F.pad(x, (filter_width // 2, filter_width // 2, 0, 0), mode="reflect")
@@ -111,9 +109,7 @@
     M, N = x.shape
     assert M < BLOCK_SIZE, f"M should be smaller than {BLOCK_SIZE=}"
 
-    x_skew = (
-        F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
-    )
+    x_skew = F.pad(x, (0, M + 1), value=np.inf).flatten()[: M * (N + M)].reshape(M, N + M)
     x_skew = x_skew.T.contiguous()
     cost = torch.ones(N + M + 2, M + 2) * np.inf
     cost[0, 0] = 0
@@ -132,9 +128,7 @@
         BLOCK_SIZE=BLOCK_SIZE,
     )
 
-    trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[
-        :, : N + 1
-    ]
+    trace = trace.T.flatten()[: (M + 1) * (M + N + 3)].reshape(M + 1, M + N + 3)[:, : N + 1]
     return backtrace(trace.cpu().numpy())
 
 
@@ -228,8 +222,7 @@
     start_times = jump_times[word_boundaries[:-1]]
     end_times = jump_times[word_boundaries[1:]]
     word_probabilities = [
-        np.mean(text_token_probs[i:j])
-        for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
+        np.mean(text_token_probs[i:j]) for i, j in zip(word_boundaries[:-1], word_boundaries[1:])
     ]
 
     return [
@@ -290,8 +283,7 @@
         return
 
     text_tokens_per_segment = [
-        [token for token in segment["tokens"] if token < tokenizer.eot]
-        for segment in segments
+        [token for token in segment["tokens"] if token < tokenizer.eot] for segment in segments
     ]
 
     text_tokens = list(itertools.chain.from_iterable(text_tokens_per_segment))
@@ -346,38 +338,22 @@
             # twice the median word duration.
             if words[0]["end"] - last_speech_timestamp > median_duration * 4 and (
                 words[0]["end"] - words[0]["start"] > max_duration
-                or (
-                    len(words) > 1
-                    and words[1]["end"] - words[0]["start"] > max_duration * 2
-                )
+                or (len(words) > 1 and words[1]["end"] - words[0]["start"] > max_duration * 2)
             ):
-                if (
-                    len(words) > 1
-                    and words[1]["end"] - words[1]["start"] > max_duration
-                ):
+                if len(words) > 1 and words[1]["end"] - words[1]["start"] > max_duration:
                     boundary = max(words[1]["end"] / 2, words[1]["end"] - max_duration)
                     words[0]["end"] = words[1]["start"] = boundary
                 words[0]["start"] = max(0, words[0]["end"] - max_duration)
 
             # prefer the segment-level start timestamp if the first word is too long.
-            if (
-                segment["start"] < words[0]["end"]
-                and segment["start"] - 0.5 > words[0]["start"]
-            ):
-                words[0]["start"] = max(
-                    0, min(words[0]["end"] - median_duration, segment["start"])
-                )
+            if segment["start"] < words[0]["end"] and segment["start"] - 0.5 > words[0]["start"]:
+                words[0]["start"] = max(0, min(words[0]["end"] - median_duration, segment["start"]))
             else:
                 segment["start"] = words[0]["start"]
 
             # prefer the segment-level end timestamp if the last word is too long.
-            if (
-                segment["end"] > words[-1]["start"]
-                and segment["end"] + 0.5 < words[-1]["end"]
-            ):
-                words[-1]["end"] = max(
-                    words[-1]["start"] + median_duration, segment["end"]
-                )
+            if segment["end"] > words[-1]["start"] and segment["end"] + 0.5 < words[-1]["end"]:
+                words[-1]["end"] = max(words[-1]["start"] + median_duration, segment["end"])
             else:
                 segment["end"] = words[-1]["end"]
 

--
Gitblit v1.9.1