liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
runtime/triton_gpu/model_repo_paraformer_large_online/feature_extractor/1/model.py
@@ -36,7 +36,7 @@
class Feat(object):
    def __init__(self, seqid, offset_ms, sample_rate, frame_stride, device='cpu'):
    def __init__(self, seqid, offset_ms, sample_rate, frame_stride, device="cpu"):
        self.seqid = seqid
        self.sample_rate = sample_rate
        self.wav = torch.tensor([], device=device)
@@ -45,14 +45,14 @@
        self.frame_stride = int(frame_stride)
        self.device = device
        self.lfr_m = 7
    def add_wavs(self, wav: torch.tensor):
        wav = wav.to(self.device)
        self.wav = torch.cat((self.wav, wav), axis=0)
    def get_seg_wav(self):
        seg = self.wav[:]
        self.wav = self.wav[-self.offset:]
        self.wav = self.wav[-self.offset :]
        return seg
    def add_frames(self, frames: torch.tensor):
@@ -60,14 +60,13 @@
        frames: seq_len x feat_sz
        """
        if self.frames is None:
            self.frames = torch.cat((frames[0, :].repeat((self.lfr_m - 1) // 2, 1),
                                     frames), axis=0)
            self.frames = torch.cat((frames[0, :].repeat((self.lfr_m - 1) // 2, 1), frames), axis=0)
        else:
            self.frames = torch.cat([self.frames, frames], axis=0)
    def get_frames(self, num_frames: int):
        seg = self.frames[0: num_frames]
        self.frames = self.frames[self.frame_stride:]
        seg = self.frames[0:num_frames]
        self.frames = self.frames[self.frame_stride :]
        return seg
@@ -91,7 +90,7 @@
          * model_version: Model version
          * model_name: Model name
        """
        self.model_config = model_config = json.loads(args['model_config'])
        self.model_config = model_config = json.loads(args["model_config"])
        self.max_batch_size = max(model_config["max_batch_size"], 1)
        if "GPU" in model_config["instance_group"][0]["kind"]:
@@ -100,35 +99,33 @@
            self.device = "cpu"
        # Get OUTPUT0 configuration
        output0_config = pb_utils.get_output_config_by_name(
            model_config, "speech")
        output0_config = pb_utils.get_output_config_by_name(model_config, "speech")
        # Convert Triton types to numpy types
        self.output0_dtype = pb_utils.triton_string_to_numpy(
            output0_config['data_type'])
        self.output0_dtype = pb_utils.triton_string_to_numpy(output0_config["data_type"])
        if self.output0_dtype == np.float32:
            self.dtype = torch.float32
        else:
            self.dtype = torch.float16
        self.feature_size = output0_config['dims'][-1]
        self.decoding_window = output0_config['dims'][-2]
        self.feature_size = output0_config["dims"][-1]
        self.decoding_window = output0_config["dims"][-2]
        params = self.model_config['parameters']
        params = self.model_config["parameters"]
        for li in params.items():
            key, value = li
            value = value["string_value"]
            if key == "config_path":
                with open(str(value), 'rb') as f:
                with open(str(value), "rb") as f:
                    config = yaml.load(f, Loader=yaml.Loader)
        opts = kaldifeat.FbankOptions()
        opts.frame_opts.dither = 0.0
        opts.frame_opts.window_type = config['frontend_conf']['window']
        opts.mel_opts.num_bins = int(config['frontend_conf']['n_mels'])
        opts.frame_opts.frame_shift_ms = float(config['frontend_conf']['frame_shift'])
        opts.frame_opts.frame_length_ms = float(config['frontend_conf']['frame_length'])
        opts.frame_opts.samp_freq = int(config['frontend_conf']['fs'])
        opts.frame_opts.window_type = config["frontend_conf"]["window"]
        opts.mel_opts.num_bins = int(config["frontend_conf"]["n_mels"])
        opts.frame_opts.frame_shift_ms = float(config["frontend_conf"]["frame_shift"])
        opts.frame_opts.frame_length_ms = float(config["frontend_conf"]["frame_length"])
        opts.frame_opts.samp_freq = int(config["frontend_conf"]["fs"])
        opts.device = torch.device(self.device)
        self.opts = opts
        self.feature_extractor = Fbank(self.opts)
@@ -177,8 +174,7 @@
            # wav_len = from_dlpack(input1.to_dlpack())[0]
            wav_len = len(wav)
            if wav_len < self.chunk_size:
                temp = torch.zeros(self.chunk_size, dtype=torch.float32,
                                   device=self.device)
                temp = torch.zeros(self.chunk_size, dtype=torch.float32, device=self.device)
                temp[0:wav_len] = wav[:]
                wav = temp
@@ -192,10 +188,9 @@
            end = in_end.as_numpy()[0][0]
            if start:
                self.seq_feat[corrid] = Feat(corrid, self.offset_ms,
                                             self.sample_rate,
                                             self.frame_stride,
                                             self.device)
                self.seq_feat[corrid] = Feat(
                    corrid, self.offset_ms, self.sample_rate, self.frame_stride, self.device
                )
            if ready:
                self.seq_feat[corrid].add_wavs(wav)