From 3547adb4fb8b8284afefb8413382592fcdfa0302 Mon Sep 17 00:00:00 2001
From: Daniel <znsoft@163.com>
Date: 星期二, 07 三月 2023 12:18:06 +0800
Subject: [PATCH] Merge branch 'alibaba-damo-academy:main' into main

---
 funasr/version.txt                                                                           |    2 
 /dev/null                                                                                    |    0 
 funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/config.pbtxt |    4 
 funasr/runtime/triton_gpu/Dockerfile/Dockerfile.server                                       |    4 
 funasr/runtime/triton_gpu/README.md                                                          |   55 ++++-
 funasr/runtime/triton_gpu/client/decode_manifest_triton.py                                   |  541 +++++++++++++++++++++++++++++++++++++++++++++++++
 docs/images/damo.png                                                                         |    0 
 funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/scoring/1/model.py             |    7 
 funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/scoring/config.pbtxt           |    2 
 README.md                                                                                    |    4 
 funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/1/model.py   |   20 +
 11 files changed, 610 insertions(+), 29 deletions(-)

diff --git a/README.md b/README.md
index b0454fb..aca9b8d 100644
--- a/README.md
+++ b/README.md
@@ -77,8 +77,8 @@
 
 ## Contributors
 
-| <div align="left"><img src="docs/images/DeepScience.png" width="250"/> |
-|:---:|
+| <div align="left"><img src="docs/images/damo.png" width="180"/> | <img src="docs/images/DeepScience.png" width="200"/> </div> |
+|:---------------------------------------------------------------:|:-----------------------------------------------------------:|
 
 ## Acknowledge
 
diff --git a/docs/images/damo.png b/docs/images/damo.png
new file mode 100644
index 0000000..f4f7a89
--- /dev/null
+++ b/docs/images/damo.png
Binary files differ
diff --git a/funasr/runtime/triton_gpu/Dockerfile/Dockerfile.server b/funasr/runtime/triton_gpu/Dockerfile/Dockerfile.server
index 459195c..d03610c 100644
--- a/funasr/runtime/triton_gpu/Dockerfile/Dockerfile.server
+++ b/funasr/runtime/triton_gpu/Dockerfile/Dockerfile.server
@@ -10,8 +10,10 @@
     cmake \
     libsndfile1
 
+# -i https://pypi.tuna.tsinghua.edu.cn/simple
+RUN pip3 install torch torchaudio 
 RUN pip3 install kaldifeat pyyaml
 
 # Dependency for client
-RUN pip3 install soundfile grpcio-tools tritonclient pyyaml
+RUN pip3 install soundfile grpcio-tools tritonclient
 WORKDIR /workspace
diff --git a/funasr/runtime/triton_gpu/README.md b/funasr/runtime/triton_gpu/README.md
index daceb4e..48e889c 100644
--- a/funasr/runtime/triton_gpu/README.md
+++ b/funasr/runtime/triton_gpu/README.md
@@ -1,16 +1,21 @@
 ## Inference with Triton 
 
 ### Steps:
-1. Refer here to [get model.onnx](https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/export/README.md)
-
-2. Follow below instructions to using triton
+1. Prepare model repo files
 ```sh
-# using docker image Dockerfile/Dockerfile.server
-docker build . -f Dockerfile/Dockerfile.server -t triton-paraformer:23.01 
-docker run -it --rm --name "paraformer_triton_server" --gpus all -v <path_host/funasr/runtime/>:/workspace --shm-size 1g --net host triton-paraformer:23.01 
-# inside the docker container, prepare previous exported model.onnx
-mv <path_model.onnx> /workspace/triton_gpu/model_repo_paraformer_large_offline/encoder/1/
+git-lfs install
+git clone https://www.modelscope.cn/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch.git
 
+pretrained_model_dir=$(pwd)/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch
+
+cp $pretrained_model_dir/am.mvn ./model_repo_paraformer_large_offline/feature_extractor/
+cp $pretrained_model_dir/config.yaml ./model_repo_paraformer_large_offline/feature_extractor/
+
+# Refer here to get model.onnx (https://github.com/alibaba-damo-academy/FunASR/blob/main/funasr/export/README.md)
+cp <exported_onnx_dir>/model.onnx ./model_repo_paraformer_large_offline/encoder/1/
+```
+Log of directory tree:
+```sh
 model_repo_paraformer_large_offline/
 |-- encoder
 |   |-- 1
@@ -20,6 +25,7 @@
 |   |-- 1
 |   |   `-- model.py
 |   |-- config.pbtxt
+|   |-- am.mvn
 |   `-- config.yaml
 |-- infer_pipeline
 |   |-- 1
@@ -27,13 +33,19 @@
 `-- scoring
     |-- 1
     |   `-- model.py
-    |-- config.pbtxt
-    `-- token_list.pkl
+    `-- config.pbtxt
 
 8 directories, 9 files
+```
+
+2. Follow below instructions to launch triton server
+```sh
+# using docker image Dockerfile/Dockerfile.server
+docker build . -f Dockerfile/Dockerfile.server -t triton-paraformer:23.01 
+docker run -it --rm --name "paraformer_triton_server" --gpus all -v <path_host/model_repo_paraformer_large_offline>:/workspace/ --shm-size 1g --net host triton-paraformer:23.01 
 
 # launch the service 
-tritonserver --model-repository ./model_repo_paraformer_large_offline \
+tritonserver --model-repository /workspace/model_repo_paraformer_large_offline \
              --pinned-memory-pool-byte-size=512000000 \
              --cuda-memory-pool-byte-size=0:1024000000
 
@@ -43,6 +55,27 @@
 
 Benchmark [speech_paraformer](https://www.modelscope.cn/models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/summary) based on Aishell1 test set with a single V100, the total audio duration is 36108.919 seconds.
 
+```sh
+# For client container:
+docker run -it --rm --name "client_test" --net host --gpus all -v <path_host/triton_gpu/client>:/workpace/ soar97/triton-k2:22.12.1 # noqa
+# For aishell manifests:
+apt-get install git-lfs
+git-lfs install
+git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests
+sudo mkdir -p /root/fangjun/open-source/icefall-aishell/egs/aishell/ASR/download/aishell
+tar xf ./aishell-test-dev-manifests/data_aishell.tar.gz -C /root/fangjun/open-source/icefall-aishell/egs/aishell/ASR/download/aishell/ # noqa
+
+serveraddr=localhost
+manifest_path=/workspace/aishell-test-dev-manifests/data/fbank/aishell_cuts_test.jsonl.gz
+num_task=60
+python3 client/decode_manifest_triton.py \
+    --server-addr $serveraddr \
+    --compute-cer \
+    --model-name infer_pipeline \
+    --num-tasks $num_task \
+    --manifest-filename $manifest_path
+```
+
 (Note: The service has been fully warm up.)
 |concurrent-tasks | processing time(s) | RTF |
 |----------|--------------------|------------|
diff --git a/funasr/runtime/triton_gpu/client/decode_manifest_triton.py b/funasr/runtime/triton_gpu/client/decode_manifest_triton.py
new file mode 100644
index 0000000..3a8d57f
--- /dev/null
+++ b/funasr/runtime/triton_gpu/client/decode_manifest_triton.py
@@ -0,0 +1,541 @@
+#!/usr/bin/env python3
+# Copyright      2022  Xiaomi Corp.        (authors: Fangjun Kuang)
+#                2023  Nvidia              (authors: Yuekai Zhang)
+# See LICENSE for clarification regarding multiple authors
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""
+This script loads a manifest in lhotse format and sends it to the server
+for decoding, in parallel.
+
+Usage:
+# For offline wenet server
+./decode_manifest_triton.py \
+    --server-addr localhost \
+    --compute-cer \
+    --model-name attention_rescoring \
+    --num-tasks 300 \
+    --manifest-filename ./aishell-test-dev-manifests/data/fbank/aishell_cuts_test.jsonl.gz # noqa
+
+# For streaming wenet server
+./decode_manifest_triton.py \
+    --server-addr localhost \
+    --streaming \
+    --compute-cer \
+    --context 7 \
+    --model-name streaming_wenet \
+    --num-tasks 300 \
+    --manifest-filename ./aishell-test-dev-manifests/data/fbank/aishell_cuts_test.jsonl.gz # noqa
+
+# For simulate streaming mode wenet server
+./decode_manifest_triton.py \
+    --server-addr localhost \
+    --simulate-streaming \
+    --compute-cer \
+    --context 7 \
+    --model-name streaming_wenet \
+    --num-tasks 300 \
+    --manifest-filename ./aishell-test-dev-manifests/data/fbank/aishell_cuts_test.jsonl.gz # noqa
+
+# For test container:
+docker run -it --rm --name "wenet_client_test" --net host --gpus all soar97/triton-k2:22.12.1 # noqa
+
+# For aishell manifests:
+apt-get install git-lfs
+git-lfs install
+git clone https://huggingface.co/csukuangfj/aishell-test-dev-manifests
+sudo mkdir -p /root/fangjun/open-source/icefall-aishell/egs/aishell/ASR/download/aishell
+tar xf ./aishell-test-dev-manifests/data_aishell.tar.gz -C /root/fangjun/open-source/icefall-aishell/egs/aishell/ASR/download/aishell/ # noqa
+
+"""
+
+import argparse
+import asyncio
+import math
+import time
+import types
+from pathlib import Path
+import json
+import numpy as np
+import tritonclient
+import tritonclient.grpc.aio as grpcclient
+from lhotse import CutSet, load_manifest
+from tritonclient.utils import np_to_triton_dtype
+
+from icefall.utils import store_transcripts, write_error_stats
+
+DEFAULT_MANIFEST_FILENAME = "/mnt/samsung-t7/yuekai/aishell-test-dev-manifests/data/fbank/aishell_cuts_test.jsonl.gz"  # noqa
+
+
+def get_args():
+    parser = argparse.ArgumentParser(
+        formatter_class=argparse.ArgumentDefaultsHelpFormatter
+    )
+
+    parser.add_argument(
+        "--server-addr",
+        type=str,
+        default="localhost",
+        help="Address of the server",
+    )
+
+    parser.add_argument(
+        "--server-port",
+        type=int,
+        default=8001,
+        help="Port of the server",
+    )
+
+    parser.add_argument(
+        "--manifest-filename",
+        type=str,
+        default=DEFAULT_MANIFEST_FILENAME,
+        help="Path to the manifest for decoding",
+    )
+
+    parser.add_argument(
+        "--model-name",
+        type=str,
+        default="transducer",
+        help="triton model_repo module name to request",
+    )
+
+    parser.add_argument(
+        "--num-tasks",
+        type=int,
+        default=50,
+        help="Number of tasks to use for sending",
+    )
+
+    parser.add_argument(
+        "--log-interval",
+        type=int,
+        default=5,
+        help="Controls how frequently we print the log.",
+    )
+
+    parser.add_argument(
+        "--compute-cer",
+        action="store_true",
+        default=False,
+        help="""True to compute CER, e.g., for Chinese.
+        False to compute WER, e.g., for English words.
+        """,
+    )
+
+    parser.add_argument(
+        "--streaming",
+        action="store_true",
+        default=False,
+        help="""True for streaming ASR.
+        """,
+    )
+
+    parser.add_argument(
+        "--simulate-streaming",
+        action="store_true",
+        default=False,
+        help="""True for strictly simulate streaming ASR.
+        Threads will sleep to simulate the real speaking scene.
+        """,
+    )
+
+    parser.add_argument(
+        "--chunk_size",
+        type=int,
+        required=False,
+        default=16,
+        help="chunk size default is 16",
+    )
+
+    parser.add_argument(
+        "--context",
+        type=int,
+        required=False,
+        default=-1,
+        help="subsampling context for wenet",
+    )
+
+    parser.add_argument(
+        "--encoder_right_context",
+        type=int,
+        required=False,
+        default=2,
+        help="encoder right context",
+    )
+
+    parser.add_argument(
+        "--subsampling",
+        type=int,
+        required=False,
+        default=4,
+        help="subsampling rate",
+    )
+
+    parser.add_argument(
+        "--stats_file",
+        type=str,
+        required=False,
+        default="./stats.json",
+        help="output of stats anaylasis",
+    )
+
+    return parser.parse_args()
+
+
+async def send(
+    cuts: CutSet,
+    name: str,
+    triton_client: tritonclient.grpc.aio.InferenceServerClient,
+    protocol_client: types.ModuleType,
+    log_interval: int,
+    compute_cer: bool,
+    model_name: str,
+):
+    total_duration = 0.0
+    results = []
+
+    for i, c in enumerate(cuts):
+        if i % log_interval == 0:
+            print(f"{name}: {i}/{len(cuts)}")
+
+        waveform = c.load_audio().reshape(-1).astype(np.float32)
+        sample_rate = 16000
+
+        # padding to nearset 10 seconds
+        samples = np.zeros(
+            (
+                1,
+                10 * sample_rate * (int(len(waveform) / sample_rate // 10) + 1),
+            ),
+            dtype=np.float32,
+        )
+        samples[0, : len(waveform)] = waveform
+
+        lengths = np.array([[len(waveform)]], dtype=np.int32)
+
+        inputs = [
+            protocol_client.InferInput(
+                "WAV", samples.shape, np_to_triton_dtype(samples.dtype)
+            ),
+            protocol_client.InferInput(
+                "WAV_LENS", lengths.shape, np_to_triton_dtype(lengths.dtype)
+            ),
+        ]
+        inputs[0].set_data_from_numpy(samples)
+        inputs[1].set_data_from_numpy(lengths)
+        outputs = [protocol_client.InferRequestedOutput("TRANSCRIPTS")]
+        sequence_id = 10086 + i
+
+        response = await triton_client.infer(
+            model_name, inputs, request_id=str(sequence_id), outputs=outputs
+        )
+
+        decoding_results = response.as_numpy("TRANSCRIPTS")[0]
+        if type(decoding_results) == np.ndarray:
+            decoding_results = b" ".join(decoding_results).decode("utf-8")
+        else:
+            # For wenet
+            decoding_results = decoding_results.decode("utf-8")
+
+        total_duration += c.duration
+
+        if compute_cer:
+            ref = c.supervisions[0].text.split()
+            hyp = decoding_results.split()
+            ref = list("".join(ref))
+            hyp = list("".join(hyp))
+            results.append((c.id, ref, hyp))
+        else:
+            results.append(
+                (
+                    c.id,
+                    c.supervisions[0].text.split(),
+                    decoding_results.split(),
+                )
+            )  # noqa
+
+    return total_duration, results
+
+
+async def send_streaming(
+    cuts: CutSet,
+    name: str,
+    triton_client: tritonclient.grpc.aio.InferenceServerClient,
+    protocol_client: types.ModuleType,
+    log_interval: int,
+    compute_cer: bool,
+    model_name: str,
+    first_chunk_in_secs: float,
+    other_chunk_in_secs: float,
+    task_index: int,
+    simulate_mode: bool = False,
+):
+    total_duration = 0.0
+    results = []
+    latency_data = []
+
+    for i, c in enumerate(cuts):
+        if i % log_interval == 0:
+            print(f"{name}: {i}/{len(cuts)}")
+
+        waveform = c.load_audio().reshape(-1).astype(np.float32)
+        sample_rate = 16000
+
+        wav_segs = []
+
+        j = 0
+        while j < len(waveform):
+            if j == 0:
+                stride = int(first_chunk_in_secs * sample_rate)
+                wav_segs.append(waveform[j : j + stride])
+            else:
+                stride = int(other_chunk_in_secs * sample_rate)
+                wav_segs.append(waveform[j : j + stride])
+            j += len(wav_segs[-1])
+
+        sequence_id = task_index + 10086
+
+        for idx, seg in enumerate(wav_segs):
+            chunk_len = len(seg)
+
+            if simulate_mode:
+                await asyncio.sleep(chunk_len / sample_rate)
+
+            chunk_start = time.time()
+            if idx == 0:
+                chunk_samples = int(first_chunk_in_secs * sample_rate)
+                expect_input = np.zeros((1, chunk_samples), dtype=np.float32)
+            else:
+                chunk_samples = int(other_chunk_in_secs * sample_rate)
+                expect_input = np.zeros((1, chunk_samples), dtype=np.float32)
+
+            expect_input[0][0:chunk_len] = seg
+            input0_data = expect_input
+            input1_data = np.array([[chunk_len]], dtype=np.int32)
+
+            inputs = [
+                protocol_client.InferInput(
+                    "WAV",
+                    input0_data.shape,
+                    np_to_triton_dtype(input0_data.dtype),
+                ),
+                protocol_client.InferInput(
+                    "WAV_LENS",
+                    input1_data.shape,
+                    np_to_triton_dtype(input1_data.dtype),
+                ),
+            ]
+
+            inputs[0].set_data_from_numpy(input0_data)
+            inputs[1].set_data_from_numpy(input1_data)
+
+            outputs = [protocol_client.InferRequestedOutput("TRANSCRIPTS")]
+            end = False
+            if idx == len(wav_segs) - 1:
+                end = True
+
+            response = await triton_client.infer(
+                model_name,
+                inputs,
+                outputs=outputs,
+                sequence_id=sequence_id,
+                sequence_start=idx == 0,
+                sequence_end=end,
+            )
+            idx += 1
+
+            decoding_results = response.as_numpy("TRANSCRIPTS")
+            if type(decoding_results) == np.ndarray:
+                decoding_results = b" ".join(decoding_results).decode("utf-8")
+            else:
+                # For wenet
+                decoding_results = response.as_numpy("TRANSCRIPTS")[0].decode(
+                    "utf-8"
+                )
+            chunk_end = time.time() - chunk_start
+            latency_data.append((chunk_end, chunk_len / sample_rate))
+
+        total_duration += c.duration
+
+        if compute_cer:
+            ref = c.supervisions[0].text.split()
+            hyp = decoding_results.split()
+            ref = list("".join(ref))
+            hyp = list("".join(hyp))
+            results.append((c.id, ref, hyp))
+        else:
+            results.append(
+                (
+                    c.id,
+                    c.supervisions[0].text.split(),
+                    decoding_results.split(),
+                )
+            )  # noqa
+
+    return total_duration, results, latency_data
+
+
+async def main():
+    args = get_args()
+    filename = args.manifest_filename
+    server_addr = args.server_addr
+    server_port = args.server_port
+    url = f"{server_addr}:{server_port}"
+    num_tasks = args.num_tasks
+    log_interval = args.log_interval
+    compute_cer = args.compute_cer
+
+    cuts = load_manifest(filename)
+    cuts_list = cuts.split(num_tasks)
+    tasks = []
+
+    triton_client = grpcclient.InferenceServerClient(url=url, verbose=False)
+    protocol_client = grpcclient
+
+    if args.streaming or args.simulate_streaming:
+        frame_shift_ms = 10
+        frame_length_ms = 25
+        add_frames = math.ceil(
+            (frame_length_ms - frame_shift_ms) / frame_shift_ms
+        )
+        # decode_window_length: input sequence length of streaming encoder
+        if args.context > 0:
+            # decode window length calculation for wenet
+            decode_window_length = (
+                args.chunk_size - 1
+            ) * args.subsampling + args.context
+        else:
+            # decode window length calculation for icefall
+            decode_window_length = (
+                args.chunk_size + 2 + args.encoder_right_context
+            ) * args.subsampling + 3
+
+        first_chunk_ms = (decode_window_length + add_frames) * frame_shift_ms
+
+    start_time = time.time()
+    for i in range(num_tasks):
+        if args.streaming:
+            assert not args.simulate_streaming
+            task = asyncio.create_task(
+                send_streaming(
+                    cuts=cuts_list[i],
+                    name=f"task-{i}",
+                    triton_client=triton_client,
+                    protocol_client=protocol_client,
+                    log_interval=log_interval,
+                    compute_cer=compute_cer,
+                    model_name=args.model_name,
+                    first_chunk_in_secs=first_chunk_ms / 1000,
+                    other_chunk_in_secs=args.chunk_size
+                    * args.subsampling
+                    * frame_shift_ms
+                    / 1000,
+                    task_index=i,
+                )
+            )
+        elif args.simulate_streaming:
+            task = asyncio.create_task(
+                send_streaming(
+                    cuts=cuts_list[i],
+                    name=f"task-{i}",
+                    triton_client=triton_client,
+                    protocol_client=protocol_client,
+                    log_interval=log_interval,
+                    compute_cer=compute_cer,
+                    model_name=args.model_name,
+                    first_chunk_in_secs=first_chunk_ms / 1000,
+                    other_chunk_in_secs=args.chunk_size
+                    * args.subsampling
+                    * frame_shift_ms
+                    / 1000,
+                    task_index=i,
+                    simulate_mode=True,
+                )
+            )
+        else:
+            task = asyncio.create_task(
+                send(
+                    cuts=cuts_list[i],
+                    name=f"task-{i}",
+                    triton_client=triton_client,
+                    protocol_client=protocol_client,
+                    log_interval=log_interval,
+                    compute_cer=compute_cer,
+                    model_name=args.model_name,
+                )
+            )
+        tasks.append(task)
+
+    ans_list = await asyncio.gather(*tasks)
+
+    end_time = time.time()
+    elapsed = end_time - start_time
+
+    results = []
+    total_duration = 0.0
+    latency_data = []
+    for ans in ans_list:
+        total_duration += ans[0]
+        results += ans[1]
+        if args.streaming or args.simulate_streaming:
+            latency_data += ans[2]
+
+    rtf = elapsed / total_duration
+
+    s = f"RTF: {rtf:.4f}\n"
+    s += f"total_duration: {total_duration:.3f} seconds\n"
+    s += f"({total_duration/3600:.2f} hours)\n"
+    s += (
+        f"processing time: {elapsed:.3f} seconds "
+        f"({elapsed/3600:.2f} hours)\n"
+    )
+
+    if args.streaming or args.simulate_streaming:
+        latency_list = [
+            chunk_end for (chunk_end, chunk_duration) in latency_data
+        ]
+        latency_ms = sum(latency_list) / float(len(latency_list)) * 1000.0
+        latency_variance = np.var(latency_list, dtype=np.float64) * 1000.0
+        s += f"latency_variance: {latency_variance:.2f}\n"
+        s += f"latency_50_percentile: {np.percentile(latency_list, 50) * 1000.0:.2f}\n"
+        s += f"latency_90_percentile: {np.percentile(latency_list, 90) * 1000.0:.2f}\n"
+        s += f"latency_99_percentile: {np.percentile(latency_list, 99) * 1000.0:.2f}\n"
+        s += f"average_latency_ms: {latency_ms:.2f}\n"
+
+    print(s)
+
+    with open("rtf.txt", "w") as f:
+        f.write(s)
+
+    name = Path(filename).stem.split(".")[0]
+    results = sorted(results)
+    store_transcripts(filename=f"recogs-{name}.txt", texts=results)
+
+    with open(f"errs-{name}.txt", "w") as f:
+        write_error_stats(f, "test-set", results, enable_log=True)
+
+    with open(f"errs-{name}.txt", "r") as f:
+        print(f.readline())  # WER
+        print(f.readline())  # Detailed errors
+
+    if args.stats_file:
+        stats = await triton_client.get_inference_statistics(
+            model_name="", as_json=True
+        )
+        with open(args.stats_file, "w") as f:
+            json.dump(stats, f)
+
+
+if __name__ == "__main__":
+    asyncio.run(main())
diff --git a/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/1/model.py b/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/1/model.py
index 6464964..2f84bb8 100644
--- a/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/1/model.py
+++ b/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/1/model.py
@@ -105,8 +105,8 @@
             frame_shift: int = 10,
             filter_length_min: int = -1,
             filter_length_max: float = -1,
-            lfr_m: int = 1,
-            lfr_n: int = 1,
+            lfr_m: int = 7,
+            lfr_n: int = 6,
             dither: float = 1.0
     ) -> None:
         # check_argument_types()
@@ -229,22 +229,24 @@
             if key == "config_path":
                 with open(str(value), 'rb') as f:
                     config = yaml.load(f, Loader=yaml.Loader)
+            if key == "cmvn_path":
+                cmvn_path = str(value)
 
         opts = kaldifeat.FbankOptions()
         opts.frame_opts.dither = 1.0 # TODO: 0.0 or 1.0
-        opts.frame_opts.window_type = config['WavFrontend']['frontend_conf']['window']
-        opts.mel_opts.num_bins = int(config['WavFrontend']['frontend_conf']['n_mels'])
-        opts.frame_opts.frame_shift_ms = float(config['WavFrontend']['frontend_conf']['frame_shift'])
-        opts.frame_opts.frame_length_ms = float(config['WavFrontend']['frontend_conf']['frame_length'])
-        opts.frame_opts.samp_freq = int(config['WavFrontend']['frontend_conf']['fs'])
+        opts.frame_opts.window_type = config['frontend_conf']['window']
+        opts.mel_opts.num_bins = int(config['frontend_conf']['n_mels'])
+        opts.frame_opts.frame_shift_ms = float(config['frontend_conf']['frame_shift'])
+        opts.frame_opts.frame_length_ms = float(config['frontend_conf']['frame_length'])
+        opts.frame_opts.samp_freq = int(config['frontend_conf']['fs'])
         opts.device = torch.device(self.device)
         self.opts = opts
         self.feature_extractor = Fbank(self.opts)
         self.feature_size = opts.mel_opts.num_bins
 
         self.frontend = WavFrontend(
-            cmvn_file=config['WavFrontend']['cmvn_file'],
-            **config['WavFrontend']['frontend_conf'])
+            cmvn_file=cmvn_path,
+            **config['frontend_conf'])
 
     def extract_feat(self,
                      waveform_list: List[np.ndarray]
diff --git a/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/config.pbtxt b/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/config.pbtxt
index 8b53183..44bfcd4 100644
--- a/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/config.pbtxt
+++ b/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/config.pbtxt
@@ -34,6 +34,10 @@
     value: { string_value: "16000"}
   },
   {
+    key: "cmvn_path"
+    value: { string_value: "./model_repo_paraformer_large_offline/feature_extractor/am.mvn"}
+  },
+  {
     key: "config_path"
     value: { string_value: "./model_repo_paraformer_large_offline/feature_extractor/config.yaml"}
   }
diff --git a/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/config.yaml b/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/config.yaml
deleted file mode 100644
index a4a66c3..0000000
--- a/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/feature_extractor/config.yaml
+++ /dev/null
@@ -1,30 +0,0 @@
-TokenIDConverter:
-  token_path: resources/models/token_list.pkl
-  unk_symbol: <unk>
-
-CharTokenizer:
-  symbol_value:
-  space_symbol: <space>
-  remove_non_linguistic_symbols: false
-
-WavFrontend:
-  cmvn_file: /raid/dgxsa/yuekaiz/pull_requests/FunASR/funasr/runtime/python/onnxruntime/resources/models/am.mvn
-  frontend_conf:
-    fs: 16000
-    window: hamming
-    n_mels: 80
-    frame_length: 25
-    frame_shift: 10
-    lfr_m: 7
-    lfr_n: 6
-    filter_length_max: -.inf
-
-Model:
-  model_path: resources/models/model.onnx
-  use_cuda: false
-  CUDAExecutionProvider:
-      device_id: 0
-      arena_extend_strategy: kNextPowerOfTwo
-      cudnn_conv_algo_search: EXHAUSTIVE
-      do_copy_in_default_stream: true
-  batch_size: 3
\ No newline at end of file
diff --git a/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/scoring/1/model.py b/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/scoring/1/model.py
index dfbaa52..ef6278d 100644
--- a/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/scoring/1/model.py
+++ b/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/scoring/1/model.py
@@ -21,8 +21,7 @@
 
 import json
 import os
-
-import pickle
+import yaml
 
 class TritonPythonModel:
     """Your Python model must use the same class name. Every Python model
@@ -76,8 +75,8 @@
         load lang_char.txt
         """
         with open(str(vocab_file), 'rb') as f:
-            token_list = pickle.load(f)
-        return token_list
+            config = yaml.load(f, Loader=yaml.Loader)
+        return config['token_list']
 
     def execute(self, requests):
         """`execute` must be implemented in every Python model. `execute`
diff --git a/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/scoring/config.pbtxt b/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/scoring/config.pbtxt
index 6b43fe4..a63d1c5 100644
--- a/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/scoring/config.pbtxt
+++ b/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/scoring/config.pbtxt
@@ -23,7 +23,7 @@
   },
   {
     key: "vocabulary",
-    value: { string_value: "./model_repo_paraformer_large_offline/scoring/token_list.pkl"}
+    value: { string_value: "./model_repo_paraformer_large_offline/feature_extractor/config.yaml"}
   },
   {
     key: "lm_path"
diff --git a/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/scoring/token_list.pkl b/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/scoring/token_list.pkl
deleted file mode 100644
index f1a2ce7..0000000
--- a/funasr/runtime/triton_gpu/model_repo_paraformer_large_offline/scoring/token_list.pkl
+++ /dev/null
Binary files differ
diff --git a/funasr/version.txt b/funasr/version.txt
index ee1372d..7179039 100644
--- a/funasr/version.txt
+++ b/funasr/version.txt
@@ -1 +1 @@
-0.2.2
+0.2.3

--
Gitblit v1.9.1