| | |
| | | if 'hotword' in kwargs: |
| | | hotword_list_or_file = kwargs['hotword'] |
| | | |
| | | batch_size_token = kwargs.get("batch_size_token", 6000) |
| | | print("batch_size_token: ", batch_size_token) |
| | | |
| | | if speech2text.hotword_list is None: |
| | | speech2text.hotword_list = speech2text.generate_hotwords_list(hotword_list_or_file) |
| | | |
| | |
| | | assert all(isinstance(s, str) for s in keys), keys |
| | | _bs = len(next(iter(batch.values()))) |
| | | assert len(keys) == _bs, f"{len(keys)} != {_bs}" |
| | | |
| | | beg_vad = time.time() |
| | | vad_results = speech2vadsegment(**batch) |
| | | end_vad = time.time() |
| | | print("time cost vad: ", end_vad-beg_vad) |
| | | _, vadsegments = vad_results[0], vad_results[1][0] |
| | | |
| | | speech, speech_lengths = batch["speech"], batch["speech_lengths"] |
| | |
| | | data_with_index = [(vadsegments[i], i) for i in range(n)] |
| | | sorted_data = sorted(data_with_index, key=lambda x: x[0][1] - x[0][0]) |
| | | results_sorted = [] |
| | | for j, beg_idx in enumerate(range(0, n, batch_size)): |
| | | end_idx = min(n, beg_idx + batch_size) |
| | | batch_size_token_ms = batch_size_token*60 |
| | | batch_size_token_ms_cum = 0 |
| | | beg_idx = 0 |
| | | for j, _ in enumerate(range(0, n)): |
| | | batch_size_token_ms_cum += (sorted_data[j][0][1] - sorted_data[j][0][0]) |
| | | if j < n-1 and (batch_size_token_ms_cum + sorted_data[j+1][0][1] - sorted_data[j+1][0][0])<batch_size_token_ms: |
| | | continue |
| | | batch_size_token_ms_cum = 0 |
| | | end_idx = j + 1 |
| | | speech_j, speech_lengths_j = slice_padding_fbank(speech, speech_lengths, sorted_data[beg_idx:end_idx]) |
| | | |
| | | beg_idx = end_idx |
| | | batch = {"speech": speech_j, "speech_lengths": speech_lengths_j} |
| | | batch = to_device(batch, device=device) |
| | | print("batch: ", speech_j.shape[0]) |
| | | beg_asr = time.time() |
| | | results = speech2text(**batch) |
| | | end_asr = time.time() |
| | | print("time cost asr: ", end_asr - beg_asr) |
| | | |
| | | if len(results) < 1: |
| | | results = [["", [], [], [], [], [], []]] |
| | | results_sorted.extend(results) |
| | | |
| | | restored_data = [0] * n |
| | | for j in range(n): |
| | | index = sorted_data[j][1] |
| | |
| | | text_postprocessed_punc = text_postprocessed |
| | | punc_id_list = [] |
| | | if len(word_lists) > 0 and text2punc is not None: |
| | | beg_punc = time.time() |
| | | text_postprocessed_punc, punc_id_list = text2punc(word_lists, 20) |
| | | end_punc = time.time() |
| | | print("time cost punc: ", end_punc-beg_punc) |
| | | |
| | | item = {'key': key, 'value': text_postprocessed_punc} |
| | | if text_postprocessed != "": |