语帆
2024-02-21 3c173e6187ea49520b1e946c9d5de66c824f0864
funasr/auto/auto_model.py
@@ -1,14 +1,13 @@
import json
import time
import copy
import torch
import hydra
import random
import string
import logging
import os.path
import numpy as np
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf, ListConfig
from funasr.register import tables
from funasr.utils.load_utils import load_bytes
@@ -17,7 +16,7 @@
from funasr.utils.vad_utils import slice_padding_audio_samples
from funasr.train_utils.set_all_random_seed import set_all_random_seed
from funasr.train_utils.load_pretrained_model import load_pretrained_model
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
from funasr.utils.load_utils import load_audio_text_image_video
from funasr.utils.timestamp_tools import timestamp_sentence
from funasr.models.campplus.utils import sv_chunk, postprocess, distribute_spk
try:
@@ -387,16 +386,22 @@
                            result[k] = restored_data[j][k]
                        else:
                            result[k] += restored_data[j][k]
            return_raw_text = kwargs.get('return_raw_text', False)
            # step.3 compute punc model
            if self.punc_model is not None:
                self.punc_kwargs.update(cfg)
                punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, disable_pbar=True, **cfg)
                import copy; raw_text = copy.copy(result["text"])
                raw_text = copy.copy(result["text"])
                if return_raw_text: result['raw_text'] = raw_text
                result["text"] = punc_res[0]["text"]
            else:
                raw_text = None
                
            # speaker embedding cluster after resorted
            if self.spk_model is not None and kwargs.get('return_spk_res', True):
                if raw_text is None:
                    logging.error("Missing punc_model, which is required by spk_model.")
                all_segments = sorted(all_segments, key=lambda x: x[0])
                spk_embedding = result['spk_embedding']
                labels = self.cb_model(spk_embedding.cpu(), oracle_num=kwargs.get('preset_spk_num', None))
@@ -405,22 +410,32 @@
                if self.spk_mode == 'vad_segment':  # recover sentence_list
                    sentence_list = []
                    for res, vadsegment in zip(restored_data, vadsegments):
                        sentence_list.append({"start": vadsegment[0],\
                                                "end": vadsegment[1],
                                                "sentence": res['raw_text'],
                                                "timestamp": res['timestamp']})
                        if 'timestamp' not in res:
                            logging.error("Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
                                           and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
                                           can predict timestamp, and speaker diarization relies on timestamps.")
                        sentence_list.append({"start": vadsegment[0],
                                              "end": vadsegment[1],
                                              "sentence": res['text'],
                                              "timestamp": res['timestamp']})
                elif self.spk_mode == 'punc_segment':
                    sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \
                                                        result['timestamp'], \
                                                        result['raw_text'])
                    if 'timestamp' not in result:
                        logging.error("Only 'iic/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-pytorch' \
                                       and 'iic/speech_seaco_paraformer_large_asr_nat-zh-cn-16k-common-vocab8404-pytorch'\
                                       can predict timestamp, and speaker diarization relies on timestamps.")
                    sentence_list = timestamp_sentence(punc_res[0]['punc_array'],
                                                       result['timestamp'],
                                                       raw_text,
                                                       return_raw_text=return_raw_text)
                distribute_spk(sentence_list, sv_output)
                result['sentence_info'] = sentence_list
            elif kwargs.get("sentence_timestamp", False):
                sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \
                                                        result['timestamp'], \
                                                        result['raw_text'])
                sentence_list = timestamp_sentence(punc_res[0]['punc_array'],
                                                   result['timestamp'],
                                                   raw_text,
                                                   return_raw_text=return_raw_text)
                result['sentence_info'] = sentence_list
            del result['spk_embedding']
            if "spk_embedding" in result: del result['spk_embedding']
                    
            result["key"] = key
            results_ret_list.append(result)