zhifu gao
2024-06-20 e65b1f701abca03bf3a1b5fbb200392aabd38c22
funasr/auto/auto_model.py
@@ -19,6 +19,7 @@
from funasr.utils.load_utils import load_bytes
from funasr.download.file import download_from_url
from funasr.utils.timestamp_tools import timestamp_sentence
from funasr.utils.timestamp_tools import timestamp_sentence_en
from funasr.download.download_from_hub import download_model
from funasr.utils.vad_utils import slice_padding_audio_samples
from funasr.utils.vad_utils import merge_vad
@@ -42,8 +43,9 @@
    filelist = [".scp", ".txt", ".json", ".jsonl", ".text"]
    chars = string.ascii_letters + string.digits
    if isinstance(data_in, str) and data_in.startswith("http"):  # url
        data_in = download_from_url(data_in)
    if isinstance(data_in, str):
        if data_in.startswith("http://") or data_in.startswith("https://"):  # url
            data_in = download_from_url(data_in)
    if isinstance(data_in, str) and os.path.exists(
        data_in
@@ -90,7 +92,8 @@
                if isinstance(data_i, str) and os.path.exists(data_i):
                    key = misc.extract_filename_without_extension(data_i)
                else:
                    key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
                    if key is None:
                        key = "rand_key_" + "".join(random.choice(chars) for _ in range(13))
                key_list.append(key)
    else:  # raw text; audio sample point, fbank; bytes
@@ -211,7 +214,6 @@
        deep_update(model_conf, kwargs.get("model_conf", {}))
        deep_update(model_conf, kwargs)
        model = model_class(**model_conf, vocab_size=vocab_size)
        model.to(device)
        # init_param
        init_param = kwargs.get("init_param", None)
@@ -232,6 +234,9 @@
        # fp16
        if kwargs.get("fp16", False):
            model.to(torch.float16)
        elif kwargs.get("bf16", False):
            model.to(torch.bfloat16)
        model.to(device)
        return model, kwargs
    def __call__(self, *args, **cfg):
@@ -284,7 +289,7 @@
            with torch.no_grad():
                res = model.inference(**batch, **kwargs)
                if isinstance(res, (list, tuple)):
                    results = res[0]
                    results = res[0] if len(res) > 0 else [{"text": ""}]
                    meta_data = res[1] if len(res) > 1 else {}
            time2 = time.perf_counter()
@@ -358,6 +363,7 @@
            results_sorted = []
            if not len(sorted_data):
                results_ret_list.append({"key": key, "text": "", "timestamp": []})
                logging.info("decoding, utt: {}, empty speech".format(key))
                continue
@@ -425,6 +431,10 @@
            #                      f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
            #                      f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
            if len(results_sorted) != n:
                results_ret_list.append({"key": key, "text": "", "timestamp": []})
                logging.info("decoding, utt: {}, empty result".format(key))
                continue
            restored_data = [0] * n
            for j in range(n):
                index = sorted_data[j][1]
@@ -458,23 +468,20 @@
                        else:
                            result[k] += restored_data[j][k]
            if not len(result["text"].strip()):
                continue
            return_raw_text = kwargs.get("return_raw_text", False)
            # step.3 compute punc model
            raw_text = None
            if self.punc_model is not None:
                if not len(result["text"].strip()):
                    if return_raw_text:
                        result["raw_text"] = ""
                else:
                    deep_update(self.punc_kwargs, cfg)
                    punc_res = self.inference(
                        result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg
                    )
                    raw_text = copy.copy(result["text"])
                    if return_raw_text:
                        result["raw_text"] = raw_text
                    result["text"] = punc_res[0]["text"]
            else:
                raw_text = None
                deep_update(self.punc_kwargs, cfg)
                punc_res = self.inference(
                    result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg
                )
                raw_text = copy.copy(result["text"])
                if return_raw_text:
                    result["raw_text"] = raw_text
                result["text"] = punc_res[0]["text"]
            # speaker embedding cluster after resorted
            if self.spk_model is not None and kwargs.get("return_spk_res", True):
@@ -511,24 +518,40 @@
                                       and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
                                       can predict timestamp, and speaker diarization relies on timestamps."
                        )
                    sentence_list = timestamp_sentence(
                        punc_res[0]["punc_array"],
                        result["timestamp"],
                        raw_text,
                        return_raw_text=return_raw_text,
                    )
                    if kwargs.get("en_post_proc", False):
                        sentence_list = timestamp_sentence_en(
                            punc_res[0]["punc_array"],
                            result["timestamp"],
                            raw_text,
                            return_raw_text=return_raw_text,
                        )
                    else:
                        sentence_list = timestamp_sentence(
                            punc_res[0]["punc_array"],
                            result["timestamp"],
                            raw_text,
                            return_raw_text=return_raw_text,
                        )
                distribute_spk(sentence_list, sv_output)
                result["sentence_info"] = sentence_list
            elif kwargs.get("sentence_timestamp", False):
                if not len(result["text"].strip()):
                    sentence_list = []
                else:
                    sentence_list = timestamp_sentence(
                        punc_res[0]["punc_array"],
                        result["timestamp"],
                        raw_text,
                        return_raw_text=return_raw_text,
                    )
                    if kwargs.get("en_post_proc", False):
                        sentence_list = timestamp_sentence_en(
                            punc_res[0]["punc_array"],
                            result["timestamp"],
                            raw_text,
                            return_raw_text=return_raw_text,
                        )
                    else:
                        sentence_list = timestamp_sentence(
                            punc_res[0]["punc_array"],
                            result["timestamp"],
                            raw_text,
                            return_raw_text=return_raw_text,
                        )
                result["sentence_info"] = sentence_list
            if "spk_embedding" in result:
                del result["spk_embedding"]
@@ -580,12 +603,6 @@
        )
        with torch.no_grad():
            if type == "onnx":
                export_dir = export_utils.export_onnx(model=model, data_in=data_list, **kwargs)
            else:
                export_dir = export_utils.export_torchscripts(
                    model=model, data_in=data_list, **kwargs
                )
            export_dir = export_utils.export(model=model, data_in=data_list, **kwargs)
        return export_dir