| | |
| | | from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard |
| | | from funasr.models.transformer.utils.nets_utils import make_pad_mask, pad_list |
| | | from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank |
| | | |
| | | from funasr.train_utils.device_funcs import to_device |
| | | |
| | | if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"): |
| | | from torch.cuda.amp import autocast |
| | |
| | | self.nbest = kwargs.get("nbest", 1) |
| | | |
| | | meta_data = {} |
| | | if isinstance(data_in, torch.Tensor): # fbank |
| | | speech, speech_lengths = data_in, data_lengths |
| | | if len(speech.shape) < 3: |
| | | speech = speech[None, :, :] |
| | | if speech_lengths is None: |
| | | speech_lengths = speech.shape[1] |
| | | else: |
| | | # extract fbank feats |
| | | time1 = time.perf_counter() |
| | | audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), |
| | | frontend=frontend) |
| | | time3 = time.perf_counter() |
| | | meta_data["extract_feat"] = f"{time3 - time2:0.3f}" |
| | | meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 |
| | | # if isinstance(data_in, torch.Tensor): # fbank |
| | | # speech, speech_lengths = data_in, data_lengths |
| | | # if len(speech.shape) < 3: |
| | | # speech = speech[None, :, :] |
| | | # if speech_lengths is None: |
| | | # speech_lengths = speech.shape[1] |
| | | # else: |
| | | # extract fbank feats |
| | | time1 = time.perf_counter() |
| | | audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000)) |
| | | time2 = time.perf_counter() |
| | | meta_data["load_data"] = f"{time2 - time1:0.3f}" |
| | | speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"), |
| | | frontend=frontend) |
| | | time3 = time.perf_counter() |
| | | meta_data["extract_feat"] = f"{time3 - time2:0.3f}" |
| | | meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000 |
| | | |
| | | speech = speech.to(device=kwargs["device"]) |
| | | speech_lengths = speech_lengths.to(device=kwargs["device"]) |
| | |
| | | nbest_hyps = [Hypothesis(yseq=yseq, score=score)] |
| | | for nbest_idx, hyp in enumerate(nbest_hyps): |
| | | ibest_writer = None |
| | | if ibest_writer is None and kwargs.get("output_dir") is not None: |
| | | writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = writer[f"{nbest_idx + 1}best_recog"] |
| | | if kwargs.get("output_dir") is not None: |
| | | if not hasattr(self, "writer"): |
| | | self.writer = DatadirWriter(kwargs.get("output_dir")) |
| | | ibest_writer = self.writer[f"{nbest_idx+1}best_recog"] |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | if isinstance(hyp.yseq, list): |
| | |
| | | result_i = {"key": key[i], "token_int": token_int} |
| | | results.append(result_i) |
| | | |
| | | return results, meta_data |
| | | return results, meta_data |
| | | |
| | | def export(self, **kwargs): |
| | | from .export_meta import export_rebuild_model |
| | | if 'max_seq_len' not in kwargs: |
| | | kwargs['max_seq_len'] = 512 |
| | | models = export_rebuild_model(model=self, **kwargs) |
| | | return models |