From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
runtime/triton_gpu/client/decode_manifest_triton_wo_cuts.py | 85 +++++++++++++++---------------------------
1 files changed, 30 insertions(+), 55 deletions(-)
diff --git a/runtime/triton_gpu/client/decode_manifest_triton_wo_cuts.py b/runtime/triton_gpu/client/decode_manifest_triton_wo_cuts.py
index ad121c6..379b28d 100644
--- a/runtime/triton_gpu/client/decode_manifest_triton_wo_cuts.py
+++ b/runtime/triton_gpu/client/decode_manifest_triton_wo_cuts.py
@@ -60,14 +60,12 @@
from icefall.utils import store_transcripts, write_error_stats
DEFAULT_MANIFEST_FILENAME = "./aishell_test.txt" # noqa
-DEFAULT_ROOT = './'
-DEFAULT_ROOT = '/mfs/songtao/researchcode/FunASR/data/'
+DEFAULT_ROOT = "./"
+DEFAULT_ROOT = "/mfs/songtao/researchcode/FunASR/data/"
def get_args():
- parser = argparse.ArgumentParser(
- formatter_class=argparse.ArgumentDefaultsHelpFormatter
- )
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--server-addr",
@@ -185,7 +183,7 @@
with open(fp) as f:
for i, dp in enumerate(f.readlines()):
dp = eval(dp)
- dp['id'] = i
+ dp["id"] = i
data.append(dp)
return data
@@ -195,19 +193,19 @@
# import pdb;pdb.set_trace()
assert len(dps) > num_tasks
- one_task_num = len(dps)//num_tasks
+ one_task_num = len(dps) // num_tasks
for i in range(0, len(dps), one_task_num):
- if i+one_task_num >= len(dps):
+ if i + one_task_num >= len(dps):
for k, j in enumerate(range(i, len(dps))):
dps_splited[k].append(dps[j])
else:
- dps_splited.append(dps[i:i+one_task_num])
+ dps_splited.append(dps[i : i + one_task_num])
return dps_splited
def load_audio(path):
audio = AudioSegment.from_wav(path).set_frame_rate(16000).set_channels(1)
- audiop_np = np.array(audio.get_array_of_samples())/32768.0
+ audiop_np = np.array(audio.get_array_of_samples()) / 32768.0
return audiop_np.astype(np.float32), audio.duration_seconds
@@ -227,16 +225,14 @@
if i % log_interval == 0:
print(f"{name}: {i}/{len(dps)}")
- waveform, duration = load_audio(
- os.path.join(DEFAULT_ROOT, dp['audio_filepath']))
+ waveform, duration = load_audio(os.path.join(DEFAULT_ROOT, dp["audio_filepath"]))
sample_rate = 16000
# padding to nearset 10 seconds
samples = np.zeros(
(
1,
- 10 * sample_rate *
- (int(len(waveform) / sample_rate // 10) + 1),
+ 10 * sample_rate * (int(len(waveform) / sample_rate // 10) + 1),
),
dtype=np.float32,
)
@@ -245,9 +241,7 @@
lengths = np.array([[len(waveform)]], dtype=np.int32)
inputs = [
- protocol_client.InferInput(
- "WAV", samples.shape, np_to_triton_dtype(samples.dtype)
- ),
+ protocol_client.InferInput("WAV", samples.shape, np_to_triton_dtype(samples.dtype)),
protocol_client.InferInput(
"WAV_LENS", lengths.shape, np_to_triton_dtype(lengths.dtype)
),
@@ -271,16 +265,16 @@
total_duration += duration
if compute_cer:
- ref = dp['text'].split()
+ ref = dp["text"].split()
hyp = decoding_results.split()
ref = list("".join(ref))
hyp = list("".join(hyp))
- results.append((dp['id'], ref, hyp))
+ results.append((dp["id"], ref, hyp))
else:
results.append(
(
- dp['id'],
- dp['text'].split(),
+ dp["id"],
+ dp["text"].split(),
decoding_results.split(),
)
) # noqa
@@ -309,7 +303,7 @@
if i % log_interval == 0:
print(f"{name}: {i}/{len(dps)}")
- waveform, duration = load_audio(dp['audio_filepath'])
+ waveform, duration = load_audio(dp["audio_filepath"])
sample_rate = 16000
wav_segs = []
@@ -318,10 +312,10 @@
while j < len(waveform):
if j == 0:
stride = int(first_chunk_in_secs * sample_rate)
- wav_segs.append(waveform[j: j + stride])
+ wav_segs.append(waveform[j : j + stride])
else:
stride = int(other_chunk_in_secs * sample_rate)
- wav_segs.append(waveform[j: j + stride])
+ wav_segs.append(waveform[j : j + stride])
j += len(wav_segs[-1])
sequence_id = task_index + 10086
@@ -380,25 +374,23 @@
decoding_results = b" ".join(decoding_results).decode("utf-8")
else:
# For wenet
- decoding_results = response.as_numpy("TRANSCRIPTS")[0].decode(
- "utf-8"
- )
+ decoding_results = response.as_numpy("TRANSCRIPTS")[0].decode("utf-8")
chunk_end = time.time() - chunk_start
latency_data.append((chunk_end, chunk_len / sample_rate))
total_duration += duration
if compute_cer:
- ref = dp['text'].split()
+ ref = dp["text"].split()
hyp = decoding_results.split()
ref = list("".join(ref))
hyp = list("".join(hyp))
- results.append((dp['id'], ref, hyp))
+ results.append((dp["id"], ref, hyp))
else:
results.append(
(
- dp['id'],
- dp['text'].split(),
+ dp["id"],
+ dp["text"].split(),
decoding_results.split(),
)
) # noqa
@@ -426,15 +418,11 @@
if args.streaming or args.simulate_streaming:
frame_shift_ms = 10
frame_length_ms = 25
- add_frames = math.ceil(
- (frame_length_ms - frame_shift_ms) / frame_shift_ms
- )
+ add_frames = math.ceil((frame_length_ms - frame_shift_ms) / frame_shift_ms)
# decode_window_length: input sequence length of streaming encoder
if args.context > 0:
# decode window length calculation for wenet
- decode_window_length = (
- args.chunk_size - 1
- ) * args.subsampling + args.context
+ decode_window_length = (args.chunk_size - 1) * args.subsampling + args.context
else:
# decode window length calculation for icefall
decode_window_length = (
@@ -457,10 +445,7 @@
compute_cer=compute_cer,
model_name=args.model_name,
first_chunk_in_secs=first_chunk_ms / 1000,
- other_chunk_in_secs=args.chunk_size
- * args.subsampling
- * frame_shift_ms
- / 1000,
+ other_chunk_in_secs=args.chunk_size * args.subsampling * frame_shift_ms / 1000,
task_index=i,
)
)
@@ -475,10 +460,7 @@
compute_cer=compute_cer,
model_name=args.model_name,
first_chunk_in_secs=first_chunk_ms / 1000,
- other_chunk_in_secs=args.chunk_size
- * args.subsampling
- * frame_shift_ms
- / 1000,
+ other_chunk_in_secs=args.chunk_size * args.subsampling * frame_shift_ms / 1000,
task_index=i,
simulate_mode=True,
)
@@ -516,15 +498,10 @@
s = f"RTF: {rtf:.4f}\n"
s += f"total_duration: {total_duration:.3f} seconds\n"
s += f"({total_duration/3600:.2f} hours)\n"
- s += (
- f"processing time: {elapsed:.3f} seconds "
- f"({elapsed/3600:.2f} hours)\n"
- )
+ s += f"processing time: {elapsed:.3f} seconds " f"({elapsed/3600:.2f} hours)\n"
if args.streaming or args.simulate_streaming:
- latency_list = [
- chunk_end for (chunk_end, chunk_duration) in latency_data
- ]
+ latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
s += f"latency_variance: {latency_variance:.2f}\n"
@@ -550,9 +527,7 @@
print(f.readline()) # Detailed errors
if args.stats_file:
- stats = await triton_client.get_inference_statistics(
- model_name="", as_json=True
- )
+ stats = await triton_client.get_inference_statistics(model_name="", as_json=True)
with open(args.stats_file, "w") as f:
json.dump(stats, f)
--
Gitblit v1.9.1