游雁
2024-04-29 8f596af4be1c2e5c4e4b4a7008ba96f412d40fca
funasr/models/sense_voice/whisper_lib/utils.py
@@ -47,9 +47,7 @@
    return len(text_bytes) / len(zlib.compress(text_bytes))
def format_timestamp(
    seconds: float, always_include_hours: bool = False, decimal_marker: str = "."
):
def format_timestamp(seconds: float, always_include_hours: bool = False, decimal_marker: str = "."):
    assert seconds >= 0, "non-negative timestamp expected"
    milliseconds = round(seconds * 1000.0)
@@ -63,9 +61,7 @@
    milliseconds -= seconds * 1_000
    hours_marker = f"{hours:02d}:" if always_include_hours or hours > 0 else ""
    return (
        f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
    )
    return f"{hours_marker}{minutes:02d}:{seconds:02d}{decimal_marker}{milliseconds:03d}"
def get_start(segments: List[dict]) -> Optional[float]:
@@ -88,30 +84,22 @@
    def __init__(self, output_dir: str):
        self.output_dir = output_dir
    def __call__(
        self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs
    ):
    def __call__(self, result: dict, audio_path: str, options: Optional[dict] = None, **kwargs):
        audio_basename = os.path.basename(audio_path)
        audio_basename = os.path.splitext(audio_basename)[0]
        output_path = os.path.join(
            self.output_dir, audio_basename + "." + self.extension
        )
        output_path = os.path.join(self.output_dir, audio_basename + "." + self.extension)
        with open(output_path, "w", encoding="utf-8") as f:
            self.write_result(result, file=f, options=options, **kwargs)
    def write_result(
        self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
    ):
    def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
        raise NotImplementedError
class WriteTXT(ResultWriter):
    extension: str = "txt"
    def write_result(
        self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
    ):
    def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
        for segment in result["segments"]:
            print(segment["text"].strip(), file=file, flush=True)
@@ -156,17 +144,10 @@
                        segment["words"][chunk_index : chunk_index + words_count]
                    ):
                        timing = original_timing.copy()
                        long_pause = (
                            not preserve_segments and timing["start"] - last > 3.0
                        )
                        long_pause = not preserve_segments and timing["start"] - last > 3.0
                        has_room = line_len + len(timing["word"]) <= max_line_width
                        seg_break = i == 0 and len(subtitle) > 0 and preserve_segments
                        if (
                            line_len > 0
                            and has_room
                            and not long_pause
                            and not seg_break
                        ):
                        if line_len > 0 and has_room and not long_pause and not seg_break:
                            # line continuation
                            line_len += len(timing["word"])
                        else:
@@ -209,9 +190,7 @@
                        yield start, end, "".join(
                            [
                                re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word)
                                if j == i
                                else word
                                re.sub(r"^(\s*)(.*)$", r"\1<u>\2</u>", word) if j == i else word
                                for j, word in enumerate(all_words)
                            ]
                        )
@@ -238,9 +217,7 @@
    always_include_hours: bool = False
    decimal_marker: str = "."
    def write_result(
        self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
    ):
    def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
        print("WEBVTT\n", file=file)
        for start, end, text in self.iterate_result(result, options, **kwargs):
            print(f"{start} --> {end}\n{text}\n", file=file, flush=True)
@@ -251,9 +228,7 @@
    always_include_hours: bool = True
    decimal_marker: str = ","
    def write_result(
        self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
    ):
    def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
        for i, (start, end, text) in enumerate(
            self.iterate_result(result, options, **kwargs), start=1
        ):
@@ -272,9 +247,7 @@
    extension: str = "tsv"
    def write_result(
        self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
    ):
    def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
        print("start", "end", "text", sep="\t", file=file)
        for segment in result["segments"]:
            print(round(1000 * segment["start"]), file=file, end="\t")
@@ -285,15 +258,11 @@
class WriteJSON(ResultWriter):
    extension: str = "json"
    def write_result(
        self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
    ):
    def write_result(self, result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
        json.dump(result, file)
def get_writer(
    output_format: str, output_dir: str
) -> Callable[[dict, TextIO, dict], None]:
def get_writer(output_format: str, output_dir: str) -> Callable[[dict, TextIO, dict], None]:
    writers = {
        "txt": WriteTXT,
        "vtt": WriteVTT,
@@ -305,9 +274,7 @@
    if output_format == "all":
        all_writers = [writer(output_dir) for writer in writers.values()]
        def write_all(
            result: dict, file: TextIO, options: Optional[dict] = None, **kwargs
        ):
        def write_all(result: dict, file: TextIO, options: Optional[dict] = None, **kwargs):
            for writer in all_writers:
                writer(result, file, options, **kwargs)