From 1596f6f414f6f41da66506debb1dff19fffeb3ec Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期一, 24 六月 2024 11:55:17 +0800
Subject: [PATCH] fixbug hotwords
---
runtime/triton_gpu/client/decode_manifest_triton.py | 43 ++++++++++---------------------------------
1 files changed, 10 insertions(+), 33 deletions(-)
diff --git a/runtime/triton_gpu/client/decode_manifest_triton.py b/runtime/triton_gpu/client/decode_manifest_triton.py
index 3a8d57f..482c715 100644
--- a/runtime/triton_gpu/client/decode_manifest_triton.py
+++ b/runtime/triton_gpu/client/decode_manifest_triton.py
@@ -78,9 +78,7 @@
def get_args():
- parser = argparse.ArgumentParser(
- formatter_class=argparse.ArgumentDefaultsHelpFormatter
- )
+ parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument(
"--server-addr",
@@ -225,9 +223,7 @@
lengths = np.array([[len(waveform)]], dtype=np.int32)
inputs = [
- protocol_client.InferInput(
- "WAV", samples.shape, np_to_triton_dtype(samples.dtype)
- ),
+ protocol_client.InferInput("WAV", samples.shape, np_to_triton_dtype(samples.dtype)),
protocol_client.InferInput(
"WAV_LENS", lengths.shape, np_to_triton_dtype(lengths.dtype)
),
@@ -360,9 +356,7 @@
decoding_results = b" ".join(decoding_results).decode("utf-8")
else:
# For wenet
- decoding_results = response.as_numpy("TRANSCRIPTS")[0].decode(
- "utf-8"
- )
+ decoding_results = response.as_numpy("TRANSCRIPTS")[0].decode("utf-8")
chunk_end = time.time() - chunk_start
latency_data.append((chunk_end, chunk_len / sample_rate))
@@ -406,15 +400,11 @@
if args.streaming or args.simulate_streaming:
frame_shift_ms = 10
frame_length_ms = 25
- add_frames = math.ceil(
- (frame_length_ms - frame_shift_ms) / frame_shift_ms
- )
+ add_frames = math.ceil((frame_length_ms - frame_shift_ms) / frame_shift_ms)
# decode_window_length: input sequence length of streaming encoder
if args.context > 0:
# decode window length calculation for wenet
- decode_window_length = (
- args.chunk_size - 1
- ) * args.subsampling + args.context
+ decode_window_length = (args.chunk_size - 1) * args.subsampling + args.context
else:
# decode window length calculation for icefall
decode_window_length = (
@@ -437,10 +427,7 @@
compute_cer=compute_cer,
model_name=args.model_name,
first_chunk_in_secs=first_chunk_ms / 1000,
- other_chunk_in_secs=args.chunk_size
- * args.subsampling
- * frame_shift_ms
- / 1000,
+ other_chunk_in_secs=args.chunk_size * args.subsampling * frame_shift_ms / 1000,
task_index=i,
)
)
@@ -455,10 +442,7 @@
compute_cer=compute_cer,
model_name=args.model_name,
first_chunk_in_secs=first_chunk_ms / 1000,
- other_chunk_in_secs=args.chunk_size
- * args.subsampling
- * frame_shift_ms
- / 1000,
+ other_chunk_in_secs=args.chunk_size * args.subsampling * frame_shift_ms / 1000,
task_index=i,
simulate_mode=True,
)
@@ -496,15 +480,10 @@
s = f"RTF: {rtf:.4f}\n"
s += f"total_duration: {total_duration:.3f} seconds\n"
s += f"({total_duration/3600:.2f} hours)\n"
- s += (
- f"processing time: {elapsed:.3f} seconds "
- f"({elapsed/3600:.2f} hours)\n"
- )
+ s += f"processing time: {elapsed:.3f} seconds " f"({elapsed/3600:.2f} hours)\n"
if args.streaming or args.simulate_streaming:
- latency_list = [
- chunk_end for (chunk_end, chunk_duration) in latency_data
- ]
+ latency_list = [chunk_end for (chunk_end, chunk_duration) in latency_data]
latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
s += f"latency_variance: {latency_variance:.2f}\n"
@@ -530,9 +509,7 @@
print(f.readline()) # Detailed errors
if args.stats_file:
- stats = await triton_client.get_inference_statistics(
- model_name="", as_json=True
- )
+ stats = await triton_client.get_inference_statistics(model_name="", as_json=True)
with open(args.stats_file, "w") as f:
json.dump(stats, f)
--
Gitblit v1.9.1