From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365
---
runtime/triton_gpu/model_repo_paraformer_large_online/feature_extractor/1/model.py | 53 ++++++++++++++++++++++++-----------------------------
1 files changed, 24 insertions(+), 29 deletions(-)
diff --git a/runtime/triton_gpu/model_repo_paraformer_large_online/feature_extractor/1/model.py b/runtime/triton_gpu/model_repo_paraformer_large_online/feature_extractor/1/model.py
index 7d81a98..716a075 100755
--- a/runtime/triton_gpu/model_repo_paraformer_large_online/feature_extractor/1/model.py
+++ b/runtime/triton_gpu/model_repo_paraformer_large_online/feature_extractor/1/model.py
@@ -36,7 +36,7 @@
class Feat(object):
- def __init__(self, seqid, offset_ms, sample_rate, frame_stride, device='cpu'):
+ def __init__(self, seqid, offset_ms, sample_rate, frame_stride, device="cpu"):
self.seqid = seqid
self.sample_rate = sample_rate
self.wav = torch.tensor([], device=device)
@@ -45,14 +45,14 @@
self.frame_stride = int(frame_stride)
self.device = device
self.lfr_m = 7
-
+
def add_wavs(self, wav: torch.tensor):
wav = wav.to(self.device)
self.wav = torch.cat((self.wav, wav), axis=0)
-
+
def get_seg_wav(self):
seg = self.wav[:]
- self.wav = self.wav[-self.offset:]
+ self.wav = self.wav[-self.offset :]
return seg
def add_frames(self, frames: torch.tensor):
@@ -60,14 +60,13 @@
frames: seq_len x feat_sz
"""
if self.frames is None:
- self.frames = torch.cat((frames[0, :].repeat((self.lfr_m - 1) // 2, 1),
- frames), axis=0)
+ self.frames = torch.cat((frames[0, :].repeat((self.lfr_m - 1) // 2, 1), frames), axis=0)
else:
self.frames = torch.cat([self.frames, frames], axis=0)
-
+
def get_frames(self, num_frames: int):
- seg = self.frames[0: num_frames]
- self.frames = self.frames[self.frame_stride:]
+ seg = self.frames[0:num_frames]
+ self.frames = self.frames[self.frame_stride :]
return seg
@@ -91,7 +90,7 @@
* model_version: Model version
* model_name: Model name
"""
- self.model_config = model_config = json.loads(args['model_config'])
+ self.model_config = model_config = json.loads(args["model_config"])
self.max_batch_size = max(model_config["max_batch_size"], 1)
if "GPU" in model_config["instance_group"][0]["kind"]:
@@ -100,35 +99,33 @@
self.device = "cpu"
# Get OUTPUT0 configuration
- output0_config = pb_utils.get_output_config_by_name(
- model_config, "speech")
+ output0_config = pb_utils.get_output_config_by_name(model_config, "speech")
# Convert Triton types to numpy types
- self.output0_dtype = pb_utils.triton_string_to_numpy(
- output0_config['data_type'])
+ self.output0_dtype = pb_utils.triton_string_to_numpy(output0_config["data_type"])
if self.output0_dtype == np.float32:
self.dtype = torch.float32
else:
self.dtype = torch.float16
- self.feature_size = output0_config['dims'][-1]
- self.decoding_window = output0_config['dims'][-2]
+ self.feature_size = output0_config["dims"][-1]
+ self.decoding_window = output0_config["dims"][-2]
- params = self.model_config['parameters']
+ params = self.model_config["parameters"]
for li in params.items():
key, value = li
value = value["string_value"]
if key == "config_path":
- with open(str(value), 'rb') as f:
+ with open(str(value), "rb") as f:
config = yaml.load(f, Loader=yaml.Loader)
opts = kaldifeat.FbankOptions()
opts.frame_opts.dither = 0.0
- opts.frame_opts.window_type = config['frontend_conf']['window']
- opts.mel_opts.num_bins = int(config['frontend_conf']['n_mels'])
- opts.frame_opts.frame_shift_ms = float(config['frontend_conf']['frame_shift'])
- opts.frame_opts.frame_length_ms = float(config['frontend_conf']['frame_length'])
- opts.frame_opts.samp_freq = int(config['frontend_conf']['fs'])
+ opts.frame_opts.window_type = config["frontend_conf"]["window"]
+ opts.mel_opts.num_bins = int(config["frontend_conf"]["n_mels"])
+ opts.frame_opts.frame_shift_ms = float(config["frontend_conf"]["frame_shift"])
+ opts.frame_opts.frame_length_ms = float(config["frontend_conf"]["frame_length"])
+ opts.frame_opts.samp_freq = int(config["frontend_conf"]["fs"])
opts.device = torch.device(self.device)
self.opts = opts
self.feature_extractor = Fbank(self.opts)
@@ -177,8 +174,7 @@
# wav_len = from_dlpack(input1.to_dlpack())[0]
wav_len = len(wav)
if wav_len < self.chunk_size:
- temp = torch.zeros(self.chunk_size, dtype=torch.float32,
- device=self.device)
+ temp = torch.zeros(self.chunk_size, dtype=torch.float32, device=self.device)
temp[0:wav_len] = wav[:]
wav = temp
@@ -192,10 +188,9 @@
end = in_end.as_numpy()[0][0]
if start:
- self.seq_feat[corrid] = Feat(corrid, self.offset_ms,
- self.sample_rate,
- self.frame_stride,
- self.device)
+ self.seq_feat[corrid] = Feat(
+ corrid, self.offset_ms, self.sample_rate, self.frame_stride, self.device
+ )
if ready:
self.seq_feat[corrid].add_wavs(wav)
--
Gitblit v1.9.1