From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 runtime/triton_gpu/model_repo_paraformer_large_online/feature_extractor/1/model.py |   53 ++++++++++++++++++++++++-----------------------------
 1 files changed, 24 insertions(+), 29 deletions(-)

diff --git a/runtime/triton_gpu/model_repo_paraformer_large_online/feature_extractor/1/model.py b/runtime/triton_gpu/model_repo_paraformer_large_online/feature_extractor/1/model.py
index 7d81a98..716a075 100755
--- a/runtime/triton_gpu/model_repo_paraformer_large_online/feature_extractor/1/model.py
+++ b/runtime/triton_gpu/model_repo_paraformer_large_online/feature_extractor/1/model.py
@@ -36,7 +36,7 @@
 
 
 class Feat(object):
-    def __init__(self, seqid, offset_ms, sample_rate, frame_stride, device='cpu'):
+    def __init__(self, seqid, offset_ms, sample_rate, frame_stride, device="cpu"):
         self.seqid = seqid
         self.sample_rate = sample_rate
         self.wav = torch.tensor([], device=device)
@@ -45,14 +45,14 @@
         self.frame_stride = int(frame_stride)
         self.device = device
         self.lfr_m = 7
-    
+
     def add_wavs(self, wav: torch.tensor):
         wav = wav.to(self.device)
         self.wav = torch.cat((self.wav, wav), axis=0)
-    
+
     def get_seg_wav(self):
         seg = self.wav[:]
-        self.wav = self.wav[-self.offset:]
+        self.wav = self.wav[-self.offset :]
         return seg
 
     def add_frames(self, frames: torch.tensor):
@@ -60,14 +60,13 @@
         frames: seq_len x feat_sz
         """
         if self.frames is None:
-            self.frames = torch.cat((frames[0, :].repeat((self.lfr_m - 1) // 2, 1),
-                                     frames), axis=0)
+            self.frames = torch.cat((frames[0, :].repeat((self.lfr_m - 1) // 2, 1), frames), axis=0)
         else:
             self.frames = torch.cat([self.frames, frames], axis=0)
-    
+
     def get_frames(self, num_frames: int):
-        seg = self.frames[0: num_frames]
-        self.frames = self.frames[self.frame_stride:]
+        seg = self.frames[0:num_frames]
+        self.frames = self.frames[self.frame_stride :]
         return seg
 
 
@@ -91,7 +90,7 @@
           * model_version: Model version
           * model_name: Model name
         """
-        self.model_config = model_config = json.loads(args['model_config'])
+        self.model_config = model_config = json.loads(args["model_config"])
         self.max_batch_size = max(model_config["max_batch_size"], 1)
 
         if "GPU" in model_config["instance_group"][0]["kind"]:
@@ -100,35 +99,33 @@
             self.device = "cpu"
 
         # Get OUTPUT0 configuration
-        output0_config = pb_utils.get_output_config_by_name(
-            model_config, "speech")
+        output0_config = pb_utils.get_output_config_by_name(model_config, "speech")
         # Convert Triton types to numpy types
-        self.output0_dtype = pb_utils.triton_string_to_numpy(
-            output0_config['data_type'])
+        self.output0_dtype = pb_utils.triton_string_to_numpy(output0_config["data_type"])
 
         if self.output0_dtype == np.float32:
             self.dtype = torch.float32
         else:
             self.dtype = torch.float16
 
-        self.feature_size = output0_config['dims'][-1]
-        self.decoding_window = output0_config['dims'][-2]
+        self.feature_size = output0_config["dims"][-1]
+        self.decoding_window = output0_config["dims"][-2]
 
-        params = self.model_config['parameters']
+        params = self.model_config["parameters"]
         for li in params.items():
             key, value = li
             value = value["string_value"]
             if key == "config_path":
-                with open(str(value), 'rb') as f:
+                with open(str(value), "rb") as f:
                     config = yaml.load(f, Loader=yaml.Loader)
 
         opts = kaldifeat.FbankOptions()
         opts.frame_opts.dither = 0.0
-        opts.frame_opts.window_type = config['frontend_conf']['window']
-        opts.mel_opts.num_bins = int(config['frontend_conf']['n_mels'])
-        opts.frame_opts.frame_shift_ms = float(config['frontend_conf']['frame_shift'])
-        opts.frame_opts.frame_length_ms = float(config['frontend_conf']['frame_length'])
-        opts.frame_opts.samp_freq = int(config['frontend_conf']['fs'])
+        opts.frame_opts.window_type = config["frontend_conf"]["window"]
+        opts.mel_opts.num_bins = int(config["frontend_conf"]["n_mels"])
+        opts.frame_opts.frame_shift_ms = float(config["frontend_conf"]["frame_shift"])
+        opts.frame_opts.frame_length_ms = float(config["frontend_conf"]["frame_length"])
+        opts.frame_opts.samp_freq = int(config["frontend_conf"]["fs"])
         opts.device = torch.device(self.device)
         self.opts = opts
         self.feature_extractor = Fbank(self.opts)
@@ -177,8 +174,7 @@
             # wav_len = from_dlpack(input1.to_dlpack())[0]
             wav_len = len(wav)
             if wav_len < self.chunk_size:
-                temp = torch.zeros(self.chunk_size, dtype=torch.float32,
-                                   device=self.device)
+                temp = torch.zeros(self.chunk_size, dtype=torch.float32, device=self.device)
                 temp[0:wav_len] = wav[:]
                 wav = temp
 
@@ -192,10 +188,9 @@
             end = in_end.as_numpy()[0][0]
 
             if start:
-                self.seq_feat[corrid] = Feat(corrid, self.offset_ms,
-                                             self.sample_rate,
-                                             self.frame_stride,
-                                             self.device)
+                self.seq_feat[corrid] = Feat(
+                    corrid, self.offset_ms, self.sample_rate, self.frame_stride, self.device
+                )
             if ready:
                 self.seq_feat[corrid].add_wavs(wav)
 

--
Gitblit v1.9.1