zhu-gu-an
2024-01-13 49e8e9d8fc1209c347aa2c2c65c6eb067b9f79d4
add triton paraformer large online (#1242)

* add triton paraformer large online
12个文件已添加
10109 ■■■■■ 已修改文件
runtime/triton_gpu/README_ONLINE.md 64 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/triton_gpu/model_repo_paraformer_large_online/cif_search/1/model.py 268 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/triton_gpu/model_repo_paraformer_large_online/cif_search/config.pbtxt 111 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/triton_gpu/model_repo_paraformer_large_online/decoder/config.pbtxt 274 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/triton_gpu/model_repo_paraformer_large_online/encoder/config.pbtxt 77 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/triton_gpu/model_repo_paraformer_large_online/feature_extractor/1/model.py 221 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/triton_gpu/model_repo_paraformer_large_online/feature_extractor/config.pbtxt 109 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/triton_gpu/model_repo_paraformer_large_online/feature_extractor/config.yaml 8639 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/triton_gpu/model_repo_paraformer_large_online/lfr_cmvn_pe/am.mvn 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/triton_gpu/model_repo_paraformer_large_online/lfr_cmvn_pe/config.pbtxt 85 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/triton_gpu/model_repo_paraformer_large_online/lfr_cmvn_pe/export_lfr_cmvn_pe_onnx.py 131 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/triton_gpu/model_repo_paraformer_large_online/streaming_paraformer/config.pbtxt 122 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
runtime/triton_gpu/README_ONLINE.md
New file
@@ -0,0 +1,64 @@
### Steps:
1. Prepare model repo files
* git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online-onnx.git
* Convert lfr_cmvn_pe.onnx model. For example: python export_lfr_cmvn_pe_onnx.py
* If you export to onnx, you should have several model files in `${MODEL_DIR}`:
```
├── README.md
└── model_repo_paraformer_large_online
    ├── cif_search
    │   ├── 1
    │   │   └── model.py
    │   └── config.pbtxt
    ├── decoder
    │   ├── 1
    │   │   └── decoder.onnx
    │   └── config.pbtxt
    ├── encoder
    │   ├── 1
    │   │   └── model.onnx
    │   └── config.pbtxt
    ├── feature_extractor
    │   ├── 1
    │   │   └── model.py
    │   ├── config.pbtxt
    │   └── config.yaml
    ├── lfr_cmvn_pe
    │   ├── 1
    │   │   └── lfr_cmvn_pe.onnx
    │   ├── am.mvn
    │   ├── config.pbtxt
    │   └── export_lfr_cmvn_pe_onnx.py
    └── streaming_paraformer
        ├── 1
        └── config.pbtxt
```
2. Follow below instructions to launch triton server
```sh
# using docker image Dockerfile/Dockerfile.server
docker build . -f Dockerfile/Dockerfile.server -t triton-paraformer:23.01
docker run -it --rm --name "paraformer_triton_server" --gpus all -v <path_host/model_repo_paraformer_large_online>:/workspace/ --shm-size 1g --net host triton-paraformer:23.01
# launch the service
cd /workspace
tritonserver --model-repository model_repo_paraformer_large_online \
             --pinned-memory-pool-byte-size=512000000 \
             --cuda-memory-pool-byte-size=0:1024000000
```
### Performance benchmark with a single A10
* FP32, onnx, [paraformer larger online](https://modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online-onnx/summary
),Our chunksize is 10 * 960 / 16000 = 0.6 s, so we should care about the perf of latency less than 0.6s so that it can be a realtime application.
| Concurrency | Throughput | Latency_p50 (ms) | Latency_p90 (ms) | Latency_p95 (ms) | Latency_p99 (ms) |
|-------------|------------|------------------|------------------|------------------|------------------|
| 20          | 309.252    | 56.913          | 76.267          | 85.598          | 138.462          |
| 40          | 391.058    | 97.911           | 145.509          | 150.545          | 185.399          |
| 60          | 426.269    | 138.244          | 185.855          | 201.016          | 236.528          |
| 80          | 431.781    | 170.991          | 227.983          | 252.453          | 412.273          |
| 100         | 473.351    | 206.205          | 262.612          | 288.964          | 463.337          |
runtime/triton_gpu/model_repo_paraformer_large_online/cif_search/1/model.py
New file
@@ -0,0 +1,268 @@
# Created on 2024-01-01
# Author: GuAn Zhu
import triton_python_backend_utils as pb_utils
import numpy as np
from torch.utils.dlpack import from_dlpack
import json
import yaml
import asyncio
from collections import OrderedDict
class LimitedDict(OrderedDict):
    def __init__(self, max_length):
        super().__init__()
        self.max_length = max_length
    def __setitem__(self, key, value):
        if len(self) >= self.max_length:
            self.popitem(last=False)
        super().__setitem__(key, value)
class CIFSearch:
    """CIFSearch: https://github.com/alibaba-damo-academy/FunASR/blob/main/runtime/python/onnxruntime/funasr_onnx
    /paraformer_online_bin.py """
    def __init__(self):
        self.cache = {"cif_hidden": np.zeros((1, 1, 512)).astype(np.float32),
                      "cif_alphas": np.zeros((1, 1)).astype(np.float32), "last_chunk": False}
        self.chunk_size = [5, 10, 5]
        self.tail_threshold = 0.45
        self.cif_threshold = 1.0
    def infer(self, hidden, alphas):
        batch_size, len_time, hidden_size = hidden.shape
        token_length = []
        list_fires = []
        list_frames = []
        cache_alphas = []
        cache_hiddens = []
        alphas[:, :self.chunk_size[0]] = 0.0
        alphas[:, sum(self.chunk_size[:2]):] = 0.0
        if self.cache is not None and "cif_alphas" in self.cache and "cif_hidden" in self.cache:
            hidden = np.concatenate((self.cache["cif_hidden"], hidden), axis=1)
            alphas = np.concatenate((self.cache["cif_alphas"], alphas), axis=1)
        if self.cache is not None and "last_chunk" in self.cache and self.cache["last_chunk"]:
            tail_hidden = np.zeros((batch_size, 1, hidden_size)).astype(np.float32)
            tail_alphas = np.array([[self.tail_threshold]]).astype(np.float32)
            tail_alphas = np.tile(tail_alphas, (batch_size, 1))
            hidden = np.concatenate((hidden, tail_hidden), axis=1)
            alphas = np.concatenate((alphas, tail_alphas), axis=1)
        len_time = alphas.shape[1]
        for b in range(batch_size):
            integrate = 0.0
            frames = np.zeros(hidden_size).astype(np.float32)
            list_frame = []
            list_fire = []
            for t in range(len_time):
                alpha = alphas[b][t]
                if alpha + integrate < self.cif_threshold:
                    integrate += alpha
                    list_fire.append(integrate)
                    frames += alpha * hidden[b][t]
                else:
                    frames += (self.cif_threshold - integrate) * hidden[b][t]
                    list_frame.append(frames)
                    integrate += alpha
                    list_fire.append(integrate)
                    integrate -= self.cif_threshold
                    frames = integrate * hidden[b][t]
            cache_alphas.append(integrate)
            if integrate > 0.0:
                cache_hiddens.append(frames / integrate)
            else:
                cache_hiddens.append(frames)
            token_length.append(len(list_frame))
            list_fires.append(list_fire)
            list_frames.append(list_frame)
        max_token_len = max(token_length)
        list_ls = []
        for b in range(batch_size):
            pad_frames = np.zeros((max_token_len - token_length[b], hidden_size)).astype(np.float32)
            if token_length[b] == 0:
                list_ls.append(pad_frames)
            else:
                list_ls.append(np.concatenate((list_frames[b], pad_frames), axis=0))
        self.cache["cif_alphas"] = np.stack(cache_alphas, axis=0)
        self.cache["cif_alphas"] = np.expand_dims(self.cache["cif_alphas"], axis=0)
        self.cache["cif_hidden"] = np.stack(cache_hiddens, axis=0)
        self.cache["cif_hidden"] = np.expand_dims(self.cache["cif_hidden"], axis=0)
        return np.stack(list_ls, axis=0).astype(np.float32), np.stack(token_length, axis=0).astype(np.int32)
class TritonPythonModel:
    """Your Python model must use the same class name. Every Python model
    that is created must have "TritonPythonModel" as the class name.
    """
    def initialize(self, args):
        """`initialize` is called only once when the model is being loaded.
        Implementing `initialize` function is optional. This function allows
        the model to initialize any state associated with this model.
        Parameters
        ----------
        args : dict
          Both keys and values are strings. The dictionary keys and values are:
          * model_config: A JSON string containing the model configuration
          * model_instance_kind: A string containing model instance kind
          * model_instance_device_id: A string containing model instance device ID
          * model_repository: Model repository path
          * model_version: Model version
          * model_name: Model name
        """
        self.model_config = model_config = json.loads(args['model_config'])
        self.max_batch_size = max(model_config["max_batch_size"], 1)
        # # Get OUTPUT0 configuration
        output0_config = pb_utils.get_output_config_by_name(
            model_config, "transcripts")
        # # Convert Triton types to numpy types
        self.out0_dtype = pb_utils.triton_string_to_numpy(
            output0_config['data_type'])
        self.init_vocab(self.model_config['parameters'])
        self.cif_search_cache = LimitedDict(1024)
        self.start = LimitedDict(1024)
    def init_vocab(self, parameters):
        for li in parameters.items():
            key, value = li
            value = value["string_value"]
            if key == "vocabulary":
                self.vocab_dict = self.load_vocab(value)
    def load_vocab(self, vocab_file):
        with open(str(vocab_file), 'rb') as f:
            config = yaml.load(f, Loader=yaml.Loader)
        return config['token_list']
    async def execute(self, requests):
        """`execute` must be implemented in every Python model. `execute`
        function receives a list of pb_utils.InferenceRequest as the only
        argument. This function is called when an inference is requested
        for this model.
        Parameters
        ----------
        requests : list
          A list of pb_utils.InferenceRequest
        Returns
        -------
        list
          A list of pb_utils.InferenceResponse. The length of this list must
          be the same as `requests`
        """
        # Every Python backend must iterate through list of requests and create
        # an instance of pb_utils.InferenceResponse class for each of them. You
        # should avoid storing any of the input Tensors in the class attributes
        # as they will be overridden in subsequent inference requests. You can
        # make a copy of the underlying NumPy array and store it if it is
        # required.
        batch_end = []
        responses = []
        batch_corrid = []
        qualified_corrid = []
        batch_result = {}
        inference_response_awaits = []
        for request in requests:
            hidden = pb_utils.get_input_tensor_by_name(request, "enc")
            hidden = from_dlpack(hidden.to_dlpack()).cpu().numpy()
            alphas = pb_utils.get_input_tensor_by_name(request, "alphas")
            alphas = from_dlpack(alphas.to_dlpack()).cpu().numpy()
            hidden_len = pb_utils.get_input_tensor_by_name(request, "enc_len")
            hidden_len = from_dlpack(hidden_len.to_dlpack()).cpu().numpy()
            in_start = pb_utils.get_input_tensor_by_name(request, "START")
            start = in_start.as_numpy()[0][0]
            in_corrid = pb_utils.get_input_tensor_by_name(request, "CORRID")
            corrid = in_corrid.as_numpy()[0][0]
            in_end = pb_utils.get_input_tensor_by_name(request, "END")
            end = in_end.as_numpy()[0][0]
            batch_end.append(end)
            batch_corrid.append(corrid)
            if start:
                self.cif_search_cache[corrid] = CIFSearch()
                self.start[corrid] = 1
            if end:
                self.cif_search_cache[corrid].cache["last_chunk"] = True
            acoustic, acoustic_len = self.cif_search_cache[corrid].infer(hidden, alphas)
            batch_result[corrid] = ''
            if acoustic.shape[1] == 0:
                continue
            else:
                qualified_corrid.append(corrid)
                input_tensor0 = pb_utils.Tensor("enc", hidden)
                input_tensor1 = pb_utils.Tensor("enc_len", np.array([hidden_len], dtype=np.int32))
                input_tensor2 = pb_utils.Tensor("acoustic_embeds", acoustic)
                input_tensor3 = pb_utils.Tensor("acoustic_embeds_len", np.array([acoustic_len], dtype=np.int32))
                input_tensors = [input_tensor0, input_tensor1, input_tensor2, input_tensor3]
                if self.start[corrid] and end:
                    flag = 3
                elif end:
                    flag = 2
                elif self.start[corrid]:
                    flag = 1
                    self.start[corrid] = 0
                else:
                    flag = 0
                inference_request = pb_utils.InferenceRequest(
                    model_name='decoder',
                    requested_output_names=['sample_ids'],
                    inputs=input_tensors,
                    request_id='',
                    correlation_id=corrid,
                    flags=flag
                )
                inference_response_awaits.append(inference_request.async_exec())
        inference_responses = await asyncio.gather(*inference_response_awaits)
        for index_corrid, inference_response in zip(qualified_corrid, inference_responses):
            if inference_response.has_error():
                raise pb_utils.TritonModelException(inference_response.error().message())
            else:
                sample_ids = pb_utils.get_output_tensor_by_name(inference_response, 'sample_ids')
                token_ids = from_dlpack(sample_ids.to_dlpack()).cpu().numpy()[0]
                # Change integer-ids to tokens
                tokens = [self.vocab_dict[token_id] for token_id in token_ids]
                batch_result[index_corrid] = "".join(tokens)
        for i, index_corrid in enumerate(batch_corrid):
            sent = np.array([batch_result[index_corrid]])
            out0 = pb_utils.Tensor("transcripts", sent.astype(self.out0_dtype))
            inference_response = pb_utils.InferenceResponse(output_tensors=[out0])
            responses.append(inference_response)
            if batch_end[i]:
                del self.cif_search_cache[index_corrid]
                del self.start[index_corrid]
        return responses
    def finalize(self):
        """`finalize` is called only once when the model is being unloaded.
        Implementing `finalize` function is optional. This function allows
        the model to perform any necessary clean ups before exit.
        """
        print('Cleaning up...')
runtime/triton_gpu/model_repo_paraformer_large_online/cif_search/config.pbtxt
New file
@@ -0,0 +1,111 @@
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Created on 2024-01-01
# Author: GuAn Zhu
name: "cif_search"
backend: "python"
max_batch_size: 128
sequence_batching{
    max_sequence_idle_microseconds: 15000000
    oldest {
      max_candidate_sequences: 1024
      preferred_batch_size: [32, 64, 128]
    }
    control_input [
        {
            name: "START",
            control [
                {
                    kind: CONTROL_SEQUENCE_START
                    fp32_false_true: [0, 1]
                }
            ]
        },
        {
            name: "READY"
            control [
                {
                    kind: CONTROL_SEQUENCE_READY
                    fp32_false_true: [0, 1]
                }
            ]
        },
        {
            name: "CORRID",
            control [
                {
                    kind: CONTROL_SEQUENCE_CORRID
                    data_type: TYPE_UINT64
                }
            ]
        },
        {
            name: "END",
            control [
                {
                    kind: CONTROL_SEQUENCE_END
                    fp32_false_true: [0, 1]
                }
            ]
        }
    ]
}
parameters [
  {
    key: "vocabulary",
    value: { string_value: "model_repo_paraformer_large_online/feature_extractor/config.yaml"}
  },
  { key: "FORCE_CPU_ONLY_INPUT_TENSORS"
    value: {string_value:"no"}
  }
]
input [
  {
    name: "enc"
    data_type: TYPE_FP32
    dims: [-1, 512]
  },
  {
    name: "enc_len"
    data_type: TYPE_INT32
    dims: [1]
    reshape: { shape: [ ] }
  },
  {
    name: 'alphas'
    data_type: TYPE_FP32
    dims: [-1]
  }
]
output [
  {
    name: "transcripts"
    data_type: TYPE_STRING
    dims: [1]
  }
]
instance_group [
    {
      count: 6
      kind: KIND_CPU
    }
  ]
runtime/triton_gpu/model_repo_paraformer_large_online/decoder/config.pbtxt
New file
@@ -0,0 +1,274 @@
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Created on 2024-01-01
# Author: GuAn Zhu
name: "decoder"
backend: "onnxruntime"
default_model_filename: "decoder.onnx"
max_batch_size: 128
sequence_batching{
    max_sequence_idle_microseconds: 15000000
    oldest {
      max_candidate_sequences: 1024
      preferred_batch_size: [16, 32, 64]
    }
    control_input [
    ]
    state [
    {
      input_name: "in_cache_0"
      output_name: "out_cache_0"
      data_type: TYPE_FP32
      dims: [ 512, 10 ]
      initial_state: {
       data_type: TYPE_FP32
       dims: [ 512, 10]
       zero_data: true
       name: "initial state"
      }
    },
    {
      input_name: "in_cache_1"
      output_name: "out_cache_1"
      data_type: TYPE_FP32
      dims: [ 512, 10 ]
      initial_state: {
       data_type: TYPE_FP32
       dims: [ 512, 10]
       zero_data: true
       name: "initial state"
      }
    },
    {
      input_name: "in_cache_2"
      output_name: "out_cache_2"
      data_type: TYPE_FP32
      dims: [ 512, 10 ]
      initial_state: {
       data_type: TYPE_FP32
       dims: [ 512, 10]
       zero_data: true
       name: "initial state"
      }
    },
    {
      input_name: "in_cache_3"
      output_name: "out_cache_3"
      data_type: TYPE_FP32
      dims: [ 512, 10 ]
      initial_state: {
       data_type: TYPE_FP32
       dims: [ 512, 10]
       zero_data: true
       name: "initial state"
      }
    },
    {
      input_name: "in_cache_4"
      output_name: "out_cache_4"
      data_type: TYPE_FP32
      dims: [ 512, 10 ]
      initial_state: {
       data_type: TYPE_FP32
       dims: [ 512, 10]
       zero_data: true
       name: "initial state"
      }
    },
    {
      input_name: "in_cache_5"
      output_name: "out_cache_5"
      data_type: TYPE_FP32
      dims: [ 512, 10 ]
      initial_state: {
       data_type: TYPE_FP32
       dims: [ 512, 10]
       zero_data: true
       name: "initial state"
      }
    },
    {
      input_name: "in_cache_6"
      output_name: "out_cache_6"
      data_type: TYPE_FP32
      dims: [ 512, 10 ]
      initial_state: {
       data_type: TYPE_FP32
       dims: [ 512, 10]
       zero_data: true
       name: "initial state"
      }
    },
    {
      input_name: "in_cache_7"
      output_name: "out_cache_7"
      data_type: TYPE_FP32
      dims: [ 512, 10 ]
      initial_state: {
       data_type: TYPE_FP32
       dims: [ 512, 10]
       zero_data: true
       name: "initial state"
      }
    },
    {
      input_name: "in_cache_8"
      output_name: "out_cache_8"
      data_type: TYPE_FP32
      dims: [ 512, 10 ]
      initial_state: {
       data_type: TYPE_FP32
       dims: [ 512, 10]
       zero_data: true
       name: "initial state"
      }
    },
    {
      input_name: "in_cache_9"
      output_name: "out_cache_9"
      data_type: TYPE_FP32
      dims: [ 512, 10 ]
      initial_state: {
       data_type: TYPE_FP32
       dims: [ 512, 10]
       zero_data: true
       name: "initial state"
      }
    },
    {
      input_name: "in_cache_10"
      output_name: "out_cache_10"
      data_type: TYPE_FP32
      dims: [ 512, 10 ]
      initial_state: {
       data_type: TYPE_FP32
       dims: [ 512, 10]
       zero_data: true
       name: "initial state"
      }
    },
    {
      input_name: "in_cache_11"
      output_name: "out_cache_11"
      data_type: TYPE_FP32
      dims: [ 512, 10 ]
      initial_state: {
       data_type: TYPE_FP32
       dims: [ 512, 10]
       zero_data: true
       name: "initial state"
      }
    },
    {
      input_name: "in_cache_12"
      output_name: "out_cache_12"
      data_type: TYPE_FP32
      dims: [ 512, 10 ]
      initial_state: {
       data_type: TYPE_FP32
       dims: [ 512, 10]
       zero_data: true
       name: "initial state"
      }
    },
    {
      input_name: "in_cache_13"
      output_name: "out_cache_13"
      data_type: TYPE_FP32
      dims: [ 512, 10 ]
      initial_state: {
       data_type: TYPE_FP32
       dims: [ 512, 10]
       zero_data: true
       name: "initial state"
      }
    },
    {
      input_name: "in_cache_14"
      output_name: "out_cache_14"
      data_type: TYPE_FP32
      dims: [ 512, 10 ]
      initial_state: {
       data_type: TYPE_FP32
       dims: [ 512, 10]
       zero_data: true
       name: "initial state"
      }
    },
    {
      input_name: "in_cache_15"
      output_name: "out_cache_15"
      data_type: TYPE_FP32
      dims: [ 512, 10 ]
      initial_state: {
       data_type: TYPE_FP32
       dims: [ 512, 10]
       zero_data: true
       name: "initial state"
      }
    }
  ]
}
input [
  {
    name: "enc"
    data_type: TYPE_FP32
    dims: [-1, 512]
  },
  {
    name: "enc_len"
    data_type: TYPE_INT32
    dims: [1]
    reshape: { shape: [ ] }
  },
  {
    name: "acoustic_embeds"
    data_type: TYPE_FP32
    dims: [-1, 512]
  },
  {
    name: "acoustic_embeds_len"
    data_type: TYPE_INT32
    dims: [1]
    reshape: { shape: [ ] }
  }
]
output [
  {
    name: "logits"
    data_type: TYPE_FP32
    dims: [-1, 8404]
  },
  {
    name: "sample_ids"
    data_type: TYPE_INT64
    dims: [-1]
  }
]
instance_group [
    {
      count: 1
      kind: KIND_GPU
    }
]
runtime/triton_gpu/model_repo_paraformer_large_online/encoder/config.pbtxt
New file
@@ -0,0 +1,77 @@
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Created on 2024-01-01
# Author: GuAn Zhu
name: "encoder"
backend: "onnxruntime"
default_model_filename: "model.onnx"
max_batch_size: 128
sequence_batching{
    max_sequence_idle_microseconds: 15000000
    oldest {
      max_candidate_sequences: 1024
      preferred_batch_size: [32, 64, 128]
      max_queue_delay_microseconds: 300
    }
    control_input [
    ]
    state [
  ]
}
input [
  {
    name: "speech"
    data_type: TYPE_FP32
    dims: [-1, 560]
  },
  {
    name: "speech_lengths"
    data_type: TYPE_INT32
    dims: [1]
    reshape: { shape: [ ] }
  }
]
output [
  {
    name: "enc"
    data_type: TYPE_FP32
    dims: [-1, 512]
  },
  {
    name: "enc_len"
    data_type: TYPE_INT32
    dims: [1]
    reshape: { shape: [ ] }
  },
  {
    name: "alphas"
    data_type: TYPE_FP32
    dims: [-1]
  }
]
instance_group [
    {
      count: 1
      kind: KIND_GPU
    }
]
runtime/triton_gpu/model_repo_paraformer_large_online/feature_extractor/1/model.py
New file
@@ -0,0 +1,221 @@
# Created on 2024-01-01
# Author: GuAn Zhu
# Modified from NVIDIA(https://github.com/wenet-e2e/wenet/blob/main/runtime/gpu/
# model_repo_stateful/feature_extractor/1/model.py)
import triton_python_backend_utils as pb_utils
from torch.utils.dlpack import from_dlpack
import torch
import kaldifeat
from typing import List
import json
import numpy as np
import yaml
from collections import OrderedDict
class LimitedDict(OrderedDict):
    def __init__(self, max_length):
        super().__init__()
        self.max_length = max_length
    def __setitem__(self, key, value):
        if len(self) >= self.max_length:
            self.popitem(last=False)
        super().__setitem__(key, value)
class Fbank(torch.nn.Module):
    def __init__(self, opts):
        super(Fbank, self).__init__()
        self.fbank = kaldifeat.Fbank(opts)
    def forward(self, waves: List[torch.Tensor]):
        return self.fbank(waves)
class Feat(object):
    def __init__(self, seqid, offset_ms, sample_rate, frame_stride, device='cpu'):
        self.seqid = seqid
        self.sample_rate = sample_rate
        self.wav = torch.tensor([], device=device)
        self.offset = int(offset_ms / 1000 * sample_rate)
        self.frames = None
        self.frame_stride = int(frame_stride)
        self.device = device
        self.lfr_m = 7
    def add_wavs(self, wav: torch.tensor):
        wav = wav.to(self.device)
        self.wav = torch.cat((self.wav, wav), axis=0)
    def get_seg_wav(self):
        seg = self.wav[:]
        self.wav = self.wav[-self.offset:]
        return seg
    def add_frames(self, frames: torch.tensor):
        """
        frames: seq_len x feat_sz
        """
        if self.frames is None:
            self.frames = torch.cat((frames[0, :].repeat((self.lfr_m - 1) // 2, 1),
                                     frames), axis=0)
        else:
            self.frames = torch.cat([self.frames, frames], axis=0)
    def get_frames(self, num_frames: int):
        seg = self.frames[0: num_frames]
        self.frames = self.frames[self.frame_stride:]
        return seg
class TritonPythonModel:
    """Your Python model must use the same class name. Every Python model
    that is created must have "TritonPythonModel" as the class name.
    """
    def initialize(self, args):
        """`initialize` is called only once when the model is being loaded.
        Implementing `initialize` function is optional. This function allows
        the model to initialize any state associated with this model.
        Parameters
        ----------
        args : dict
          Both keys and values are strings. The dictionary keys and values are:
          * model_config: A JSON string containing the model configuration
          * model_instance_kind: A string containing model instance kind
          * model_instance_device_id: A string containing model instance device ID
          * model_repository: Model repository path
          * model_version: Model version
          * model_name: Model name
        """
        self.model_config = model_config = json.loads(args['model_config'])
        self.max_batch_size = max(model_config["max_batch_size"], 1)
        if "GPU" in model_config["instance_group"][0]["kind"]:
            self.device = "cuda"
        else:
            self.device = "cpu"
        # Get OUTPUT0 configuration
        output0_config = pb_utils.get_output_config_by_name(
            model_config, "speech")
        # Convert Triton types to numpy types
        self.output0_dtype = pb_utils.triton_string_to_numpy(
            output0_config['data_type'])
        if self.output0_dtype == np.float32:
            self.dtype = torch.float32
        else:
            self.dtype = torch.float16
        self.feature_size = output0_config['dims'][-1]
        self.decoding_window = output0_config['dims'][-2]
        params = self.model_config['parameters']
        for li in params.items():
            key, value = li
            value = value["string_value"]
            if key == "config_path":
                with open(str(value), 'rb') as f:
                    config = yaml.load(f, Loader=yaml.Loader)
        opts = kaldifeat.FbankOptions()
        opts.frame_opts.dither = 0.0
        opts.frame_opts.window_type = config['frontend_conf']['window']
        opts.mel_opts.num_bins = int(config['frontend_conf']['n_mels'])
        opts.frame_opts.frame_shift_ms = float(config['frontend_conf']['frame_shift'])
        opts.frame_opts.frame_length_ms = float(config['frontend_conf']['frame_length'])
        opts.frame_opts.samp_freq = int(config['frontend_conf']['fs'])
        opts.device = torch.device(self.device)
        self.opts = opts
        self.feature_extractor = Fbank(self.opts)
        self.seq_feat = LimitedDict(1024)
        chunk_size_s = float(params["chunk_size_s"]["string_value"])
        sample_rate = opts.frame_opts.samp_freq
        frame_shift_ms = opts.frame_opts.frame_shift_ms
        frame_length_ms = opts.frame_opts.frame_length_ms
        self.chunk_size = int(chunk_size_s * sample_rate)
        self.frame_stride = (chunk_size_s * 1000) // frame_shift_ms
        self.offset_ms = self.get_offset(frame_length_ms, frame_shift_ms)
        self.sample_rate = sample_rate
    def get_offset(self, frame_length_ms, frame_shift_ms):
        offset_ms = 0
        while offset_ms + frame_shift_ms < frame_length_ms:
            offset_ms += frame_shift_ms
        return offset_ms
    def execute(self, requests):
        """`execute` must be implemented in every Python model. `execute`
        function receives a list of pb_utils.InferenceRequest as the only
        argument. This function is called when an inference is requested
        for this model.
        Parameters
        ----------
        requests : list
          A list of pb_utils.InferenceRequest
        Returns
        -------
        list
          A list of pb_utils.InferenceResponse. The length of this list must
          be the same as `requests`
        """
        total_waves = []
        responses = []
        batch_seqid = []
        end_seqid = {}
        for request in requests:
            input0 = pb_utils.get_input_tensor_by_name(request, "wav")
            wav = from_dlpack(input0.to_dlpack())[0]
            # input1 = pb_utils.get_input_tensor_by_name(request, "wav_lens")
            # wav_len = from_dlpack(input1.to_dlpack())[0]
            wav_len = len(wav)
            if wav_len < self.chunk_size:
                temp = torch.zeros(self.chunk_size, dtype=torch.float32,
                                   device=self.device)
                temp[0:wav_len] = wav[:]
                wav = temp
            in_start = pb_utils.get_input_tensor_by_name(request, "START")
            start = in_start.as_numpy()[0][0]
            in_ready = pb_utils.get_input_tensor_by_name(request, "READY")
            ready = in_ready.as_numpy()[0][0]
            in_corrid = pb_utils.get_input_tensor_by_name(request, "CORRID")
            corrid = in_corrid.as_numpy()[0][0]
            in_end = pb_utils.get_input_tensor_by_name(request, "END")
            end = in_end.as_numpy()[0][0]
            if start:
                self.seq_feat[corrid] = Feat(corrid, self.offset_ms,
                                             self.sample_rate,
                                             self.frame_stride,
                                             self.device)
            if ready:
                self.seq_feat[corrid].add_wavs(wav)
            batch_seqid.append(corrid)
            if end:
                end_seqid[corrid] = 1
            wav = self.seq_feat[corrid].get_seg_wav() * 32768
            total_waves.append(wav)
        features = self.feature_extractor(total_waves)
        for corrid, frames in zip(batch_seqid, features):
            self.seq_feat[corrid].add_frames(frames)
            speech = self.seq_feat[corrid].get_frames(self.decoding_window)
            out_tensor0 = pb_utils.Tensor("speech", torch.unsqueeze(speech, 0).to("cpu").numpy())
            output_tensors = [out_tensor0]
            response = pb_utils.InferenceResponse(output_tensors=output_tensors)
            responses.append(response)
            if corrid in end_seqid:
                del self.seq_feat[corrid]
        return responses
    def finalize(self):
        print("Remove feature extractor!")
runtime/triton_gpu/model_repo_paraformer_large_online/feature_extractor/config.pbtxt
New file
@@ -0,0 +1,109 @@
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Created on 2024-01-01
# Author: GuAn Zhu
name: "feature_extractor"
backend: "python"
max_batch_size: 128
parameters [
  {
    key: "chunk_size_s",
    value: { string_value: "0.6"}
  },
  {
    key: "config_path"
    value: { string_value: "model_repo_paraformer_large_online/feature_extractor/config.yaml"}
  }
]
sequence_batching{
    max_sequence_idle_microseconds: 15000000
    oldest {
      max_candidate_sequences: 1024
      preferred_batch_size: [32, 64, 128]
      max_queue_delay_microseconds: 300
    }
    control_input [
        {
            name: "START",
            control [
                {
                    kind: CONTROL_SEQUENCE_START
                    fp32_false_true: [0, 1]
                }
            ]
        },
        {
            name: "READY"
            control [
                {
                    kind: CONTROL_SEQUENCE_READY
                    fp32_false_true: [0, 1]
                }
            ]
        },
        {
            name: "CORRID",
            control [
                {
                    kind: CONTROL_SEQUENCE_CORRID
                    data_type: TYPE_UINT64
                }
            ]
        },
        {
            name: "END",
            control [
                {
                    kind: CONTROL_SEQUENCE_END
                    fp32_false_true: [0, 1]
                }
            ]
        }
    ]
}
input [
  {
    name: "wav"
    data_type: TYPE_FP32
    dims: [-1]
  },
  {
    name: "wav_lens"
    data_type: TYPE_INT32
    dims: [1]
  }
]
output [
  {
    name: "speech"
    data_type: TYPE_FP32
    dims: [61, 80]  # 80
  }
]
instance_group [
    {
      count: 1
      kind: KIND_GPU
    }
]
runtime/triton_gpu/model_repo_paraformer_large_online/feature_extractor/config.yaml
New file
Diff too large
runtime/triton_gpu/model_repo_paraformer_large_online/lfr_cmvn_pe/am.mvn
New file
@@ -0,0 +1,8 @@
<Nnet>
<Splice> 560 560
[ 0 ]
<AddShift> 560 560
<LearnRateCoef> 0 [ -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 -8.311879 -8.600912 -9.615928 -10.43595 -11.21292 -11.88333 -12.36243 -12.63706 -12.8818 -12.83066 -12.89103 -12.95666 -13.19763 -13.40598 -13.49113 -13.5546 -13.55639 -13.51915 -13.68284 -13.53289 -13.42107 -13.65519 -13.50713 -13.75251 -13.76715 -13.87408 -13.73109 -13.70412 -13.56073 -13.53488 -13.54895 -13.56228 -13.59408 -13.62047 -13.64198 -13.66109 -13.62669 -13.58297 -13.57387 -13.4739 -13.53063 -13.48348 -13.61047 -13.64716 -13.71546 -13.79184 -13.90614 -14.03098 -14.18205 -14.35881 -14.48419 -14.60172 -14.70591 -14.83362 -14.92122 -15.00622 -15.05122 -15.03119 -14.99028 -14.92302 -14.86927 -14.82691 -14.7972 -14.76909 -14.71356 -14.61277 -14.51696 -14.42252 -14.36405 -14.30451 -14.23161 -14.19851 -14.16633 -14.15649 -14.10504 -13.99518 -13.79562 -13.3996 -12.7767 -11.71208 ]
<Rescale> 560 560
<LearnRateCoef> 0 [ 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 0.155775 0.154484 0.1527379 0.1518718 0.1506028 0.1489256 0.147067 0.1447061 0.1436307 0.1443568 0.1451849 0.1455157 0.1452821 0.1445717 0.1439195 0.1435867 0.1436018 0.1438781 0.1442086 0.1448844 0.1454756 0.145663 0.146268 0.1467386 0.1472724 0.147664 0.1480913 0.1483739 0.1488841 0.1493636 0.1497088 0.1500379 0.1502916 0.1505389 0.1506787 0.1507102 0.1505992 0.1505445 0.1505938 0.1508133 0.1509569 0.1512396 0.1514625 0.1516195 0.1516156 0.1515561 0.1514966 0.1513976 0.1512612 0.151076 0.1510596 0.1510431 0.151077 0.1511168 0.1511917 0.151023 0.1508045 0.1505885 0.1503493 0.1502373 0.1501726 0.1500762 0.1500065 0.1499782 0.150057 0.1502658 0.150469 0.1505335 0.1505505 0.1505328 0.1504275 0.1502438 0.1499674 0.1497118 0.1494661 0.1493102 0.1493681 0.1495501 0.1499738 0.1509654 ]
</Nnet>
runtime/triton_gpu/model_repo_paraformer_large_online/lfr_cmvn_pe/config.pbtxt
New file
@@ -0,0 +1,85 @@
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Created on 2024-01-01
# Author: GuAn Zhu
name: "lfr_cmvn_pe"
backend: "onnxruntime"
default_model_filename: "lfr_cmvn_pe.onnx"
max_batch_size: 128
sequence_batching{
    max_sequence_idle_microseconds: 15000000
    oldest {
      max_candidate_sequences: 1024
      preferred_batch_size: [32, 64, 128]
      max_queue_delay_microseconds: 300
    }
    control_input [
    ]
    state [
    {
      input_name: "cache"
      output_name: "r_cache"
      data_type: TYPE_FP32
      dims: [10, 560]
      initial_state: {
       data_type: TYPE_FP32
       dims: [10, 560]
       zero_data: true
       name: "initial state"
      }
    },
    {
      input_name: "offset"
      output_name: "r_offset"
      data_type: TYPE_INT32
      dims: [1]
      initial_state: {
       data_type: TYPE_INT32
       dims: [1]
       zero_data: true
       name: "initial state"
      }
    }
  ]
}
input [
  {
    name: "chunk_xs"
    data_type: TYPE_FP32
    dims: [61, 80]
  }
]
output [
  {
    name: "chunk_xs_out"
    data_type: TYPE_FP32
    dims: [-1, 560]
  },
  {
    name: "chunk_xs_out_len"
    data_type: TYPE_INT32
    dims: [-1]
  }
]
instance_group [
    {
      count: 1
      kind: KIND_GPU
    }
]
runtime/triton_gpu/model_repo_paraformer_large_online/lfr_cmvn_pe/export_lfr_cmvn_pe_onnx.py
New file
@@ -0,0 +1,131 @@
# Created on 2024-01-01
# Author: GuAn Zhu
import torch
import numpy as np
import math
import torch.nn.functional as F
class LFR_CMVN_PE(torch.nn.Module):
    def __init__(self,
                 mean: torch.Tensor,
                 istd: torch.Tensor,
                 m: int = 7,
                 n: int = 6,
                 max_len: int = 5000,
                 encoder_input_size: int = 560,
                 encoder_output_size: int = 512):
        super().__init__()
        # LRF
        self.m = m
        self.n = n
        self.subsample = (m - 1) // 2
        # CMVN
        assert mean.shape == istd.shape
        # The buffer can be accessed from this module using self.mean
        self.register_buffer("mean", mean)
        self.register_buffer("istd", istd)
        # PE
        self.encoder_input_size = encoder_input_size
        self.encoder_output_size = encoder_output_size
        self.max_len = max_len
        self.pe = torch.zeros(self.max_len, self.encoder_input_size)
        position = torch.arange(0, self.max_len,
                                dtype=torch.float32).unsqueeze(1)
        div_term = torch.exp(
            torch.arange((self.encoder_input_size/2), dtype=torch.float32) *
            -(math.log(10000.0) / (self.encoder_input_size/2-1)))
        self.pe[:, 0::1] = torch.cat((torch.sin(position * div_term), torch.cos(position * div_term)), dim=1)
    def forward(self, x, cache, offset):
        """
        Args:
            x (torch.Tensor): (batch, max_len, feat_dim)
        Returns:
            (torch.Tensor): normalized feature
        """
        B, _, D = x.size()
        x = x.unfold(1, self.m, step=self.n).transpose(2, 3)
        x = x.view(B, -1, D * self.m)
        x = (x + self.mean) * self.istd
        x = x * (self.encoder_output_size ** 0.5)
        index = offset + torch.arange(1, x.size(1)+1).to(dtype=torch.int32)
        pos_emb = F.embedding(index, self.pe)  # B X T X d_model
        r_cache = x + pos_emb
        r_x = torch.cat((cache, r_cache), dim=1)
        r_offset = offset + x.size(1)
        r_x_len = torch.ones((B, 1), dtype=torch.int32) * r_x.size(1)
        return r_x, r_x_len, r_cache, r_offset
def load_cmvn(cmvn_file):
    with open(cmvn_file, 'r', encoding='utf-8') as f:
        lines = f.readlines()
    means_list = []
    vars_list = []
    for i in range(len(lines)):
        line_item = lines[i].split()
        if line_item[0] == '<AddShift>':
            line_item = lines[i + 1].split()
            if line_item[0] == '<LearnRateCoef>':
                add_shift_line = line_item[3:(len(line_item) - 1)]
                means_list = list(add_shift_line)
                continue
        elif line_item[0] == '<Rescale>':
            line_item = lines[i + 1].split()
            if line_item[0] == '<LearnRateCoef>':
                rescale_line = line_item[3:(len(line_item) - 1)]
                vars_list = list(rescale_line)
                continue
    means = np.array(means_list).astype(np.float32)
    vars = np.array(vars_list).astype(np.float32)
    means = torch.from_numpy(means)
    vars = torch.from_numpy(vars)
    return means, vars
if __name__ == "__main__":
    means, vars = load_cmvn("am.mvn")
    means = torch.tile(means, (10, 1))
    vars = torch.tile(vars, (10, 1))
    model = LFR_CMVN_PE(means, vars)
    model.eval()
    all_names = ['chunk_xs', 'cache', 'offset', 'chunk_xs_out', 'chunk_xs_out_len', 'r_cache', 'r_offset']
    dynamic_axes = {}
    for name in all_names:
        dynamic_axes[name] = {0: 'B'}
    input_data1 = torch.randn(4, 61, 80).to(torch.float32)
    input_data2 = torch.randn(4, 10, 560).to(torch.float32)
    input_data3 = torch.randn(4, 1).to(torch.int32)
    onnx_path = "./1/lfr_cmvn_pe.onnx"
    torch.onnx.export(model,
                      (input_data1, input_data2, input_data3),
                      onnx_path,
                      export_params=True,
                      opset_version=11,
                      do_constant_folding=True,
                      input_names=['chunk_xs', 'cache', 'offset'],
                      output_names=['chunk_xs_out', 'chunk_xs_out_len', 'r_cache', 'r_offset'],
                      dynamic_axes=dynamic_axes,
                      verbose=False
                      )
    print("export to onnx model succeed!")
runtime/triton_gpu/model_repo_paraformer_large_online/streaming_paraformer/config.pbtxt
New file
@@ -0,0 +1,122 @@
# Copyright (c) 2021, NVIDIA CORPORATION.  All rights reserved.
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Created on 2024-01-01
# Author: GuAn Zhu
name: "streaming_paraformer"
platform: "ensemble"
max_batch_size: 128 #MAX_BATCH
input [
  {
    name: "WAV"
    data_type: TYPE_FP32
    dims: [-1]
  },
  {
    name: "WAV_LENS"
    data_type: TYPE_INT32
    dims: [1]
  }
]
output [
  {
    name: "TRANSCRIPTS"
    data_type: TYPE_STRING
    dims: [1]
  }
]
ensemble_scheduling {
 step [
   {
    model_name: "feature_extractor"
    model_version: -1
    input_map {
      key: "wav"
      value: "WAV"
    }
    input_map {
      key: "wav_lens"
      value: "WAV_LENS"
    }
    output_map {
      key: "speech"
      value: "SPEECH"
    }
   },
   {
      model_name: "lfr_cmvn_pe"
      model_version: -1
      input_map {
          key: "chunk_xs"
          value: "SPEECH"
      }
      output_map {
          key: "chunk_xs_out"
          value: "CHUNK_XS_OUT"
      }
      output_map {
          key: "chunk_xs_out_len"
          value: "CHUNK_XS_OUT_LEN"
      }
   },
   {
    model_name: "encoder"
    model_version: -1
    input_map {
      key: "speech"
      value: "CHUNK_XS_OUT"
    }
    input_map {
      key: "speech_lengths"
      value: "CHUNK_XS_OUT_LEN"
    }
    output_map {
      key: "enc"
      value: "ENC"
    }
    output_map {
      key: "enc_len"
      value: "ENC_LEN"
    }
    output_map {
      key: "alphas"
      value: "ALPHAS"
    }
  },
  {
    model_name: "cif_search"
    model_version: -1
    input_map {
      key: "enc"
      value: "ENC"
    }
    input_map {
      key: "enc_len"
      value: "ENC_LEN"
    }
    input_map {
      key: "alphas"
      value: "ALPHAS"
    }
    output_map {
      key: "transcripts"
      value: "TRANSCRIPTS"
      }
   }
 ]
}