shixian.shi
2024-02-21 571dc8b55a9b036a5b36f968bb3a5baf5858e395
funasr/auto/auto_model.py
@@ -1,14 +1,13 @@
import json
import time
import copy
import torch
import hydra
import random
import string
import logging
import os.path
import numpy as np
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf, ListConfig
from funasr.register import tables
from funasr.utils.load_utils import load_bytes
@@ -17,7 +16,7 @@
from funasr.utils.vad_utils import slice_padding_audio_samples
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils.load_utils import load_audio_text_image_video
from funasr.utils.timestamp_tools import timestamp_sentence
from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
try:
@@ -385,11 +384,15 @@
            if self.punc_model is not None:
                self.punc_kwargs.update(cfg)
                punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, disable_pbar=True, **cfg)
                import copy; raw_text = copy.copy(result["text"])
                raw_text = copy.copy(result["text"])
                result["text"] = punc_res[0]["text"]
            else:
                raw_text = None
                
            # speaker embedding cluster after resorted
            if self.spk_model is not None and kwargs.get('return_spk_res', True):
                if raw_text is None:
                    logging.error("Missing punc_model, which is required by spk_model.")
                all_segments = sorted(all_segments, key=lambda x: x[0])
                spk_embedding = result['spk_embedding']
                labels = self.cb_model(spk_embedding.cpu(), oracle_num=kwargs.get('preset_spk_num', None))
@@ -398,20 +401,28 @@
                if self.spk_mode == 'vad_segment':  # recover sentence_list
                    sentence_list = []
                    for res, vadsegment in zip(restored_data, vadsegments):
                        if 'timestamp' not in res:
                            logging.error("Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
                                and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
                                can predict timestamp, and speaker diarization relies on timestamps.")
                        sentence_list.append({"start": vadsegment[0],\
                                                "end": vadsegment[1],
                                                "sentence": res['raw_text'],
                                                "sentence": res['text'],
                                                "timestamp": res['timestamp']})
                elif self.spk_mode == 'punc_segment':
                    if 'timestamp' not in result:
                        logging.error("Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
                            and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
                            can predict timestamp, and speaker diarization relies on timestamps.")
                    sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \
                                                        result['timestamp'], \
                                                        result['raw_text'])
                                                        raw_text)
                distribute_spk(sentence_list, sv_output)
                result['sentence_info'] = sentence_list
            elif kwargs.get("sentence_timestamp", False):
                sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \
                                                        result['timestamp'], \
                                                        result['raw_text'])
                                                        raw_text)
                result['sentence_info'] = sentence_list
            if "spk_embedding" in result: del result['spk_embedding']