游雁
2024-06-14 59bc02b089f7a626fe67907dcfc695eae6883f82
funasr/auto/auto_model.py
@@ -19,6 +19,7 @@
from funasr.utils.load_utils import load_bytes
from funasr.download.file import download_from_url
from funasr.utils.timestamp_tools import timestamp_sentence
from funasr.utils.timestamp_tools import timestamp_sentence_en
from funasr.download.download_from_hub import download_model
from funasr.utils.vad_utils import slice_padding_audio_samples
from funasr.utils.vad_utils import merge_vad
@@ -212,7 +213,6 @@
        deep_update(model_conf, kwargs.get("model_conf", {}))
        deep_update(model_conf, kwargs)
        model = model_class(**model_conf, vocab_size=vocab_size)
        model.to(device)
        # init_param
        init_param = kwargs.get("init_param", None)
@@ -235,6 +235,7 @@
            model.to(torch.float16)
        elif kwargs.get("bf16", False):
            model.to(torch.bfloat16)
        model.to(device)
        return model, kwargs
    def __call__(self, *args, **cfg):
@@ -323,7 +324,7 @@
            input, input_len=input_len, model=self.vad_model, kwargs=self.vad_kwargs, **cfg
        )
        end_vad = time.time()
        #  FIX(gcf): concat the vad clips for sense vocie model for better aed
        if kwargs.get("merge_vad", False):
            for i in range(len(res)):
@@ -519,24 +520,40 @@
                                       and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
                                       can predict timestamp, and speaker diarization relies on timestamps."
                        )
                    sentence_list = timestamp_sentence(
                        punc_res[0]["punc_array"],
                        result["timestamp"],
                        raw_text,
                        return_raw_text=return_raw_text,
                    )
                    if kwargs.get("en_post_proc", False):
                        sentence_list = timestamp_sentence_en(
                            punc_res[0]["punc_array"],
                            result["timestamp"],
                            raw_text,
                            return_raw_text=return_raw_text,
                        )
                    else:
                        sentence_list = timestamp_sentence(
                            punc_res[0]["punc_array"],
                            result["timestamp"],
                            raw_text,
                            return_raw_text=return_raw_text,
                        )
                distribute_spk(sentence_list, sv_output)
                result["sentence_info"] = sentence_list
            elif kwargs.get("sentence_timestamp", False):
                if not len(result["text"].strip()):
                    sentence_list = []
                else:
                    sentence_list = timestamp_sentence(
                        punc_res[0]["punc_array"],
                        result["timestamp"],
                        raw_text,
                        return_raw_text=return_raw_text,
                    )
                    if kwargs.get("en_post_proc", False):
                        sentence_list = timestamp_sentence_en(
                            punc_res[0]["punc_array"],
                            result["timestamp"],
                            raw_text,
                            return_raw_text=return_raw_text,
                        )
                    else:
                        sentence_list = timestamp_sentence(
                            punc_res[0]["punc_array"],
                            result["timestamp"],
                            raw_text,
                            return_raw_text=return_raw_text,
                        )
                result["sentence_info"] = sentence_list
            if "spk_embedding" in result:
                del result["spk_embedding"]