游雁
2024-01-30 f1c1cb0773fca5e9d1ee595ef6ca2ff4bad9f2a4
funasr/auto/auto_model.py
@@ -88,7 +88,8 @@
class AutoModel:
    
    def __init__(self, **kwargs):
        tables.print()
        if kwargs.get("disable_log", False):
            tables.print()
        
        model, kwargs = self.build_model(**kwargs)
        
@@ -133,8 +134,6 @@
        self.spk_model = spk_model
        self.spk_kwargs = spk_kwargs
        self.model_path = kwargs.get("model_path")
        
    def build_model(self, **kwargs):
        assert "model" in kwargs
@@ -145,7 +144,7 @@
        set_all_random_seed(kwargs.get("seed", 0))
        
        device = kwargs.get("device", "cuda")
        if not torch.cuda.is_available() or kwargs.get("ngpu", 0) == 0:
        if not torch.cuda.is_available() or kwargs.get("ngpu", 1) == 0:
            device = "cpu"
            kwargs["batch_size"] = 1
        kwargs["device"] = device
@@ -198,8 +197,6 @@
        kwargs.update(cfg)
        res = self.model(*args, kwargs)
        return res
    def generate(self, input, input_len=None, **cfg):
        if self.vad_model is None:
@@ -260,7 +257,7 @@
            time_escape_total += time_escape
        if pbar:
            pbar.update(1)
            # pbar.update(1)
            pbar.set_description(f"rtf_avg: {time_escape_total/time_speech_total:0.3f}")
        torch.cuda.empty_cache()
        return asr_result_list
@@ -285,10 +282,10 @@
        
        key_list, data_list = prepare_data_iterator(input, input_len=input_len, data_type=kwargs.get("data_type", None))
        results_ret_list = []
        time_speech_total_all_samples = 0.0
        time_speech_total_all_samples = 1e-6
        beg_total = time.time()
        pbar_total = tqdm(colour="red", total=len(res) + 1, dynamic_ncols=True)
        pbar_total = tqdm(colour="red", total=len(res), dynamic_ncols=True)
        for i in range(len(res)):
            key = res[i]["key"]
            vadsegments = res[i]["value"]
@@ -310,14 +307,14 @@
            batch_size_ms_cum = 0
            beg_idx = 0
            beg_asr_total = time.time()
            time_speech_total_per_sample = speech_lengths/16000 + 1e-6
            time_speech_total_per_sample = speech_lengths/16000
            time_speech_total_all_samples += time_speech_total_per_sample
            pbar_sample = tqdm(colour="blue", total=n + 1, dynamic_ncols=True)
            # pbar_sample = tqdm(colour="blue", total=n, dynamic_ncols=True)
            all_segments = []
            for j, _ in enumerate(range(0, n)):
                pbar_sample.update(1)
                # pbar_sample.update(1)
                batch_size_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0])
                if j < n - 1 and (
                    batch_size_ms_cum + sorted_data[j + 1][0][1] - sorted_data[j + 1][0][0]) < batch_size and (
@@ -336,19 +333,19 @@
                        segments = sv_chunk(vad_segments)
                        all_segments.extend(segments)
                        speech_b = [i[2] for i in segments]
                        spk_res = self.inference(speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, **cfg)
                        spk_res = self.inference(speech_b, input_len=None, model=self.spk_model, kwargs=kwargs, disable_pbar=True, **cfg)
                        results[_b]['spk_embedding'] = spk_res[0]['spk_embedding']
                beg_idx = end_idx
                if len(results) < 1:
                    continue
                results_sorted.extend(results)
            
            end_asr_total = time.time()
            time_escape_total_per_sample = end_asr_total - beg_asr_total
            pbar_sample.update(1)
            pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
                                 f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
                                 f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
            # end_asr_total = time.time()
            # time_escape_total_per_sample = end_asr_total - beg_asr_total
            # pbar_sample.update(1)
            # pbar_sample.set_description(f"rtf_avg_per_sample: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
            #                      f"time_speech_total_per_sample: {time_speech_total_per_sample: 0.3f}, "
            #                      f"time_escape_total_per_sample: {time_escape_total_per_sample:0.3f}")
            
            restored_data = [0] * n
            for j in range(n):
@@ -386,7 +383,7 @@
            # step.3 compute punc model
            if self.punc_model is not None:
                self.punc_kwargs.update(cfg)
                punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, **cfg)
                punc_res = self.inference(result["text"], model=self.punc_model, kwargs=self.punc_kwargs, disable_pbar=True, **cfg)
                import copy; raw_text = copy.copy(result["text"])
                result["text"] = punc_res[0]["text"]
                
@@ -418,13 +415,18 @@
                    
            result["key"] = key
            results_ret_list.append(result)
            end_asr_total = time.time()
            time_escape_total_per_sample = end_asr_total - beg_asr_total
            pbar_total.update(1)
        pbar_total.update(1)
        end_total = time.time()
        time_escape_total_all_samples = end_total - beg_total
        pbar_total.set_description(f"rtf_avg_all_samples: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
                             f"time_speech_total_all_samples: {time_speech_total_all_samples: 0.3f}, "
                             f"time_escape_total_all_samples: {time_escape_total_all_samples:0.3f}")
            pbar_total.set_description(f"rtf_avg: {time_escape_total_per_sample / time_speech_total_per_sample:0.3f}, "
                                 f"time_speech: {time_speech_total_per_sample: 0.3f}, "
                                 f"time_escape: {time_escape_total_per_sample:0.3f}")
        # end_total = time.time()
        # time_escape_total_all_samples = end_total - beg_total
        # print(f"rtf_avg_all: {time_escape_total_all_samples / time_speech_total_all_samples:0.3f}, "
        #                      f"time_speech_all: {time_speech_total_all_samples: 0.3f}, "
        #                      f"time_escape_all: {time_escape_total_all_samples:0.3f}")
        return results_ret_list