From 8f596af4be1c2e5c4e4b4a7008ba96f412d40fca Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 29 四月 2024 14:32:43 +0800
Subject: [PATCH] batch

---
 funasr/models/sense_voice/whisper_lib/utils.py |   63 +++++++------------------------
 1 files changed, 15 insertions(+), 48 deletions(-)

diff --git a/funasr/models/sense_voice/whisper_lib/utils.py b/funasr/models/sense_voice/whisper_lib/utils.py
index 9b9b138..5fc6125 100644
--- a/funasr/models/sense_voice/whisper_lib/utils.py
+++ b/funasr/models/sense_voice/whisper_lib/utils.py
@@ -47,9 +47,7 @@
     return len(text_bytes) / len(zlib.compress(text_bytes))
 
 
-def format_timestamp(
-    seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
-):
+def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."):
     assert seconds >= 0, "non-negative timestamp expected"
     milliseconds = round(seconds * 1000.0)
 
@@ -63,9 +61,7 @@
     milliseconds -= seconds * 1_000
 
     hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
-    return (
-        f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
-    )
+    return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
 
 
 def get_start(segments: List[dict]) -> Optional[float]:
@@ -88,30 +84,22 @@
     def __init__(self, output_dir: str):
         self.output_dir = output_dir
 
-    def __call__(
-        self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
-    ):
+    def __call__(self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs):
         audio_basename = os.path.basename(audio_path)
         audio_basename = os.path.splitext(audio_basename)[0]
-        output_path = os.path.join(
-            self.output_dir, audio_basename + "." + self.extension
-        )
+        output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)
 
         with open(output_path, "w", encoding="utf-8") as f:
             self.write_result(result, file=f, options=options, **kwargs)
 
-    def write_result(
-        self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
-    ):
+    def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
         raise NotImplementedError
 
 
 class WriteTXT(ResultWriter):
     extension: str = "txt"
 
-    def write_result(
-        self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
-    ):
+    def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
         for segment in result["segments"]:
             print(segment["text"].strip(), file=file, flush=True)
 
@@ -156,17 +144,10 @@
                         segment["words"][chunk_index : chunk_index + words_count]
                     ):
                         timing = original_timing.copy()
-                        long_pause = (
-                            not preserve_segments and timing["start"] - last > 3.0
-                        )
+                        long_pause = not preserve_segments and timing["start"] - last > 3.0
                         has_room = line_len + len(timing["word"]) <= max_line_width
                         seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
-                        if (
-                            line_len > 0
-                            and has_room
-                            and not long_pause
-                            and not seg_break
-                        ):
+                        if line_len > 0 and has_room and not long_pause and not seg_break:
                             # line continuation
                             line_len += len(timing["word"])
                         else:
@@ -209,9 +190,7 @@
 
                         yield start, end, "".join(
                             [
-                                re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
-                                if j == i
-                                else word
+                                re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word) if j == i else word
                                 for j, word in enumerate(all_words)
                             ]
                         )
@@ -238,9 +217,7 @@
     always_include_hours: bool = False
     decimal_marker: str = "."
 
-    def write_result(
-        self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
-    ):
+    def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
         print("WEBVTT\n", file=file)
         for start, end, text in self.iterate_result(result, options, **kwargs):
             print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
@@ -251,9 +228,7 @@
     always_include_hours: bool = True
     decimal_marker: str = ","
 
-    def write_result(
-        self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
-    ):
+    def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
         for i, (start, end, text) in enumerate(
             self.iterate_result(result, options, **kwargs), start=1
         ):
@@ -272,9 +247,7 @@
 
     extension: str = "tsv"
 
-    def write_result(
-        self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
-    ):
+    def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
         print("start", "end", "text", sep="\t", file=file)
         for segment in result["segments"]:
             print(round(1000 * segment["start"]), file=file, end="\t")
@@ -285,15 +258,11 @@
 class WriteJSON(ResultWriter):
     extension: str = "json"
 
-    def write_result(
-        self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
-    ):
+    def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
         json.dump(result, file)
 
 
-def get_writer(
-    output_format: str, output_dir: str
-) -> Callable[[dict, TextIO, dict], None]:
+def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO, dict], None]:
     writers = {
         "txt": WriteTXT,
         "vtt": WriteVTT,
@@ -305,9 +274,7 @@
     if output_format == "all":
         all_writers = [writer(output_dir) for writer in writers.values()]
 
-        def write_all(
-            result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
-        ):
+        def write_all(result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
             for writer in all_writers:
                 writer(result, file, options, **kwargs)
 

--
Gitblit v1.9.1