From 8f596af4be1c2e5c4e4b4a7008ba96f412d40fca Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 29 四月 2024 14:32:43 +0800
Subject: [PATCH] batch
---
funasr/models/sense_voice/whisper_lib/utils.py | 63 +++++++------------------------
1 files changed, 15 insertions(+), 48 deletions(-)
diff --git a/funasr/models/sense_voice/whisper_lib/utils.py b/funasr/models/sense_voice/whisper_lib/utils.py
index 9b9b138..5fc6125 100644
--- a/funasr/models/sense_voice/whisper_lib/utils.py
+++ b/funasr/models/sense_voice/whisper_lib/utils.py
@@ -47,9 +47,7 @@
return len(text_bytes) / len(zlib.compress(text_bytes))
-def format_timestamp(
- seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
-):
+def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."):
assert seconds >= 0, "non-negative timestamp expected"
milliseconds = round(seconds * 1000.0)
@@ -63,9 +61,7 @@
milliseconds -= seconds * 1_000
hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
- return (
- f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
- )
+ return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
def get_start(segments: List[dict]) -> Optional[float]:
@@ -88,30 +84,22 @@
def __init__(self, output_dir: str):
self.output_dir = output_dir
- def __call__(
- self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
- ):
+ def __call__(self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs):
audio_basename = os.path.basename(audio_path)
audio_basename = os.path.splitext(audio_basename)[0]
- output_path = os.path.join(
- self.output_dir, audio_basename + "." + self.extension
- )
+ output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)
with open(output_path, "w", encoding="utf-8") as f:
self.write_result(result, file=f, options=options, **kwargs)
- def write_result(
- self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
- ):
+ def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
raise NotImplementedError
class WriteTXT(ResultWriter):
extension: str = "txt"
- def write_result(
- self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
- ):
+ def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
for segment in result["segments"]:
print(segment["text"].strip(), file=file, flush=True)
@@ -156,17 +144,10 @@
segment["words"][chunk_index : chunk_index + words_count]
):
timing = original_timing.copy()
- long_pause = (
- not preserve_segments and timing["start"] - last > 3.0
- )
+ long_pause = not preserve_segments and timing["start"] - last > 3.0
has_room = line_len + len(timing["word"]) <= max_line_width
seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
- if (
- line_len > 0
- and has_room
- and not long_pause
- and not seg_break
- ):
+ if line_len > 0 and has_room and not long_pause and not seg_break:
# line continuation
line_len += len(timing["word"])
else:
@@ -209,9 +190,7 @@
yield start, end, "".join(
[
- re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
- if j == i
- else word
+ re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word) if j == i else word
for j, word in enumerate(all_words)
]
)
@@ -238,9 +217,7 @@
always_include_hours: bool = False
decimal_marker: str = "."
- def write_result(
- self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
- ):
+ def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
print("WEBVTT\n", file=file)
for start, end, text in self.iterate_result(result, options, **kwargs):
print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
@@ -251,9 +228,7 @@
always_include_hours: bool = True
decimal_marker: str = ","
- def write_result(
- self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
- ):
+ def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
for i, (start, end, text) in enumerate(
self.iterate_result(result, options, **kwargs), start=1
):
@@ -272,9 +247,7 @@
extension: str = "tsv"
- def write_result(
- self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
- ):
+ def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
print("start", "end", "text", sep="\t", file=file)
for segment in result["segments"]:
print(round(1000 * segment["start"]), file=file, end="\t")
@@ -285,15 +258,11 @@
class WriteJSON(ResultWriter):
extension: str = "json"
- def write_result(
- self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
- ):
+ def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
json.dump(result, file)
-def get_writer(
- output_format: str, output_dir: str
-) -> Callable[[dict, TextIO, dict], None]:
+def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO, dict], None]:
writers = {
"txt": WriteTXT,
"vtt": WriteVTT,
@@ -305,9 +274,7 @@
if output_format == "all":
all_writers = [writer(output_dir) for writer in writers.values()]
- def write_all(
- result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
- ):
+ def write_all(result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
for writer in all_writers:
writer(result, file, options, **kwargs)
--
Gitblit v1.9.1