# Created on 2024-01-01 # Author: GuAn Zhu # Modified from NVIDIA(https://github.com/wenet-e2e/wenet/blob/main/runtime/gpu/ # model_repo_stateful/feature_extractor/1/model.py) import triton_python_backend_utils as pb_utils from torch.utils.dlpack import from_dlpack import torch import kaldifeat from typing import List import json import numpy as np import yaml from collections import OrderedDict class LimitedDict(OrderedDict): def __init__(self, max_length): super().__init__() self.max_length = max_length def __setitem__(self, key, value): if len(self) >= self.max_length: self.popitem(last=False) super().__setitem__(key, value) class Fbank(torch.nn.Module): def __init__(self, opts): super(Fbank, self).__init__() self.fbank = kaldifeat.Fbank(opts) def forward(self, waves: List[torch.Tensor]): return self.fbank(waves) class Feat(object): def __init__(self, seqid, offset_ms, sample_rate, frame_stride, device='cpu'): self.seqid = seqid self.sample_rate = sample_rate self.wav = torch.tensor([], device=device) self.offset = int(offset_ms / 1000 * sample_rate) self.frames = None self.frame_stride = int(frame_stride) self.device = device self.lfr_m = 7 def add_wavs(self, wav: torch.tensor): wav = wav.to(self.device) self.wav = torch.cat((self.wav, wav), axis=0) def get_seg_wav(self): seg = self.wav[:] self.wav = self.wav[-self.offset:] return seg def add_frames(self, frames: torch.tensor): """ frames: seq_len x feat_sz """ if self.frames is None: self.frames = torch.cat((frames[0, :].repeat((self.lfr_m - 1) // 2, 1), frames), axis=0) else: self.frames = torch.cat([self.frames, frames], axis=0) def get_frames(self, num_frames: int): seg = self.frames[0: num_frames] self.frames = self.frames[self.frame_stride:] return seg class TritonPythonModel: """Your Python model must use the same class name. Every Python model that is created must have "TritonPythonModel" as the class name. """ def initialize(self, args): """`initialize` is called only once when the model is being loaded. Implementing `initialize` function is optional. This function allows the model to initialize any state associated with this model. Parameters ---------- args : dict Both keys and values are strings. The dictionary keys and values are: * model_config: A JSON string containing the model configuration * model_instance_kind: A string containing model instance kind * model_instance_device_id: A string containing model instance device ID * model_repository: Model repository path * model_version: Model version * model_name: Model name """ self.model_config = model_config = json.loads(args['model_config']) self.max_batch_size = max(model_config["max_batch_size"], 1) if "GPU" in model_config["instance_group"][0]["kind"]: self.device = "cuda" else: self.device = "cpu" # Get OUTPUT0 configuration output0_config = pb_utils.get_output_config_by_name( model_config, "speech") # Convert Triton types to numpy types self.output0_dtype = pb_utils.triton_string_to_numpy( output0_config['data_type']) if self.output0_dtype == np.float32: self.dtype = torch.float32 else: self.dtype = torch.float16 self.feature_size = output0_config['dims'][-1] self.decoding_window = output0_config['dims'][-2] params = self.model_config['parameters'] for li in params.items(): key, value = li value = value["string_value"] if key == "config_path": with open(str(value), 'rb') as f: config = yaml.load(f, Loader=yaml.Loader) opts = kaldifeat.FbankOptions() opts.frame_opts.dither = 0.0 opts.frame_opts.window_type = config['frontend_conf']['window'] opts.mel_opts.num_bins = int(config['frontend_conf']['n_mels']) opts.frame_opts.frame_shift_ms = float(config['frontend_conf']['frame_shift']) opts.frame_opts.frame_length_ms = float(config['frontend_conf']['frame_length']) opts.frame_opts.samp_freq = int(config['frontend_conf']['fs']) opts.device = torch.device(self.device) self.opts = opts self.feature_extractor = Fbank(self.opts) self.seq_feat = LimitedDict(1024) chunk_size_s = float(params["chunk_size_s"]["string_value"]) sample_rate = opts.frame_opts.samp_freq frame_shift_ms = opts.frame_opts.frame_shift_ms frame_length_ms = opts.frame_opts.frame_length_ms self.chunk_size = int(chunk_size_s * sample_rate) self.frame_stride = (chunk_size_s * 1000) // frame_shift_ms self.offset_ms = self.get_offset(frame_length_ms, frame_shift_ms) self.sample_rate = sample_rate def get_offset(self, frame_length_ms, frame_shift_ms): offset_ms = 0 while offset_ms + frame_shift_ms < frame_length_ms: offset_ms += frame_shift_ms return offset_ms def execute(self, requests): """`execute` must be implemented in every Python model. `execute` function receives a list of pb_utils.InferenceRequest as the only argument. This function is called when an inference is requested for this model. Parameters ---------- requests : list A list of pb_utils.InferenceRequest Returns ------- list A list of pb_utils.InferenceResponse. The length of this list must be the same as `requests` """ total_waves = [] responses = [] batch_seqid = [] end_seqid = {} for request in requests: input0 = pb_utils.get_input_tensor_by_name(request, "wav") wav = from_dlpack(input0.to_dlpack())[0] # input1 = pb_utils.get_input_tensor_by_name(request, "wav_lens") # wav_len = from_dlpack(input1.to_dlpack())[0] wav_len = len(wav) if wav_len < self.chunk_size: temp = torch.zeros(self.chunk_size, dtype=torch.float32, device=self.device) temp[0:wav_len] = wav[:] wav = temp in_start = pb_utils.get_input_tensor_by_name(request, "START") start = in_start.as_numpy()[0][0] in_ready = pb_utils.get_input_tensor_by_name(request, "READY") ready = in_ready.as_numpy()[0][0] in_corrid = pb_utils.get_input_tensor_by_name(request, "CORRID") corrid = in_corrid.as_numpy()[0][0] in_end = pb_utils.get_input_tensor_by_name(request, "END") end = in_end.as_numpy()[0][0] if start: self.seq_feat[corrid] = Feat(corrid, self.offset_ms, self.sample_rate, self.frame_stride, self.device) if ready: self.seq_feat[corrid].add_wavs(wav) batch_seqid.append(corrid) if end: end_seqid[corrid] = 1 wav = self.seq_feat[corrid].get_seg_wav() * 32768 total_waves.append(wav) features = self.feature_extractor(total_waves) for corrid, frames in zip(batch_seqid, features): self.seq_feat[corrid].add_frames(frames) speech = self.seq_feat[corrid].get_frames(self.decoding_window) out_tensor0 = pb_utils.Tensor("speech", torch.unsqueeze(speech, 0).to("cpu").numpy()) output_tensors = [out_tensor0] response = pb_utils.InferenceResponse(output_tensors=output_tensors) responses.append(response) if corrid in end_seqid: del self.seq_feat[corrid] return responses def finalize(self): print("Remove feature extractor!")