liugz18
2024-07-18 d80ac2fd2df4e7fb8a28acfa512bb11472b5cc99
runtime/triton_gpu/model_repo_paraformer_large_online/cif_search/1/model.py
@@ -23,10 +23,14 @@
class CIFSearch:
    """CIFSearch: https://github.com/alibaba-damo-academy/FunASR/blob/main/runtime/python/onnxruntime/funasr_onnx
    /paraformer_online_bin.py """
    /paraformer_online_bin.py"""
    def __init__(self):
        self.cache = {"cif_hidden": np.zeros((1, 1, 512)).astype(np.float32),
                      "cif_alphas": np.zeros((1, 1)).astype(np.float32), "last_chunk": False}
        self.cache = {
            "cif_hidden": np.zeros((1, 1, 512)).astype(np.float32),
            "cif_alphas": np.zeros((1, 1)).astype(np.float32),
            "last_chunk": False,
        }
        self.chunk_size = [5, 10, 5]
        self.tail_threshold = 0.45
        self.cif_threshold = 1.0
@@ -38,8 +42,8 @@
        list_frames = []
        cache_alphas = []
        cache_hiddens = []
        alphas[:, :self.chunk_size[0]] = 0.0
        alphas[:, sum(self.chunk_size[:2]):] = 0.0
        alphas[:, : self.chunk_size[0]] = 0.0
        alphas[:, sum(self.chunk_size[:2]) :] = 0.0
        if self.cache is not None and "cif_alphas" in self.cache and "cif_hidden" in self.cache:
            hidden = np.concatenate((self.cache["cif_hidden"], hidden), axis=1)
@@ -95,7 +99,9 @@
        self.cache["cif_hidden"] = np.stack(cache_hiddens, axis=0)
        self.cache["cif_hidden"] = np.expand_dims(self.cache["cif_hidden"], axis=0)
        return np.stack(list_ls, axis=0).astype(np.float32), np.stack(token_length, axis=0).astype(np.int32)
        return np.stack(list_ls, axis=0).astype(np.float32), np.stack(token_length, axis=0).astype(
            np.int32
        )
class TritonPythonModel:
@@ -119,18 +125,16 @@
          * model_version: Model version
          * model_name: Model name
        """
        self.model_config = model_config = json.loads(args['model_config'])
        self.model_config = model_config = json.loads(args["model_config"])
        self.max_batch_size = max(model_config["max_batch_size"], 1)
        # # Get OUTPUT0 configuration
        output0_config = pb_utils.get_output_config_by_name(
            model_config, "transcripts")
        output0_config = pb_utils.get_output_config_by_name(model_config, "transcripts")
        # # Convert Triton types to numpy types
        self.out0_dtype = pb_utils.triton_string_to_numpy(
            output0_config['data_type'])
        self.out0_dtype = pb_utils.triton_string_to_numpy(output0_config["data_type"])
        self.init_vocab(self.model_config['parameters'])
        self.init_vocab(self.model_config["parameters"])
        self.cif_search_cache = LimitedDict(1024)
        self.start = LimitedDict(1024)
@@ -142,9 +146,9 @@
                self.vocab_dict = self.load_vocab(value)
    def load_vocab(self, vocab_file):
        with open(str(vocab_file), 'rb') as f:
        with open(str(vocab_file), "rb") as f:
            config = yaml.load(f, Loader=yaml.Loader)
        return config['token_list']
        return config["token_list"]
    async def execute(self, requests):
        """`execute` must be implemented in every Python model. `execute`
@@ -187,7 +191,7 @@
            in_start = pb_utils.get_input_tensor_by_name(request, "START")
            start = in_start.as_numpy()[0][0]
            in_corrid = pb_utils.get_input_tensor_by_name(request, "CORRID")
            corrid = in_corrid.as_numpy()[0][0]
@@ -202,19 +206,21 @@
                self.start[corrid] = 1
            if end:
                self.cif_search_cache[corrid].cache["last_chunk"] = True
            acoustic, acoustic_len = self.cif_search_cache[corrid].infer(hidden, alphas)
            batch_result[corrid] = ''
            batch_result[corrid] = ""
            if acoustic.shape[1] == 0:
                continue
                continue
            else:
                qualified_corrid.append(corrid)
                input_tensor0 = pb_utils.Tensor("enc", hidden)
                input_tensor1 = pb_utils.Tensor("enc_len", np.array([hidden_len], dtype=np.int32))
                input_tensor2 = pb_utils.Tensor("acoustic_embeds", acoustic)
                input_tensor3 = pb_utils.Tensor("acoustic_embeds_len", np.array([acoustic_len], dtype=np.int32))
                input_tensor3 = pb_utils.Tensor(
                    "acoustic_embeds_len", np.array([acoustic_len], dtype=np.int32)
                )
                input_tensors = [input_tensor0, input_tensor1, input_tensor2, input_tensor3]
                if self.start[corrid] and end:
                    flag = 3
                elif end:
@@ -225,12 +231,12 @@
                else:
                    flag = 0
                inference_request = pb_utils.InferenceRequest(
                    model_name='decoder',
                    requested_output_names=['sample_ids'],
                    model_name="decoder",
                    requested_output_names=["sample_ids"],
                    inputs=input_tensors,
                    request_id='',
                    request_id="",
                    correlation_id=corrid,
                    flags=flag
                    flags=flag,
                )
                inference_response_awaits.append(inference_request.async_exec())
@@ -240,9 +246,9 @@
            if inference_response.has_error():
                raise pb_utils.TritonModelException(inference_response.error().message())
            else:
                sample_ids = pb_utils.get_output_tensor_by_name(inference_response, 'sample_ids')
                sample_ids = pb_utils.get_output_tensor_by_name(inference_response, "sample_ids")
                token_ids = from_dlpack(sample_ids.to_dlpack()).cpu().numpy()[0]
                # Change integer-ids to tokens
                tokens = [self.vocab_dict[token_id] for token_id in token_ids]
                batch_result[index_corrid] = "".join(tokens)
@@ -252,7 +258,7 @@
            out0 = pb_utils.Tensor("transcripts", sent.astype(self.out0_dtype))
            inference_response = pb_utils.InferenceResponse(output_tensors=[out0])
            responses.append(inference_response)
            if batch_end[i]:
                del self.cif_search_cache[index_corrid]
                del self.start[index_corrid]
@@ -264,5 +270,4 @@
        Implementing `finalize` function is optional. This function allows
        the model to perform any necessary clean ups before exit.
        """
        print('Cleaning up...')
        print("Cleaning up...")