游雁
2024-01-25 4f078d1cbd4dfd1ffce31a563cc792098174f920
funasr/auto/auto_model.py
@@ -6,6 +6,7 @@
import string
import logging
import os.path
import numpy as np
from tqdm import tqdm
from omegaconf import DictConfig, OmegaConf, ListConfig
@@ -96,7 +97,7 @@
        vad_kwargs = kwargs.get("vad_model_revision", None)
        if vad_model is not None:
            logging.info("Building VAD model.")
            vad_kwargs = {"model": vad_model, "model_revision": vad_kwargs}
            vad_kwargs = {"model": vad_model, "model_revision": vad_kwargs, "device": kwargs["device"]}
            vad_model, vad_kwargs = self.build_model(**vad_kwargs)
        # if punc_model is not None, build punc model else None
@@ -104,7 +105,7 @@
        punc_kwargs = kwargs.get("punc_model_revision", None)
        if punc_model is not None:
            logging.info("Building punc model.")
            punc_kwargs = {"model": punc_model, "model_revision": punc_kwargs}
            punc_kwargs = {"model": punc_model, "model_revision": punc_kwargs, "device": kwargs["device"]}
            punc_model, punc_kwargs = self.build_model(**punc_kwargs)
        # if spk_model is not None, build spk model else None
@@ -112,9 +113,9 @@
        spk_kwargs = kwargs.get("spk_model_revision", None)
        if spk_model is not None:
            logging.info("Building SPK model.")
            spk_kwargs = {"model": spk_model, "model_revision": spk_kwargs}
            spk_kwargs = {"model": spk_model, "model_revision": spk_kwargs, "device": kwargs["device"]}
            spk_model, spk_kwargs = self.build_model(**spk_kwargs)
            self.cb_model = ClusterBackend()
            self.cb_model = ClusterBackend().to(kwargs["device"])
            spk_mode = kwargs.get("spk_mode", 'punc_segment')
            if spk_mode not in ["default", "vad_segment", "punc_segment"]:
                logging.error("spk_mode should be one of default, vad_segment and punc_segment.")
@@ -132,7 +133,8 @@
        self.punc_kwargs = punc_kwargs
        self.spk_model = spk_model
        self.spk_kwargs = spk_kwargs
        self.model_path = kwargs.get("model_path", "./")
        self.model_path = kwargs.get("model_path")
  
        
    def build_model(self, **kwargs):
@@ -144,7 +146,7 @@
        set_all_random_seed(kwargs.get("seed", 0))
        
        device = kwargs.get("device", "cuda")
        if not torch.cuda.is_available() or kwargs.get("ngpu", 0):
        if not torch.cuda.is_available() or kwargs.get("ngpu", 0) == 0:
            device = "cpu"
            kwargs["batch_size"] = 1
        kwargs["device"] = device
@@ -222,7 +224,7 @@
        asr_result_list = []
        num_samples = len(data_list)
        disable_pbar = kwargs.get("disable_pbar", False)
        pbar = tqdm(colour="blue", total=num_samples+1, dynamic_ncols=True) if not disable_pbar else None
        pbar = tqdm(colour="blue", total=num_samples, dynamic_ncols=True) if not disable_pbar else None
        time_speech_total = 0.0
        time_escape_total = 0.0
        for beg_idx in range(0, num_samples, batch_size):
@@ -333,7 +335,7 @@
                    for _b in range(len(speech_j)):
                        vad_segments = [[sorted_data[beg_idx:end_idx][_b][0][0]/1000.0,
                                        sorted_data[beg_idx:end_idx][_b][0][1]/1000.0,
                                        speech_j[_b]]]
                                        np.array(speech_j[_b])]]
                        segments = sv_chunk(vad_segments)
                        all_segments.extend(segments)
                        speech_b = [i[2] for i in segments]
@@ -348,6 +350,7 @@
            
            end_asr_total = time.time()
            time_escape_total_per_sample = end_asr_total - beg_asr_total
            pbar_sample.update(1)
            pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
                                 f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
                                 f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
@@ -375,7 +378,7 @@
                            result[k] = restored_data[j][k]
                        else:
                            result[k] = torch.cat([result[k], restored_data[j][k]], dim=0)
                    elif k == 'text':
                    elif k == 'raw_text':
                        if k not in result:
                            result[k] = restored_data[j][k]
                        else:
@@ -390,13 +393,13 @@
            if self.punc_model is not None:
                self.punc_kwargs.update(cfg)
                punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg)
                result["text_with_punc"] = punc_res[0]["text"]
                result["text"] = punc_res[0]["text"]
                     
            # speaker embedding cluster after resorted
            if self.spk_model is not None:
                all_segments = sorted(all_segments, key=lambda x: x[0])
                spk_embedding = result['spk_embedding']
                labels = self.cb_model(spk_embedding, oracle_num=self.preset_spk_num)
                labels = self.cb_model(spk_embedding.cpu(), oracle_num=self.preset_spk_num)
                del result['spk_embedding']
                sv_output = postprocess(all_segments, None, labels, spk_embedding.cpu())
                if self.spk_mode == 'vad_segment':
@@ -404,12 +407,12 @@
                    for res, vadsegment in zip(restored_data, vadsegments):
                        sentence_list.append({"start": vadsegment[0],\
                                                "end": vadsegment[1],
                                                "sentence": res['text'],
                                                "sentence": res['raw_text'],
                                                "timestamp": res['timestamp']})
                else: # punc_segment
                    sentence_list = timestamp_sentence(punc_res[0]['punc_array'], \
                                                        result['timestamp'], \
                                                        result['text'])
                                                        result['raw_text'])
                distribute_spk(sentence_list, sv_output)
                result['sentence_info'] = sentence_list