维石
2024-06-11 3e9319263835bd018abe2dcd59e029603b714022
english timestamp for valilla paraformer
3个文件已修改
115 ■■■■■ 已修改文件
examples/industrial_data_pretraining/paraformer/demo.py 13 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/auto/auto_model.py 17 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/utils/timestamp_tools.py 85 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
examples/industrial_data_pretraining/paraformer/demo.py
@@ -21,6 +21,19 @@
print(res)
""" call english model like below for detailed timestamps
# choose english paraformer model first
# iic/speech_paraformer_asr-en-16k-vocab4199-pytorch
res = model.generate(
    input="https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_audio/asr_example_en.wav",
    cache={},
    pred_timestamp=True,
    return_raw_text=True,
    sentence_timestamp=True,
    en_post_proc=True,
)
"""
""" can not use currently
from funasr import AutoFrontend
funasr/auto/auto_model.py
@@ -19,6 +19,7 @@
from funasr.utils.load_utils import load_bytes
from funasr.download.file import download_from_url
from funasr.utils.timestamp_tools import timestamp_sentence
from funasr.utils.timestamp_tools import timestamp_sentence_en
from funasr.download.download_from_hub import download_model
from funasr.utils.vad_utils import slice_padding_audio_samples
from funasr.utils.vad_utils import merge_vad
@@ -513,6 +514,14 @@
                                       and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
                                       can predict timestamp, and speaker diarization relies on timestamps."
                        )
                    if kwargs.get("en_post_proc", False):
                        sentence_list = timestamp_sentence_en(
                            punc_res[0]["punc_array"],
                            result["timestamp"],
                            raw_text,
                            return_raw_text=return_raw_text,
                        )
                    else:
                    sentence_list = timestamp_sentence(
                        punc_res[0]["punc_array"],
                        result["timestamp"],
@@ -525,6 +534,14 @@
                if not len(result["text"].strip()):
                    sentence_list = []
                else:
                    if kwargs.get("en_post_proc", False):
                        sentence_list = timestamp_sentence_en(
                            punc_res[0]["punc_array"],
                            result["timestamp"],
                            raw_text,
                            return_raw_text=return_raw_text,
                        )
                    else:
                    sentence_list = timestamp_sentence(
                        punc_res[0]["punc_array"],
                        result["timestamp"],
funasr/utils/timestamp_tools.py
@@ -185,3 +185,88 @@
            ts_list = []
            sentence_start = sentence_end
    return res
def timestamp_sentence_en(
    punc_id_list, timestamp_postprocessed, text_postprocessed, return_raw_text=False
):
    punc_list = [",", ".", "?", ","]
    res = []
    if text_postprocessed is None:
        return res
    if timestamp_postprocessed is None:
        return res
    if len(timestamp_postprocessed) == 0:
        return res
    if len(text_postprocessed) == 0:
        return res
    if punc_id_list is None or len(punc_id_list) == 0:
        res.append(
            {
                "text": text_postprocessed.split(),
                "start": timestamp_postprocessed[0][0],
                "end": timestamp_postprocessed[-1][1],
                "timestamp": timestamp_postprocessed,
            }
        )
        return res
    if len(punc_id_list) != len(timestamp_postprocessed):
        logging.warning("length mismatch between punc and timestamp")
    sentence_text = ""
    sentence_text_seg = ""
    ts_list = []
    sentence_start = timestamp_postprocessed[0][0]
    sentence_end = timestamp_postprocessed[0][1]
    texts = text_postprocessed.split()
    punc_stamp_text_list = list(
        zip_longest(punc_id_list, timestamp_postprocessed, texts, fillvalue=None)
    )
    for punc_stamp_text in punc_stamp_text_list:
        punc_id, timestamp, text = punc_stamp_text
        # sentence_text += text if text is not None else ''
        if text is not None:
            if "a" <= text[0] <= "z" or "A" <= text[0] <= "Z":
                sentence_text += " " + text
            elif len(sentence_text) and (
                "a" <= sentence_text[-1] <= "z" or "A" <= sentence_text[-1] <= "Z"
            ):
                sentence_text += " " + text
            else:
                sentence_text += text
            sentence_text_seg += text + " "
        ts_list.append(timestamp)
        punc_id = int(punc_id) if punc_id is not None else 1
        sentence_end = timestamp[1] if timestamp is not None else sentence_end
        sentence_text = sentence_text[1:] if sentence_text[0] == ' ' else sentence_text
        if punc_id > 1:
            sentence_text += punc_list[punc_id - 2]
            sentence_text_seg = (
                sentence_text_seg[:-1] if sentence_text_seg[-1] == " " else sentence_text_seg
            )
            if return_raw_text:
                res.append(
                    {
                        "text": sentence_text,
                        "start": sentence_start,
                        "end": sentence_end,
                        "timestamp": ts_list,
                        "raw_text": sentence_text_seg,
                    }
                )
            else:
                res.append(
                    {
                        "text": sentence_text,
                        "start": sentence_start,
                        "end": sentence_end,
                        "timestamp": ts_list,
                    }
                )
            sentence_text = ""
            sentence_text_seg = ""
            ts_list = []
            sentence_start = sentence_end
    return res