From d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99 Mon Sep 17 00:00:00 2001
From: liugz18 <57401541+liugz18@users.noreply.github.com>
Date: 星期四, 18 七月 2024 21:34:55 +0800
Subject: [PATCH] Rename 'res' in line 514 to avoid with naming conflict with line 365

---
 runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/1/model.py |  150 ++++++++++++++++++++++++-------------------------
 1 files changed, 74 insertions(+), 76 deletions(-)

diff --git a/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/1/model.py b/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/1/model.py
index c556daf..e6eb3b9 100644
--- a/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/1/model.py
+++ b/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/1/model.py
@@ -25,8 +25,10 @@
 import yaml
 from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
 
+
 class LFR(torch.nn.Module):
-    """Batch LFR: https://github.com/Mddct/devil-asr/blob/main/patch/lfr.py """
+    """Batch LFR: https://github.com/Mddct/devil-asr/blob/main/patch/lfr.py"""
+
     def __init__(self, m: int = 7, n: int = 6) -> None:
         """
         Actually, this implements stacking frames and skipping frames.
@@ -42,8 +44,9 @@
 
         self.left_padding_nums = math.ceil((self.m - 1) // 2)
 
-    def forward(self, input_tensor: torch.Tensor,
-                input_lens: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
+    def forward(
+        self, input_tensor: torch.Tensor, input_lens: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
         B, _, D = input_tensor.size()
         n_lfr = torch.ceil(input_lens / self.n)
 
@@ -70,44 +73,43 @@
         # stack
         input_tensor = torch.cat([head_frames, input_tensor, tail_frames], dim=1)
 
-        index = torch.arange(T_all_max,
-                             device=input_tensor.device,
-                             dtype=input_lens.dtype).unsqueeze(0).repeat(B, 1)  # [B, T_all_max]
-        index_mask = (index <
-                      (self.left_padding_nums + input_lens).unsqueeze(1)
-                      )  #[B, T_all_max]
+        index = (
+            torch.arange(T_all_max, device=input_tensor.device, dtype=input_lens.dtype)
+            .unsqueeze(0)
+            .repeat(B, 1)
+        )  # [B, T_all_max]
+        index_mask = index < (self.left_padding_nums + input_lens).unsqueeze(1)  # [B, T_all_max]
 
-        tail_index_mask = torch.logical_not(
-            index >= (T_all.unsqueeze(1))) & index_mask
-        tail = torch.ones(T_all_max,
-                          dtype=input_lens.dtype,
-                          device=input_tensor.device).unsqueeze(0).repeat(B, 1) * (
-                              T_all_max - 1)  # [B, T_all_max]
-        indices = torch.where(torch.logical_or(index_mask, tail_index_mask),
-                              index, tail)
+        tail_index_mask = torch.logical_not(index >= (T_all.unsqueeze(1))) & index_mask
+        tail = torch.ones(T_all_max, dtype=input_lens.dtype, device=input_tensor.device).unsqueeze(
+            0
+        ).repeat(B, 1) * (
+            T_all_max - 1
+        )  # [B, T_all_max]
+        indices = torch.where(torch.logical_or(index_mask, tail_index_mask), index, tail)
         input_tensor = torch.gather(input_tensor, 1, indices.unsqueeze(2).repeat(1, 1, D))
 
         input_tensor = input_tensor.unfold(1, self.m, step=self.n).transpose(2, 3)
 
         return input_tensor.reshape(B, -1, D * self.m), new_len
 
-class WavFrontend():
-    """Conventional frontend structure for ASR.
-    """
+
+class WavFrontend:
+    """Conventional frontend structure for ASR."""
 
     def __init__(
-            self,
-            cmvn_file: str = None,
-            fs: int = 16000,
-            window: str = 'hamming',
-            n_mels: int = 80,
-            frame_length: int = 25,
-            frame_shift: int = 10,
-            filter_length_min: int = -1,
-            filter_length_max: float = -1,
-            lfr_m: int = 7,
-            lfr_n: int = 6,
-            dither: float = 1.0
+        self,
+        cmvn_file: str = None,
+        fs: int = 16000,
+        window: str = "hamming",
+        n_mels: int = 80,
+        frame_length: int = 25,
+        frame_shift: int = 10,
+        filter_length_min: int = -1,
+        filter_length_max: float = -1,
+        lfr_m: int = 7,
+        lfr_n: int = 6,
+        dither: float = 1.0,
     ) -> None:
 
         self.fs = fs
@@ -133,31 +135,33 @@
         batch, frame, dim = inputs.shape
         means = np.tile(self.cmvn[0:1, :dim], (frame, 1))
         vars = np.tile(self.cmvn[1:2, :dim], (frame, 1))
-        
+
         means = torch.from_numpy(means).to(inputs.device)
         vars = torch.from_numpy(vars).to(inputs.device)
         # print(inputs.shape, means.shape, vars.shape)
         inputs = (inputs + means) * vars
         return inputs
 
-    def load_cmvn(self,) -> np.ndarray:
-        with open(self.cmvn_file, 'r', encoding='utf-8') as f:
+    def load_cmvn(
+        self,
+    ) -> np.ndarray:
+        with open(self.cmvn_file, "r", encoding="utf-8") as f:
             lines = f.readlines()
 
         means_list = []
         vars_list = []
         for i in range(len(lines)):
             line_item = lines[i].split()
-            if line_item[0] == '<AddShift>':
+            if line_item[0] == "<AddShift>":
                 line_item = lines[i + 1].split()
-                if line_item[0] == '<LearnRateCoef>':
-                    add_shift_line = line_item[3:(len(line_item) - 1)]
+                if line_item[0] == "<LearnRateCoef>":
+                    add_shift_line = line_item[3 : (len(line_item) - 1)]
                     means_list = list(add_shift_line)
                     continue
-            elif line_item[0] == '<Rescale>':
+            elif line_item[0] == "<Rescale>":
                 line_item = lines[i + 1].split()
-                if line_item[0] == '<LearnRateCoef>':
-                    rescale_line = line_item[3:(len(line_item) - 1)]
+                if line_item[0] == "<LearnRateCoef>":
+                    rescale_line = line_item[3 : (len(line_item) - 1)]
                     vars_list = list(rescale_line)
                     continue
 
@@ -197,16 +201,14 @@
           * model_version: Model version
           * model_name: Model name
         """
-        self.model_config = model_config = json.loads(args['model_config'])
+        self.model_config = model_config = json.loads(args["model_config"])
         self.max_batch_size = max(model_config["max_batch_size"], 1)
         self.device = "cuda"
 
         # Get OUTPUT0 configuration
-        output0_config = pb_utils.get_output_config_by_name(
-            model_config, "speech")
+        output0_config = pb_utils.get_output_config_by_name(model_config, "speech")
         # Convert Triton types to numpy types
-        output0_dtype = pb_utils.triton_string_to_numpy(
-            output0_config['data_type'])
+        output0_dtype = pb_utils.triton_string_to_numpy(output0_config["data_type"])
 
         if output0_dtype == np.float32:
             self.output0_dtype = torch.float32
@@ -214,42 +216,36 @@
             self.output0_dtype = torch.float16
 
         # Get OUTPUT1 configuration
-        output1_config = pb_utils.get_output_config_by_name(
-            model_config, "speech_lengths")
+        output1_config = pb_utils.get_output_config_by_name(model_config, "speech_lengths")
         # Convert Triton types to numpy types
-        self.output1_dtype = pb_utils.triton_string_to_numpy(
-            output1_config['data_type'])
+        self.output1_dtype = pb_utils.triton_string_to_numpy(output1_config["data_type"])
 
-        params = self.model_config['parameters']
+        params = self.model_config["parameters"]
 
         for li in params.items():
             key, value = li
             value = value["string_value"]
             if key == "config_path":
-                with open(str(value), 'rb') as f:
+                with open(str(value), "rb") as f:
                     config = yaml.load(f, Loader=yaml.Loader)
             if key == "cmvn_path":
                 cmvn_path = str(value)
 
         opts = kaldifeat.FbankOptions()
-        opts.frame_opts.dither = 1.0 # TODO: 0.0 or 1.0
-        opts.frame_opts.window_type = config['frontend_conf']['window']
-        opts.mel_opts.num_bins = int(config['frontend_conf']['n_mels'])
-        opts.frame_opts.frame_shift_ms = float(config['frontend_conf']['frame_shift'])
-        opts.frame_opts.frame_length_ms = float(config['frontend_conf']['frame_length'])
-        opts.frame_opts.samp_freq = int(config['frontend_conf']['fs'])
+        opts.frame_opts.dither = 1.0  # TODO: 0.0 or 1.0
+        opts.frame_opts.window_type = config["frontend_conf"]["window"]
+        opts.mel_opts.num_bins = int(config["frontend_conf"]["n_mels"])
+        opts.frame_opts.frame_shift_ms = float(config["frontend_conf"]["frame_shift"])
+        opts.frame_opts.frame_length_ms = float(config["frontend_conf"]["frame_length"])
+        opts.frame_opts.samp_freq = int(config["frontend_conf"]["fs"])
         opts.device = torch.device(self.device)
         self.opts = opts
         self.feature_extractor = Fbank(self.opts)
         self.feature_size = opts.mel_opts.num_bins
 
-        self.frontend = WavFrontend(
-            cmvn_file=cmvn_path,
-            **config['frontend_conf'])
+        self.frontend = WavFrontend(cmvn_file=cmvn_path, **config["frontend_conf"])
 
-    def extract_feat(self,
-                     waveform_list: List[np.ndarray]
-                     ) -> Tuple[np.ndarray, np.ndarray]:
+    def extract_feat(self, waveform_list: List[np.ndarray]) -> Tuple[np.ndarray, np.ndarray]:
         feats, feats_len = [], []
         wavs = []
         for waveform in waveform_list:
@@ -258,18 +254,21 @@
 
         features = self.feature_extractor(wavs)
         features_len = [feature.shape[0] for feature in features]
-        speech = torch.zeros((len(features), max(features_len), self.opts.mel_opts.num_bins),
-                                dtype=self.output0_dtype, device=self.device)
+        speech = torch.zeros(
+            (len(features), max(features_len), self.opts.mel_opts.num_bins),
+            dtype=self.output0_dtype,
+            device=self.device,
+        )
         for i, feature in enumerate(features):
-            speech[i,:int(features_len[i])] = feature
-        speech_lens = torch.tensor(features_len,dtype=torch.int64).to(self.device)
-      
+            speech[i, : int(features_len[i])] = feature
+        speech_lens = torch.tensor(features_len, dtype=torch.int64).to(self.device)
+
         feats, feats_len = self.frontend.lfr(speech, speech_lens)
         feats_len = feats_len.type(torch.int32)
-        
+
         feats = self.frontend.apply_cmvn_batch(feats)
         feats = feats.type(self.output0_dtype)
-        
+
         return feats, feats_len
 
     def execute(self, requests):
@@ -294,23 +293,22 @@
         batch_len = []
         responses = []
         for request in requests:
-            
+
             input0 = pb_utils.get_input_tensor_by_name(request, "wav")
             input1 = pb_utils.get_input_tensor_by_name(request, "wav_lens")
 
-            cur_b_wav = input0.as_numpy() * (1 << 15) # b x -1
+            cur_b_wav = input0.as_numpy() * (1 << 15)  # b x -1
             total_waves.append(cur_b_wav)
 
         features, feats_len = self.extract_feat(total_waves)
 
         for i in range(features.shape[0]):
-            speech = features[i:i+1][:int(feats_len[i].cpu())]
+            speech = features[i : i + 1][: int(feats_len[i].cpu())]
             speech_lengths = feats_len[i].unsqueeze(0).unsqueeze(0)
 
             speech, speech_lengths = speech.cpu(), speech_lengths.cpu()
             out0 = pb_utils.Tensor.from_dlpack("speech", to_dlpack(speech))
-            out1 = pb_utils.Tensor.from_dlpack("speech_lengths",
-                                               to_dlpack(speech_lengths))
+            out1 = pb_utils.Tensor.from_dlpack("speech_lengths", to_dlpack(speech_lengths))
             inference_response = pb_utils.InferenceResponse(output_tensors=[out0, out1])
             responses.append(inference_response)
         return responses

--
Gitblit v1.9.1