| | |
| | | from funasr.download.file import download_from_url |
| | | from funasr.utils.timestamp_tools import timestamp_sentence |
| | | from funasr.utils.timestamp_tools import timestamp_sentence_en |
| | | from funasr.download.download_from_hub import download_model |
| | | from funasr.download.download_model_from_hub import download_model |
| | | from funasr.utils.vad_utils import slice_padding_audio_samples |
| | | from funasr.utils.vad_utils import merge_vad |
| | | from funasr.utils.load_utils import load_audio_text_image_video |
| | |
| | | |
| | | def __init__(self, **kwargs): |
| | | |
| | | try: |
| | | from funasr.utils.version_checker import check_for_update |
| | | |
| | | check_for_update() |
| | | except: |
| | | pass |
| | | |
| | | log_level = getattr(logging, kwargs.get("log_level", "INFO").upper()) |
| | | logging.basicConfig(level=log_level) |
| | | |
| | | if not kwargs.get("disable_log", True): |
| | | tables.print() |
| | | |
| | | model, kwargs = self.build_model(**kwargs) |
| | | |
| | |
| | | self.spk_kwargs = spk_kwargs |
| | | self.model_path = kwargs.get("model_path") |
| | | |
| | | def build_model(self, **kwargs): |
| | | @staticmethod |
| | | def build_model(**kwargs): |
| | | assert "model" in kwargs |
| | | if "model_conf" not in kwargs: |
| | | logging.info("download models from model hub: {}".format(kwargs.get("hub", "ms"))) |
| | |
| | | kwargs["frontend"] = frontend |
| | | # build model |
| | | model_class = tables.model_classes.get(kwargs["model"]) |
| | | assert model_class is not None, f'{kwargs["model"]} is not registered' |
| | | model_conf = {} |
| | | deep_update(model_conf, kwargs.get("model_conf", {})) |
| | | deep_update(model_conf, kwargs) |
| | |
| | | elif kwargs.get("bf16", False): |
| | | model.to(torch.bfloat16) |
| | | model.to(device) |
| | | |
| | | if not kwargs.get("disable_log", True): |
| | | tables.print() |
| | | |
| | | return model, kwargs |
| | | |
| | | def __call__(self, *args, **cfg): |
| | |
| | | speed_stats["rtf"] = f"{(time_escape) / batch_data_time:0.3f}" |
| | | description = f"{speed_stats}, " |
| | | if pbar: |
| | | pbar.update(1) |
| | | pbar.update(end_idx - beg_idx) |
| | | pbar.set_description(description) |
| | | time_speech_total += batch_data_time |
| | | time_escape_total += time_escape |
| | |
| | | # FIX(gcf): concat the vad clips for sense vocie model for better aed |
| | | if kwargs.get("merge_vad", False): |
| | | for i in range(len(res)): |
| | | res[i]["value"] = merge_vad(res[i]["value"], kwargs.get("merge_length", 15000)) |
| | | res[i]["value"] = merge_vad( |
| | | res[i]["value"], kwargs.get("merge_length_s", 15) * 1000 |
| | | ) |
| | | |
| | | # step.2 compute asr model |
| | | model = self.model |
| | |
| | | |
| | | if len(sorted_data) > 0 and len(sorted_data[0]) > 0: |
| | | batch_size = max(batch_size, sorted_data[0][0][1] - sorted_data[0][0][0]) |
| | | |
| | | if kwargs["device"] == "cpu": |
| | | batch_size = 0 |
| | | |
| | | beg_idx = 0 |
| | | beg_asr_total = time.time() |
| | |
| | | sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu()) |
| | | if self.spk_mode == "vad_segment": # recover sentence_list |
| | | sentence_list = [] |
| | | for res, vadsegment in zip(restored_data, vadsegments): |
| | | if "timestamp" not in res: |
| | | for rest, vadsegment in zip(restored_data, vadsegments): |
| | | if "timestamp" not in rest: |
| | | logging.error( |
| | | "Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \ |
| | | and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\ |
| | |
| | | { |
| | | "start": vadsegment[0], |
| | | "end": vadsegment[1], |
| | | "sentence": res["text"], |
| | | "timestamp": res["timestamp"], |
| | | "sentence": rest["text"], |
| | | "timestamp": rest["timestamp"], |
| | | } |
| | | ) |
| | | elif self.spk_mode == "punc_segment": |