| | |
| | | from icefall.utils import store_transcripts, write_error_stats |
| | | |
| | | DEFAULT_MANIFEST_FILENAME = "./aishell_test.txt" # noqa |
| | | DEFAULT_ROOT = './' |
| | | DEFAULT_ROOT = '/mfs/songtao/researchcode/FunASR/data/' |
| | | DEFAULT_ROOT = "./" |
| | | DEFAULT_ROOT = "/mfs/songtao/researchcode/FunASR/data/" |
| | | |
| | | |
| | | def get_args(): |
| | | parser = argparse.ArgumentParser( |
| | | formatter_class=argparse.ArgumentDefaultsHelpFormatter |
| | | ) |
| | | parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
| | | |
| | | parser.add_argument( |
| | | "--server-addr", |
| | |
| | | with open(fp) as f: |
| | | for i, dp in enumerate(f.readlines()): |
| | | dp = eval(dp) |
| | | dp['id'] = i |
| | | dp["id"] = i |
| | | data.append(dp) |
| | | return data |
| | | |
| | |
| | | # import pdb;pdb.set_trace() |
| | | assert len(dps) > num_tasks |
| | | |
| | | one_task_num = len(dps)//num_tasks |
| | | one_task_num = len(dps) // num_tasks |
| | | for i in range(0, len(dps), one_task_num): |
| | | if i+one_task_num >= len(dps): |
| | | if i + one_task_num >= len(dps): |
| | | for k, j in enumerate(range(i, len(dps))): |
| | | dps_splited[k].append(dps[j]) |
| | | else: |
| | | dps_splited.append(dps[i:i+one_task_num]) |
| | | dps_splited.append(dps[i : i + one_task_num]) |
| | | return dps_splited |
| | | |
| | | |
| | | def load_audio(path): |
| | | audio = AudioSegment.from_wav(path).set_frame_rate(16000).set_channels(1) |
| | | audiop_np = np.array(audio.get_array_of_samples())/32768.0 |
| | | audiop_np = np.array(audio.get_array_of_samples()) / 32768.0 |
| | | return audiop_np.astype(np.float32), audio.duration_seconds |
| | | |
| | | |
| | |
| | | if i % log_interval == 0: |
| | | print(f"{name}: {i}/{len(dps)}") |
| | | |
| | | waveform, duration = load_audio( |
| | | os.path.join(DEFAULT_ROOT, dp['audio_filepath'])) |
| | | waveform, duration = load_audio(os.path.join(DEFAULT_ROOT, dp["audio_filepath"])) |
| | | sample_rate = 16000 |
| | | |
| | | # padding to nearset 10 seconds |
| | | samples = np.zeros( |
| | | ( |
| | | 1, |
| | | 10 * sample_rate * |
| | | (int(len(waveform) / sample_rate // 10) + 1), |
| | | 10 * sample_rate * (int(len(waveform) / sample_rate // 10) + 1), |
| | | ), |
| | | dtype=np.float32, |
| | | ) |
| | |
| | | lengths = np.array([[len(waveform)]], dtype=np.int32) |
| | | |
| | | inputs = [ |
| | | protocol_client.InferInput( |
| | | "WAV", samples.shape, np_to_triton_dtype(samples.dtype) |
| | | ), |
| | | protocol_client.InferInput("WAV", samples.shape, np_to_triton_dtype(samples.dtype)), |
| | | protocol_client.InferInput( |
| | | "WAV_LENS", lengths.shape, np_to_triton_dtype(lengths.dtype) |
| | | ), |
| | |
| | | total_duration += duration |
| | | |
| | | if compute_cer: |
| | | ref = dp['text'].split() |
| | | ref = dp["text"].split() |
| | | hyp = decoding_results.split() |
| | | ref = list("".join(ref)) |
| | | hyp = list("".join(hyp)) |
| | | results.append((dp['id'], ref, hyp)) |
| | | results.append((dp["id"], ref, hyp)) |
| | | else: |
| | | results.append( |
| | | ( |
| | | dp['id'], |
| | | dp['text'].split(), |
| | | dp["id"], |
| | | dp["text"].split(), |
| | | decoding_results.split(), |
| | | ) |
| | | ) # noqa |
| | |
| | | if i % log_interval == 0: |
| | | print(f"{name}: {i}/{len(dps)}") |
| | | |
| | | waveform, duration = load_audio(dp['audio_filepath']) |
| | | waveform, duration = load_audio(dp["audio_filepath"]) |
| | | sample_rate = 16000 |
| | | |
| | | wav_segs = [] |
| | |
| | | while j < len(waveform): |
| | | if j == 0: |
| | | stride = int(first_chunk_in_secs * sample_rate) |
| | | wav_segs.append(waveform[j: j + stride]) |
| | | wav_segs.append(waveform[j : j + stride]) |
| | | else: |
| | | stride = int(other_chunk_in_secs * sample_rate) |
| | | wav_segs.append(waveform[j: j + stride]) |
| | | wav_segs.append(waveform[j : j + stride]) |
| | | j += len(wav_segs[-1]) |
| | | |
| | | sequence_id = task_index + 10086 |
| | |
| | | decoding_results = b" ".join(decoding_results).decode("utf-8") |
| | | else: |
| | | # For wenet |
| | | decoding_results = response.as_numpy("TRANSCRIPTS")[0].decode( |
| | | "utf-8" |
| | | ) |
| | | decoding_results = response.as_numpy("TRANSCRIPTS")[0].decode("utf-8") |
| | | chunk_end = time.time() - chunk_start |
| | | latency_data.append((chunk_end, chunk_len / sample_rate)) |
| | | |
| | | total_duration += duration |
| | | |
| | | if compute_cer: |
| | | ref = dp['text'].split() |
| | | ref = dp["text"].split() |
| | | hyp = decoding_results.split() |
| | | ref = list("".join(ref)) |
| | | hyp = list("".join(hyp)) |
| | | results.append((dp['id'], ref, hyp)) |
| | | results.append((dp["id"], ref, hyp)) |
| | | else: |
| | | results.append( |
| | | ( |
| | | dp['id'], |
| | | dp['text'].split(), |
| | | dp["id"], |
| | | dp["text"].split(), |
| | | decoding_results.split(), |
| | | ) |
| | | ) # noqa |
| | |
| | | if args.streaming or args.simulate_streaming: |
| | | frame_shift_ms = 10 |
| | | frame_length_ms = 25 |
| | | add_frames = math.ceil( |
| | | (frame_length_ms - frame_shift_ms) / frame_shift_ms |
| | | ) |
| | | add_frames = math.ceil((frame_length_ms - frame_shift_ms) / frame_shift_ms) |
| | | # decode_window_length: input sequence length of streaming encoder |
| | | if args.context > 0: |
| | | # decode window length calculation for wenet |
| | | decode_window_length = ( |
| | | args.chunk_size - 1 |
| | | ) * args.subsampling + args.context |
| | | decode_window_length = (args.chunk_size - 1) * args.subsampling + args.context |
| | | else: |
| | | # decode window length calculation for icefall |
| | | decode_window_length = ( |
| | |
| | | compute_cer=compute_cer, |
| | | model_name=args.model_name, |
| | | first_chunk_in_secs=first_chunk_ms / 1000, |
| | | other_chunk_in_secs=args.chunk_size |
| | | * args.subsampling |
| | | * frame_shift_ms |
| | | / 1000, |
| | | other_chunk_in_secs=args.chunk_size * args.subsampling * frame_shift_ms / 1000, |
| | | task_index=i, |
| | | ) |
| | | ) |
| | |
| | | compute_cer=compute_cer, |
| | | model_name=args.model_name, |
| | | first_chunk_in_secs=first_chunk_ms / 1000, |
| | | other_chunk_in_secs=args.chunk_size |
| | | * args.subsampling |
| | | * frame_shift_ms |
| | | / 1000, |
| | | other_chunk_in_secs=args.chunk_size * args.subsampling * frame_shift_ms / 1000, |
| | | task_index=i, |
| | | simulate_mode=True, |
| | | ) |
| | |
| | | s = f"RTF: {rtf:.4f}\n" |
| | | s += f"total_duration: {total_duration:.3f} seconds\n" |
| | | s += f"({total_duration/3600:.2f} hours)\n" |
| | | s += ( |
| | | f"processing time: {elapsed:.3f} seconds " |
| | | f"({elapsed/3600:.2f} hours)\n" |
| | | ) |
| | | s += f"processing time: {elapsed:.3f} seconds " f"({elapsed/3600:.2f} hours)\n" |
| | | |
| | | if args.streaming or args.simulate_streaming: |
| | | latency_list = [ |
| | | chunk_end for (chunk_end, chunk_duration) in latency_data |
| | | ] |
| | | latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data] |
| | | latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0 |
| | | latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0 |
| | | s += f"latency_variance: {latency_variance:.2f}\n" |
| | |
| | | print(f.readline()) # Detailed errors |
| | | |
| | | if args.stats_file: |
| | | stats = await triton_client.get_inference_statistics( |
| | | model_name="", as_json=True |
| | | ) |
| | | stats = await triton_client.get_inference_statistics(model_name="", as_json=True) |
| | | with open(args.stats_file, "w") as f: |
| | | json.dump(stats, f) |
| | | |