Yabin Li
2023-08-21 e0fa63765bfb4a36bde7047c2a6066ca5a80e90f
Dev hw (#878)

* merge from hw (#872)

* hotwords

* Contextual Paraformer onnx export

* update

* update

* quant inference

* add clas hotword support

* update websocket-server

* update websocket-server

* add catch for hotword

* update websocket-server

* update paraformer

* update websocket-server

* add wait for funasr-wss-client

* fix core by adding clean_thread

* fix wav_name

* update funasr-wss-client

* update websocket-server

* Update SDK_tutorial_online_zh.md

---------

Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>

* Update websocket_protocol_zh.md

* Update websocket_protocol.md

* Update SDK_tutorial_zh.md

* Update SDK_tutorial.md

* Update SDK_advanced_guide_online_zh.md

* Update SDK_advanced_guide_online.md

* Update SDK_advanced_guide_offline_zh.md

* Update SDK_advanced_guide_offline_zh.md

* Update SDK_advanced_guide_offline.md

* Update SDK_advanced_guide_offline.md

* Update docker_offline_cpu_zh_lists

* update docs

* update

---------

Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>
46个文件已修改
7个文件已添加
2086 ■■■■■ 已修改文件
funasr/export/export_model.py 32 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/__init__.py 10 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/decoder/contextual_decoder.py 191 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/models/e2e_asr_contextual_paraformer.py 174 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/docs/SDK_advanced_guide_offline.md 13 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/docs/SDK_advanced_guide_offline_zh.md 17 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/docs/SDK_advanced_guide_online.md 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/docs/SDK_tutorial.md 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/docs/SDK_tutorial_zh.md 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/docs/docker_offline_cpu_zh_lists 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/docs/websocket_protocol.md 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/docs/websocket_protocol_zh.md 8 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/bin/funasr-onnx-2pass.cpp 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp 37 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-vad.cpp 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp 31 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/bin/funasr-onnx-online-asr.cpp 4 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/bin/funasr-onnx-online-punc.cpp 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/bin/funasr-onnx-online-rtf.cpp 6 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/com-define.h 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/funasrruntime.h 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/include/model.h 5 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/CMakeLists.txt 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/encode_converter.cpp 575 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/encode_converter.h 109 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/funasrruntime.cpp 36 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/offline-stream.cpp 20 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/paraformer-online.cpp 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/paraformer-online.h 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/paraformer.cpp 183 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/paraformer.h 15 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/precomp.h 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/seg_dict.cpp 53 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/seg_dict.h 26 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/util.cpp 29 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/util.h 6 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/vocab.cpp 10 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/onnxruntime/src/vocab.h 3 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py 11 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py 2 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py 148 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py 47 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/CMakeLists.txt 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/funasr-wss-client.cpp 51 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/funasr-wss-server-2pass.cpp 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/funasr-wss-server.cpp 14 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/readme.md 1 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/websocket-server-2pass.cpp 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/websocket-server.cpp 143 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/websocket/websocket-server.h 13 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/export/export_model.py
@@ -1,14 +1,11 @@
import json
from typing import Union, Dict
from pathlib import Path
import os
import logging
import torch
from funasr.export.models import get_model
import numpy as np
import random
import logging
import numpy as np
from pathlib import Path
from typing import Union, Dict, List
from funasr.export.models import get_model
from funasr.utils.types import str2bool, str2triple_str
# torch_version = float(".".join(torch.__version__.split(".")[:2]))
# assert torch_version > 1.9
@@ -55,20 +52,25 @@
        # export encoder1
        self.export_config["model_name"] = "model"
        models = get_model(
        model = get_model(
            model,
            self.export_config,
        )
        if not isinstance(models, tuple):
            models = (models,)
        for i, model in enumerate(models):
        if isinstance(model, List):
            for m in model:
                m.eval()
                if self.onnx:
                    self._export_onnx(m, verbose, export_dir)
                else:
                    self._export_torchscripts(m, verbose, export_dir)
                print("output dir: {}".format(export_dir))
        else:
            model.eval()
            # self._export_onnx(model, verbose, export_dir)
            if self.onnx:
                self._export_onnx(model, verbose, export_dir)
            else:
                self._export_torchscripts(model, verbose, export_dir)
            print("output dir: {}".format(export_dir))
@@ -233,7 +235,7 @@
        # model_script = torch.jit.script(model)
        model_script = model #torch.jit.trace(model)
        model_path = os.path.join(path, f'{model.model_name}.onnx')
        if not os.path.exists(model_path):
        # if not os.path.exists(model_path):
            torch.onnx.export(
                model_script,
                dummy_input,
funasr/export/models/__init__.py
@@ -12,9 +12,17 @@
from funasr.export.models.CT_Transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export
from funasr.export.models.e2e_asr_paraformer import ParaformerOnline_encoder_predictor as ParaformerOnline_encoder_predictor_export
from funasr.export.models.e2e_asr_paraformer import ParaformerOnline_decoder as ParaformerOnline_decoder_export
from funasr.export.models.e2e_asr_contextual_paraformer import ContextualParaformer_backbone as ContextualParaformer_backbone_export
from funasr.export.models.e2e_asr_contextual_paraformer import ContextualParaformer_embedder as ContextualParaformer_embedder_export
from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer
def get_model(model, export_config=None):
    if isinstance(model, BiCifParaformer):
    if isinstance(model, NeatContextualParaformer):
        backbone = ContextualParaformer_backbone_export(model, **export_config)
        embedder = ContextualParaformer_embedder_export(model, **export_config)
        return [embedder, backbone]
    elif isinstance(model, BiCifParaformer):
        return BiCifParaformer_export(model, **export_config)
    elif isinstance(model, ParaformerOnline):
        return (ParaformerOnline_encoder_predictor_export(model, model_name="model"),
funasr/export/models/decoder/contextual_decoder.py
New file
@@ -0,0 +1,191 @@
import os
import torch
import torch.nn as nn
from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask
from funasr.modules.attention import MultiHeadedAttentionSANMDecoder
from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export
from funasr.modules.attention import MultiHeadedAttentionCrossAtt
from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export
from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM
from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export
from funasr.export.models.modules.decoder_layer import DecoderLayerSANM as DecoderLayerSANM_export
class ContextualSANMDecoder(nn.Module):
    def __init__(self, model,
                 max_seq_len=512,
                 model_name='decoder',
                 onnx: bool = True,):
        super().__init__()
        # self.embed = model.embed #Embedding(model.embed, max_seq_len)
        self.model = model
        if onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
        for i, d in enumerate(self.model.decoders):
            if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
                d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
            if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
            if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
                d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
            self.model.decoders[i] = DecoderLayerSANM_export(d)
        if self.model.decoders2 is not None:
            for i, d in enumerate(self.model.decoders2):
                if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
                    d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
                if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
                    d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
                self.model.decoders2[i] = DecoderLayerSANM_export(d)
        for i, d in enumerate(self.model.decoders3):
            if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
                d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
            self.model.decoders3[i] = DecoderLayerSANM_export(d)
        self.output_layer = model.output_layer
        self.after_norm = model.after_norm
        self.model_name = model_name
        # bias decoder
        if isinstance(self.model.bias_decoder.src_attn, MultiHeadedAttentionCrossAtt):
            self.model.bias_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.bias_decoder.src_attn)
        self.bias_decoder = self.model.bias_decoder
        # last decoder
        if isinstance(self.model.last_decoder.src_attn, MultiHeadedAttentionCrossAtt):
            self.model.last_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.last_decoder.src_attn)
        if isinstance(self.model.last_decoder.self_attn, MultiHeadedAttentionSANMDecoder):
            self.model.last_decoder.self_attn = MultiHeadedAttentionSANMDecoder_export(self.model.last_decoder.self_attn)
        if isinstance(self.model.last_decoder.feed_forward, PositionwiseFeedForwardDecoderSANM):
            self.model.last_decoder.feed_forward = PositionwiseFeedForwardDecoderSANM_export(self.model.last_decoder.feed_forward)
        self.last_decoder = self.model.last_decoder
        self.bias_output = self.model.bias_output
        self.dropout = self.model.dropout
    def prepare_mask(self, mask):
        mask_3d_btd = mask[:, :, None]
        if len(mask.shape) == 2:
            mask_4d_bhlt = 1 - mask[:, None, None, :]
        elif len(mask.shape) == 3:
            mask_4d_bhlt = 1 - mask[:, None, :]
        mask_4d_bhlt = mask_4d_bhlt * -10000.0
        return mask_3d_btd, mask_4d_bhlt
    def forward(
        self,
        hs_pad: torch.Tensor,
        hlens: torch.Tensor,
        ys_in_pad: torch.Tensor,
        ys_in_lens: torch.Tensor,
        bias_embed: torch.Tensor,
    ):
        tgt = ys_in_pad
        tgt_mask = self.make_pad_mask(ys_in_lens)
        tgt_mask, _ = self.prepare_mask(tgt_mask)
        # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
        memory = hs_pad
        memory_mask = self.make_pad_mask(hlens)
        _, memory_mask = self.prepare_mask(memory_mask)
        # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
        x = tgt
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders(
            x, tgt_mask, memory, memory_mask
        )
        _, _, x_self_attn, x_src_attn = self.last_decoder(
            x, tgt_mask, memory, memory_mask
        )
        # contextual paraformer related
        contextual_length = torch.Tensor([bias_embed.shape[1]]).int().repeat(hs_pad.shape[0])
        # contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :]
        contextual_mask = self.make_pad_mask(contextual_length)
        contextual_mask, _ = self.prepare_mask(contextual_mask)
        # import pdb; pdb.set_trace()
        contextual_mask = contextual_mask.transpose(2, 1).unsqueeze(1)
        cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, bias_embed, memory_mask=contextual_mask)
        if self.bias_output is not None:
            x = torch.cat([x_src_attn, cx], dim=2)
            x = self.bias_output(x.transpose(1, 2)).transpose(1, 2)  # 2D -> D
            x = x_self_attn + self.dropout(x)
        if self.model.decoders2 is not None:
            x, tgt_mask, memory, memory_mask, _ = self.model.decoders2(
                x, tgt_mask, memory, memory_mask
            )
        x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
            x, tgt_mask, memory, memory_mask
        )
        x = self.after_norm(x)
        x = self.output_layer(x)
        return x, ys_in_lens
    def get_dummy_inputs(self, enc_size):
        tgt = torch.LongTensor([0]).unsqueeze(0)
        memory = torch.randn(1, 100, enc_size)
        pre_acoustic_embeds = torch.randn(1, 1, enc_size)
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        cache = [
            torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size))
            for _ in range(cache_num)
        ]
        return (tgt, memory, pre_acoustic_embeds, cache)
    def is_optimizable(self):
        return True
    def get_input_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return ['tgt', 'memory', 'pre_acoustic_embeds'] \
               + ['cache_%d' % i for i in range(cache_num)]
    def get_output_names(self):
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        return ['y'] \
               + ['out_cache_%d' % i for i in range(cache_num)]
    def get_dynamic_axes(self):
        ret = {
            'tgt': {
                0: 'tgt_batch',
                1: 'tgt_length'
            },
            'memory': {
                0: 'memory_batch',
                1: 'memory_length'
            },
            'pre_acoustic_embeds': {
                0: 'acoustic_embeds_batch',
                1: 'acoustic_embeds_length',
            }
        }
        cache_num = len(self.model.decoders) + len(self.model.decoders2)
        ret.update({
            'cache_%d' % d: {
                0: 'cache_%d_batch' % d,
                2: 'cache_%d_length' % d
            }
            for d in range(cache_num)
        })
        return ret
    def get_model_config(self, path):
        return {
            "dec_type": "XformerDecoder",
            "model_path": os.path.join(path, f'{self.model_name}.onnx'),
            "n_layers": len(self.model.decoders) + len(self.model.decoders2),
            "odim": self.model.decoders[0].size
        }
funasr/export/models/e2e_asr_contextual_paraformer.py
New file
@@ -0,0 +1,174 @@
from audioop import bias
import logging
import torch
import torch.nn as nn
import numpy as np
from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask
from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
from funasr.models.predictor.cif import CifPredictorV2
from funasr.export.models.predictor.cif import CifPredictorV2 as CifPredictorV2_export
from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export
from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSAN as ParaformerDecoderSAN_export
from funasr.export.models.decoder.contextual_decoder import ContextualSANMDecoder as ContextualSANMDecoder_export
from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder
class ContextualParaformer_backbone(nn.Module):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2206.08317
    """
    def __init__(
            self,
            model,
            max_seq_len=512,
            feats_dim=560,
            model_name='model',
            **kwargs,
    ):
        super().__init__()
        onnx = False
        if "onnx" in kwargs:
            onnx = kwargs["onnx"]
        if isinstance(model.encoder, SANMEncoder):
            self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
        elif isinstance(model.encoder, ConformerEncoder):
            self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
        if isinstance(model.predictor, CifPredictorV2):
            self.predictor = CifPredictorV2_export(model.predictor)
        # decoder
        if isinstance(model.decoder, ContextualParaformerDecoder):
            self.decoder = ContextualSANMDecoder_export(model.decoder, onnx=onnx)
        elif isinstance(model.decoder, ParaformerSANMDecoder):
            self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx)
        elif isinstance(model.decoder, ParaformerDecoderSAN):
            self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
        self.feats_dim = feats_dim
        self.model_name = model_name
        if onnx:
            self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
        else:
            self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
    def forward(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
            bias_embed: torch.Tensor,
    ):
        # a. To device
        batch = {"speech": speech, "speech_lengths": speech_lengths}
        # batch = to_device(batch, device=self.device)
        enc, enc_len = self.encoder(**batch)
        mask = self.make_pad_mask(enc_len)[:, None, :]
        pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask)
        pre_token_length = pre_token_length.floor().type(torch.int32)
        # bias_embed = bias_embed. squeeze(0).repeat([enc.shape[0], 1, 1])
        decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length, bias_embed)
        decoder_out = torch.log_softmax(decoder_out, dim=-1)
        # sample_ids = decoder_out.argmax(dim=-1)
        return decoder_out, pre_token_length
    def get_dummy_inputs(self):
        speech = torch.randn(2, 30, self.feats_dim)
        speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
        bias_embed = torch.randn(2, 1, 512)
        return (speech, speech_lengths, bias_embed)
    def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"):
        import numpy as np
        fbank = np.loadtxt(txt_file)
        fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32)
        speech = torch.from_numpy(fbank[None, :, :].astype(np.float32))
        speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32))
        return (speech, speech_lengths)
    def get_input_names(self):
        return ['speech', 'speech_lengths', 'bias_embed']
    def get_output_names(self):
        return ['logits', 'token_num']
    def get_dynamic_axes(self):
        return {
            'speech': {
                0: 'batch_size',
                1: 'feats_length'
            },
            'speech_lengths': {
                0: 'batch_size',
            },
            'bias_embed': {
                0: 'batch_size',
                1: 'num_hotwords'
            },
            'logits': {
                0: 'batch_size',
                1: 'logits_length'
            },
        }
class ContextualParaformer_embedder(nn.Module):
    def __init__(self,
                 model,
                 max_seq_len=512,
                 feats_dim=560,
                 model_name='model',
                 **kwargs,):
        super().__init__()
        self.embedding = model.bias_embed
        model.bias_encoder.batch_first = False
        self.bias_encoder = model.bias_encoder
        # self.bias_encoder.batch_first = False
        self.feats_dim = feats_dim
        self.model_name = "{}_eb".format(model_name)
    def forward(self, hotword):
        hotword = self.embedding(hotword).transpose(0, 1) # batch second
        hw_embed, (_, _) = self.bias_encoder(hotword)
        return hw_embed
    def get_dummy_inputs(self):
        hotword = torch.tensor([
                                [10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
                                [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
                                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                                [10, 11, 12, 13, 14, 10, 11, 12, 13, 14],
                                [100, 101, 0, 0, 0, 0, 0, 0, 0, 0],
                                [1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
                               ],
                                dtype=torch.int32)
        # hotword_length = torch.tensor([10, 2, 1], dtype=torch.int32)
        return (hotword)
    def get_input_names(self):
        return ['hotword']
    def get_output_names(self):
        return ['hw_embed']
    def get_dynamic_axes(self):
        return {
            'hotword': {
                0: 'num_hotwords',
            },
            'hw_embed': {
                0: 'num_hotwords',
            },
        }
funasr/runtime/docs/SDK_advanced_guide_offline.md
@@ -59,6 +59,11 @@
  --vad-dir damo/speech_fsmn_vad_zh-cn-16k-common-onnx \
  --model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx  \
  --punc-dir damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx > log.out 2>&1 &
# If you want to close ssl,please add:--certfile 0
# If you want to deploy the timestamp or hotword model, please set --model-dir to the corresponding model:
# speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-onnx(timestamp)
# damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404-onnx(hotword)
```
More details about the script run_server.sh:
@@ -92,8 +97,8 @@
--port: Port number that the server listens on. Default is 10095.
--decoder-thread-num: Number of inference threads that the server starts. Default is 8.
--io-thread-num: Number of IO threads that the server starts. Default is 1.
--certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt.
--keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key.
--certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close ssl,set ""
--keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key. If you want to close ssl,set ""
```
The FunASR-wss-server also supports loading models from a local path (see Preparing Model Resources for detailed instructions on preparing local model resources). Here is an example:
@@ -183,6 +188,7 @@
--output_dir: the path to the recognition result output.
--ssl: whether to use SSL encryption. The default is to use SSL.
--mode: offline mode.
--hotword If am is hotword model, setting hotword: *.txt(one hotword perline) or hotwords seperate by space (could be: 阿里巴巴 达摩院)
```
### c++-client
@@ -199,6 +205,7 @@
--output_dir: the path to the recognition result output.
--ssl: whether to use SSL encryption. The default is to use SSL.
--mode: offline mode.
--hotword If am is hotword model, setting hotword: *.txt(one hotword perline) or hotwords seperate by space (could be: 阿里巴巴 达摩院)
```
### Custom client
@@ -207,7 +214,7 @@
```text
# First communication
{"mode": "offline", "wav_name": wav_name, "is_speaking": True}
{"mode": "offline", "wav_name": wav_name, "is_speaking": True, "hotwords": "hotword1|hotword2"}
# Send wav data
Bytes data
# Send end flag
funasr/runtime/docs/SDK_advanced_guide_offline_zh.md
@@ -28,6 +28,10 @@
  --punc-dir damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx > log.out 2>&1 &
# 如果您想关闭ssl,增加参数:--certfile 0
# 如果您想使用时间戳或者热词模型进行部署,请设置--model-dir为对应模型:
# damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-onnx(时间戳)
# 或者 damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404-onnx(热词)
```
服务端详细参数介绍可参考[服务端参数介绍](#服务端参数介绍)
### 客户端测试与使用
@@ -92,7 +96,9 @@
--port 10095 部署端口号
--mode offline表示离线文件转写
--audio_in 需要进行转写的音频文件,支持文件路径,文件列表wav.scp
--output_dir 识别结果保存路径
--thread_num 设置并发发送线程数,默认为1
--ssl 设置是否开启ssl证书校验,默认1开启,设置为0关闭
--hotword 如果模型为热词模型,可以设置热词: *.txt(每行一个热词) 或者空格分隔的热词字符串 (could be: 阿里巴巴 达摩院)
```
### cpp-client
@@ -107,6 +113,7 @@
--server-ip 为FunASR runtime-SDK服务部署机器ip,默认为本机ip(127.0.0.1),如果client与服务不在同一台服务器,需要改为部署机器ip
--port 10095 部署端口号
--wav-path 需要进行转写的音频文件,支持文件路径
--hotword 如果模型为热词模型,可以设置热词: *.txt(每行一个热词) 或者空格分隔的热词字符串 (could be: 阿里巴巴 达摩院)
```
### Html网页版
@@ -152,8 +159,8 @@
--port  服务端监听的端口号,默认为 10095
--decoder-thread-num  服务端启动的推理线程数,默认为 8
--io-thread-num  服务端启动的IO线程数,默认为 1
--certfile  ssl的证书文件,默认为:../../../ssl_key/server.crt
--keyfile   ssl的密钥文件,默认为:../../../ssl_key/server.key
--certfile  ssl的证书文件,默认为:../../../ssl_key/server.crt,如果需要关闭ssl,参数设置为”“
--keyfile   ssl的密钥文件,默认为:../../../ssl_key/server.key,如果需要关闭ssl,参数设置为”“
```
funasr-wss-server同时也支持从本地路径加载模型(本地模型资源准备详见[模型资源准备](#模型资源准备))示例如下:
@@ -180,8 +187,8 @@
--port  服务端监听的端口号,默认为 10095
--decoder-thread-num  服务端启动的推理线程数,默认为 8
--io-thread-num  服务端启动的IO线程数,默认为 1
--certfile ssl的证书文件,默认为:../../../ssl_key/server.crt
--keyfile  ssl的密钥文件,默认为:../../../ssl_key/server.key
--certfile ssl的证书文件,默认为:../../../ssl_key/server.crt,如果需要关闭ssl,参数设置为”“
--keyfile  ssl的密钥文件,默认为:../../../ssl_key/server.key,如果需要关闭ssl,参数设置为”“
```
## 模型资源准备
funasr/runtime/docs/SDK_advanced_guide_online.md
@@ -25,6 +25,8 @@
  --model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx  \
  --online-model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online-onnx  \
  --punc-dir damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx > log.out 2>&1 &
# If you want to close ssl,please add:--certfile 0
```
For a more detailed description of server parameters, please refer to [Server Introduction]()
### Client Testing and Usage
funasr/runtime/docs/SDK_tutorial.md
@@ -33,6 +33,7 @@
```shell
sudo bash funasr-runtime-deploy-offline-cpu-zh.sh install --workspace /root/funasr-runtime-resources
```
Note: If you need to deploy the timestamp model or hotword model, select the corresponding model in step 2 of the installation and deployment process, where 1 is the paraformer-large model, 2 is the paraformer-large timestamp model, and 3 is the paraformer-large hotword model.
### Client Testing and Usage
@@ -69,6 +70,7 @@
--audio_in is the audio file that needs to be transcribed, supporting file paths and file list wav.scp
--thread_num sets the number of concurrent sending threads, default is 1
--ssl sets whether to enable SSL certificate verification, default is 1 to enable, and 0 to disable
--hotword If am is hotword model, setting hotword: *.txt(one hotword perline) or hotwords seperate by space (could be: 阿里巴巴 达摩院)
```
### cpp-client
@@ -85,6 +87,7 @@
--wav-path specifies the audio file to be transcribed, and supports file paths.
--thread_num sets the number of concurrent send threads, with a default value of 1.
--ssl sets whether to enable SSL certificate verification, with a default value of 1 for enabling and 0 for disabling.
--hotword If am is hotword model, setting hotword: *.txt(one hotword perline) or hotwords seperate by space (could be: 阿里巴巴 达摩院)
```
### html-client
funasr/runtime/docs/SDK_tutorial_zh.md
@@ -34,6 +34,7 @@
```shell
sudo bash funasr-runtime-deploy-offline-cpu-zh.sh install --workspace ./funasr-runtime-resources
```
注:如果需要部署时间戳模型或者热词模型,在安装部署步骤2时选择对应模型,其中1为paraformer-large模型,2为paraformer-large 时间戳模型,3为paraformer-large 热词模型
### 客户端测试与使用
@@ -71,6 +72,7 @@
--audio_in 需要进行转写的音频文件,支持文件路径,文件列表wav.scp
--thread_num 设置并发发送线程数,默认为1
--ssl 设置是否开启ssl证书校验,默认1开启,设置为0关闭
--hotword 如果模型为热词模型,可以设置热词: *.txt(每行一个热词) 或者空格分隔的热词字符串 (could be: 阿里巴巴 达摩院)
```
### cpp-client
@@ -87,6 +89,7 @@
--wav-path 需要进行转写的音频文件,支持文件路径
--thread_num 设置并发发送线程数,默认为1
--ssl 设置是否开启ssl证书校验,默认1开启,设置为0关闭
--hotword 如果模型为热词模型,可以设置热词: *.txt(每行一个热词) 或者空格分隔的热词字符串 (could be: 阿里巴巴 达摩院)
```
### html-client
funasr/runtime/docs/docker_offline_cpu_zh_lists
@@ -1,4 +1,5 @@
DOCKER:
  funasr-runtime-sdk-cpu-0.2.0
  funasr-runtime-sdk-cpu-0.1.0
DEFAULT_ASR_MODEL:
  damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx
funasr/runtime/docs/websocket_protocol.md
@@ -10,7 +10,7 @@
#### Initial Communication
The message (which needs to be serialized in JSON) is:
```text
{"mode": "offline", "wav_name": "wav_name", "is_speaking": True,"wav_format":"pcm"}
{"mode": "offline", "wav_name": "wav_name","wav_format":"pcm","is_speaking": True,"wav_format":"pcm","hotwords":"阿里巴巴 达摩院 阿里云"}
```
Parameter explanation:
```text
@@ -19,6 +19,7 @@
`wav_format`: the audio and video file extension, such as pcm, mp3, mp4, etc.
`is_speaking`: False indicates the end of a sentence, such as a VAD segmentation point or the end of a WAV file
`audio_fs`: when the input audio is in PCM format, the audio sampling rate parameter needs to be added
`hotwords`:If AM is the hotword model, hotword data needs to be sent to the server in string format, with " " used as a separator between hotwords. For example:"阿里巴巴 达摩院 阿里云"
```
#### Sending Audio Data
@@ -34,7 +35,7 @@
#### Sending Recognition Results
The message (serialized in JSON) is:
```text
{"mode": "offline", "wav_name": "wav_name", "text": "asr ouputs", "is_final": True}
{"mode": "offline", "wav_name": "wav_name", "text": "asr ouputs", "is_final": True, "timestamp":"[[100,200], [200,500]]"}
```
Parameter explanation:
```text
@@ -42,6 +43,7 @@
`wav_name`: the name of the audio file to be transcribed
`text`: the text output of speech recognition
`is_final`: indicating the end of recognition
`timestamp`:If AM is a timestamp model, it will return this field, indicating the timestamp, in the format of "[[100,200], [200,500]]"
```
## Real-time Speech Recognition
@@ -56,7 +58,7 @@
#### Initial Communication
The message (which needs to be serialized in JSON) is:
```text
{"mode": "2pass", "wav_name": "wav_name", "is_speaking": True, "wav_format":"pcm", "chunk_size":[5,10,5]
{"mode": "2pass", "wav_name": "wav_name", "is_speaking": True, "wav_format":"pcm", "chunk_size":[5,10,5]}
```
Parameter explanation:
```text
funasr/runtime/docs/websocket_protocol_zh.md
@@ -10,7 +10,7 @@
#### 首次通信
message为(需要用json序列化):
```text
{"mode": "offline", "wav_name": "wav_name", "is_speaking": True,"wav_format":"pcm"}
{"mode": "offline", "wav_name": "wav_name","wav_format":"pcm","is_speaking": True,"wav_format":"pcm","hotwords":"阿里巴巴 达摩院 阿里云"}
```
参数介绍:
```text
@@ -19,6 +19,7 @@
`wav_format`:表示音视频文件后缀名,可选pcm、mp3、mp4等
`is_speaking`:False 表示断句尾点,例如,vad切割点,或者一条wav结束
`audio_fs`:当输入音频为pcm数据是,需要加上音频采样率参数
`hotwords`:如果AM为热词模型,需要向服务端发送热词数据,格式为字符串,热词之间用" "分隔,例如 "阿里巴巴 达摩院 阿里云"
```
#### 发送音频数据
@@ -34,7 +35,7 @@
#### 发送识别结果
message为(采用json序列化)
```text
{"mode": "offline", "wav_name": "wav_name", "text": "asr ouputs", "is_final": True}
{"mode": "offline", "wav_name": "wav_name", "text": "asr ouputs", "is_final": True,"timestamp":"[[100,200], [200,500]]"}
```
参数介绍:
```text
@@ -42,6 +43,7 @@
`wav_name`:表示需要推理音频文件名
`text`:表示语音识别输出文本
`is_final`:表示识别结束
`timestamp`:如果AM为时间戳模型,会返回此字段,表示时间戳,格式为 "[[100,200], [200,500]]"(ms)
```
## 实时语音识别
@@ -56,7 +58,7 @@
#### 首次通信
message为(需要用json序列化):
```text
{"mode": "2pass", "wav_name": "wav_name", "is_speaking": True, "wav_format":"pcm", "chunk_size":[5,10,5]
{"mode": "2pass", "wav_name": "wav_name", "is_speaking": True, "wav_format":"pcm", "chunk_size":[5,10,5]}
```
参数介绍:
```text
funasr/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp
@@ -194,12 +194,12 @@
    TCLAP::CmdLine cmd("funasr-onnx-2pass", ' ', "1.0");
    TCLAP::ValueArg<std::string>    offline_model_dir("", OFFLINE_MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
    TCLAP::ValueArg<std::string>    online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains encoder.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains model.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
    TCLAP::ValueArg<std::string>    vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
    TCLAP::ValueArg<std::string>    punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
    TCLAP::ValueArg<std::int32_t>   onnx_thread("", "onnx-inter-thread", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
    TCLAP::ValueArg<std::int32_t>   thread_num_("", THREAD_NUM, "multi-thread num for rtf", false, 1, "int32_t");
funasr/runtime/onnxruntime/bin/funasr-onnx-2pass.cpp
@@ -43,12 +43,12 @@
    TCLAP::CmdLine cmd("funasr-onnx-2pass", ' ', "1.0");
    TCLAP::ValueArg<std::string>    offline_model_dir("", OFFLINE_MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
    TCLAP::ValueArg<std::string>    online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains encoder.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains model.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
    TCLAP::ValueArg<std::string>    vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
    TCLAP::ValueArg<std::string>    punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
    TCLAP::ValueArg<std::int32_t>   onnx_thread("", "onnx-inter-thread", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-punc.cpp
@@ -35,7 +35,7 @@
    TCLAP::CmdLine cmd("funasr-onnx-offline-punc", ' ', "1.0");
    TCLAP::ValueArg<std::string>    model_dir("", MODEL_DIR, "the punc model path, which contains model.onnx, punc.yaml", true, "", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
    TCLAP::ValueArg<std::string> txt_path("", TXT_PATH, "txt file path, one sentence per line", true, "", "string");
    cmd.add(model_dir);
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
@@ -29,17 +29,18 @@
std::mutex mtx;
void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wav_ids,
            float* total_length, long* total_time, int core_id) {
            float* total_length, long* total_time, int core_id, string hotwords) {
    
    struct timeval start, end;
    long seconds = 0;
    float n_total_length = 0.0f;
    long n_total_time = 0;
    std::vector<std::vector<float>> hotwords_embedding = CompileHotwordEmbedding(asr_handle, hotwords);
    
    // warm up
    for (size_t i = 0; i < 1; i++)
    {
        FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, 16000);
        FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, hotwords_embedding, 16000);
        if(result){
            FunASRFreeResult(result);
        }
@@ -53,7 +54,7 @@
        }
        gettimeofday(&start, NULL);
        FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[i].c_str(), RASR_NONE, NULL, 16000);
        FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[i].c_str(), RASR_NONE, NULL, hotwords_embedding, 16000);
        gettimeofday(&end, NULL);
        seconds = (end.tv_sec - start.tv_sec);
@@ -107,14 +108,15 @@
    TCLAP::CmdLine cmd("funasr-onnx-offline-rtf", ' ', "1.0");
    TCLAP::ValueArg<std::string>    model_dir("", MODEL_DIR, "the model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
    TCLAP::ValueArg<std::string>    vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string");
    TCLAP::ValueArg<std::string>    vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
    TCLAP::ValueArg<std::string>    punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string");
    TCLAP::ValueArg<std::string>    punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
    TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
    TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t");
    TCLAP::ValueArg<std::string> hotword("", HOTWORD, "*.txt(one hotword perline) or hotwords seperate by | (could be: 阿里巴巴 达摩院)", false, "", "string");
    cmd.add(model_dir);
    cmd.add(quantize);
@@ -124,6 +126,7 @@
    cmd.add(punc_quant);
    cmd.add(wav_path);
    cmd.add(thread_num);
    cmd.add(hotword);
    cmd.parse(argc, argv);
    std::map<std::string, std::string> model_path;
@@ -149,6 +152,26 @@
    long seconds = (end.tv_sec - start.tv_sec);
    long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
    LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
    // read hotwords
    std::string hotword_ = hotword.getValue();
    std::string hotwords_;
    if(is_target_file(hotword_, "txt")){
        ifstream in(hotword_);
        if (!in.is_open()) {
            LOG(ERROR) << "Failed to open file: " << model_path.at(HOTWORD) ;
            return 0;
        }
        string line;
        while(getline(in, line))
        {
            hotwords_ +=line+HOTWORD_SEP;
        }
        in.close();
    }else{
        hotwords_ = hotword_;
    }
    // read wav_path
    vector<string> wav_list;
@@ -188,7 +211,7 @@
    int rtf_threds = thread_num.getValue();
    for (int i = 0; i < rtf_threds; i++)
    {
        threads.emplace_back(thread(runReg, asr_handle, wav_list, wav_ids, &total_length, &total_time, i));
        threads.emplace_back(thread(runReg, asr_handle, wav_list, wav_ids, &total_length, &total_time, i, hotwords_));
    }
    for (auto& thread : threads)
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-vad.cpp
@@ -65,7 +65,7 @@
    TCLAP::CmdLine cmd("funasr-onnx-offline-vad", ' ', "1.0");
    TCLAP::ValueArg<std::string>    model_dir("", MODEL_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", true, "", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
@@ -44,13 +44,14 @@
    TCLAP::CmdLine cmd("funasr-onnx-offline", ' ', "1.0");
    TCLAP::ValueArg<std::string>    model_dir("", MODEL_DIR, "the asr model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
    TCLAP::ValueArg<std::string>    vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string");
    TCLAP::ValueArg<std::string>    vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
    TCLAP::ValueArg<std::string>    punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string");
    TCLAP::ValueArg<std::string>    punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
    TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
    TCLAP::ValueArg<std::string> hotword("", HOTWORD, "*.txt(one hotword perline) or hotwords seperate by space (could be: 阿里巴巴 达摩院)", false, "", "string");
    cmd.add(model_dir);
    cmd.add(quantize);
@@ -59,6 +60,7 @@
    cmd.add(punc_dir);
    cmd.add(punc_quant);
    cmd.add(wav_path);
    cmd.add(hotword);
    cmd.parse(argc, argv);
    std::map<std::string, std::string> model_path;
@@ -85,6 +87,26 @@
    long seconds = (end.tv_sec - start.tv_sec);
    long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
    LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
    // read hotwords
    std::string hotword_ = hotword.getValue();
    std::string hotwords_;
    if(is_target_file(hotword_, "txt")){
        ifstream in(hotword_);
        if (!in.is_open()) {
            LOG(ERROR) << "Failed to open file: " << model_path.at(HOTWORD) ;
            return 0;
        }
        string line;
        while(getline(in, line))
        {
            hotwords_ +=line+HOTWORD_SEP;
        }
        in.close();
    }else{
        hotwords_ = hotword_;
    }
    // read wav_path
    vector<string> wav_list;
@@ -115,11 +137,12 @@
    
    float snippet_time = 0.0f;
    long taking_micros = 0;
    std::vector<std::vector<float>> hotwords_embedding = CompileHotwordEmbedding(asr_hanlde, hotwords_);
    for (int i = 0; i < wav_list.size(); i++) {
        auto& wav_file = wav_list[i];
        auto& wav_id = wav_ids[i];
        gettimeofday(&start, NULL);
        FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL, 16000);
        FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL, hotwords_embedding, 16000);
        gettimeofday(&end, NULL);
        seconds = (end.tv_sec - start.tv_sec);
        taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
funasr/runtime/onnxruntime/bin/funasr-onnx-online-asr.cpp
@@ -45,8 +45,8 @@
    FLAGS_logtostderr = true;
    TCLAP::CmdLine cmd("funasr-onnx-offline-vad", ' ', "1.0");
    TCLAP::ValueArg<std::string>    model_dir("", MODEL_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", true, "", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
    TCLAP::ValueArg<std::string>    model_dir("", MODEL_DIR, "the asr online model path, which contains model.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
funasr/runtime/onnxruntime/bin/funasr-onnx-online-punc.cpp
@@ -55,7 +55,7 @@
    TCLAP::CmdLine cmd("funasr-onnx-online-punc", ' ', "1.0");
    TCLAP::ValueArg<std::string>    model_dir("", MODEL_DIR, "the punc model path, which contains model.onnx, punc.yaml", true, "", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
    TCLAP::ValueArg<std::string> txt_path("", TXT_PATH, "txt file path, one sentence per line", true, "", "string");
    cmd.add(model_dir);
funasr/runtime/onnxruntime/bin/funasr-onnx-online-rtf.cpp
@@ -179,11 +179,11 @@
    TCLAP::CmdLine cmd("funasr-onnx-online-rtf", ' ', "1.0");
    TCLAP::ValueArg<std::string>    model_dir("", MODEL_DIR, "the model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
    TCLAP::ValueArg<std::string>    vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string");
    TCLAP::ValueArg<std::string>    vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
    TCLAP::ValueArg<std::string>    punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string");
    TCLAP::ValueArg<std::string>    punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
    TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
    TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t");
funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp
@@ -72,7 +72,7 @@
    TCLAP::CmdLine cmd("funasr-onnx-offline-vad", ' ', "1.0");
    TCLAP::ValueArg<std::string>    model_dir("", MODEL_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", true, "", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
    TCLAP::ValueArg<std::string>    quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string");
    TCLAP::ValueArg<std::string>    wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
funasr/runtime/onnxruntime/include/com-define.h
@@ -27,6 +27,7 @@
#define TXT_PATH "txt-path"
#define THREAD_NUM "thread-num"
#define PORT_ID "port-id"
#define HOTWORD_SEP " "
// #define VAD_MODEL_PATH "vad-model"
// #define VAD_CMVN_PATH "vad-cmvn"
@@ -38,12 +39,16 @@
// #define PUNC_CONFIG_PATH "punc-config"
#define MODEL_NAME "model.onnx"
// hotword embedding compile model
#define MODEL_EB_NAME "model_eb.onnx"
#define QUANT_MODEL_NAME "model_quant.onnx"
#define VAD_CMVN_NAME "vad.mvn"
#define VAD_CONFIG_NAME "vad.yaml"
#define AM_CMVN_NAME "am.mvn"
#define AM_CONFIG_NAME "config.yaml"
#define PUNC_CONFIG_NAME "punc.yaml"
#define MODEL_SEG_DICT "seg_dict"
#define HOTWORD "hotword"
#define ENCODER_NAME "model.onnx"
#define QUANT_ENCODER_NAME "model_quant.onnx"
funasr/runtime/onnxruntime/include/funasrruntime.h
@@ -98,9 +98,10 @@
//OfflineStream
_FUNASRAPI FUNASR_HANDLE      FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num);
// buffer
_FUNASRAPI FUNASR_RESULT    FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000, std::string wav_format="pcm");
_FUNASRAPI FUNASR_RESULT    FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, const std::vector<std::vector<float>> &hw_emb, int sampling_rate=16000, std::string wav_format="pcm");
// file, support wav & pcm
_FUNASRAPI FUNASR_RESULT    FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
_FUNASRAPI FUNASR_RESULT    FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, const std::vector<std::vector<float>> &hw_emb, int sampling_rate=16000);
_FUNASRAPI const std::vector<std::vector<float>> CompileHotwordEmbedding(FUNASR_HANDLE handle, std::string &hotwords);
_FUNASRAPI void                FunOfflineUninit(FUNASR_HANDLE handle);
//2passStream
funasr/runtime/onnxruntime/include/model.h
@@ -13,8 +13,11 @@
    virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
    virtual void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
    virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
    virtual std::string Forward(float *din, int len, bool input_finished){return "";};
    virtual std::string Forward(float *din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}){return "";};
    virtual std::string Rescoring() = 0;
    virtual void InitHwCompiler(const std::string &hw_model, int thread_num){};
    virtual void InitSegDict(const std::string &seg_dict_model){};
    virtual std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords){};
};
Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);
funasr/runtime/onnxruntime/src/CMakeLists.txt
@@ -2,6 +2,8 @@
file(GLOB files1 "*.cpp")
set(files ${files1})
message("files: "${files})
add_library(funasr SHARED ${files})
if(WIN32)
@@ -20,5 +22,6 @@
    include_directories(${FFMPEG_DIR}/include)
endif()
#message("CXX_FLAGS "${CMAKE_CXX_FLAGS})
include_directories(${CMAKE_SOURCE_DIR}/include)
target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS})
funasr/runtime/onnxruntime/src/encode_converter.cpp
New file
@@ -0,0 +1,575 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 * MIT License  (https://opensource.org/licenses/MIT)
*/
#include "encode_converter.h"
#include <assert.h>
namespace funasr {
using namespace std;
U16CHAR_T UTF16[8];
U8CHAR_T UTF8[8];
size_t MyUtf8ToUtf16(const U8CHAR_T* pu8, size_t ilen, U16CHAR_T* pu16);
size_t MyUtf16ToUtf8(const U16CHAR_T* pu16, U8CHAR_T* pu8);
void EncodeConverter::SwapEndian(U16CHAR_T* pbuf, size_t len)
{
  for (size_t i = 0; i < len; i++) {
    pbuf[i] = ((pbuf[i] >> 8) | (pbuf[i] << 8));
  }
}
size_t MyUtf16ToUtf8(const U16CHAR_T* pu16, U8CHAR_T* pu8)
{
  size_t n = 0;
  if (pu16[0] <= 0x007F)
  {
    pu8[0] = (pu16[0] & 0x7F);
    n = 1;
  }
  else if (pu16[0] >= 0x0080 &&  pu16[0] <= 0x07FF)
  {
    pu8[1] = (0x80 | (pu16[0] & 0x003F));
    pu8[0] = (0xC0 | ((pu16[0] >> 6) & 0x001F));
    n = 2;
  }
  else if (pu16[0] >= 0x0800)
  {
    pu8[2] = (0x80 | (pu16[0] & 0x003F));
    pu8[1] = (0x80 | ((pu16[0] >> 6) & 0x003F));
    pu8[0] = (0xE0 | ((pu16[0] >> 12) & 0x000F));
    n = 3;
  }
  return n;
}
#define is2ByteUtf16(u16) ( (u16) >= 0x0080 && (u16) <= 0x07FF )
#define is3ByteUtf16(u16) ( (u16) >= 0x0800 )
size_t EncodeConverter::Utf16ToUtf8(const U16CHAR_T* pu16, U8CHAR_T* pu8)
{
  size_t n = 0;
  if (pu16[0] <= 0x007F)
  {
    pu8[0] = (pu16[0] & 0x7F);
    n = 1;
  }
  else if (pu16[0] >= 0x0080 &&  pu16[0] <= 0x07FF)
  {
    pu8[1] = (0x80 | (pu16[0] & 0x003F));
    pu8[0] = (0xC0 | ((pu16[0] >> 6) & 0x001F));
    n = 2;
  }
  else if (pu16[0] >= 0x0800)
  {
    pu8[2] = (0x80 | (pu16[0] & 0x003F));
    pu8[1] = (0x80 | ((pu16[0] >> 6) & 0x003F));
    pu8[0] = (0xE0 | ((pu16[0] >> 12) & 0x000F));
    n = 3;
  }
  return n;
}
size_t EncodeConverter::Utf16ToUtf8(const U16CHAR_T* pu16, size_t ilen,
    U8CHAR_T* pu8, size_t olen)
{
  size_t offset = 0;
  size_t sz = 0;
  /*
  for (size_t i = 0; i < ilen && offset < static_cast<int>(olen) - 3; i++) {
    sz = utf16ToUtf8(pu16 + i, pu8 + offset);
    offset += sz;
  }
  */
  for (size_t i = 0; i < ilen && static_cast<int>(offset) < static_cast<int>(olen); i++) {
    sz = Utf16ToUtf8(pu16 + i, pu8 + offset);
    if (static_cast<int>(offset + static_cast<int>(sz)) <= static_cast<int>(olen))
        offset += sz;
  }
 // pu8[offset] = '\0';
  return offset;
}
u8string EncodeConverter::Utf16ToUtf8(const u16string& u16str)
{
  size_t buflen = u16str.length()*3 + 1;
  U8CHAR_T* pu8 = new U8CHAR_T[buflen];
  size_t len = Utf16ToUtf8(u16str.data(), u16str.length(),
    pu8, buflen);
  u8string u8str(pu8, len);
  delete [] pu8;
  return u8str;
}
size_t EncodeConverter::Utf8ToUtf16(const U8CHAR_T* pu8, U16CHAR_T* pu16)
{
  size_t n = 0;
  if ((pu8[0] & 0xF0) == 0xE0)
  {
    if ((pu8[1] & 0xC0) == 0x80 &&
        (pu8[2] & 0xC0) == 0x80)
    {
      pu16[0] = (((pu8[0] & 0x0F) << 4) | ((pu8[1] & 0x3C) >> 2));
      pu16[0] <<= 8;
      pu16[0] |= (((pu8[1] & 0x03) << 6) | (pu8[2] & 0x3F));
    }
    else
    {
      pu16[0] = defUniChar;
    }
    n = 3;
  }
  else if ((pu8[0] & 0xE0) == 0xC0)
  {
    if ((pu8[1] & 0xC0) == 0x80)
    {
      pu16[0] = ((pu8[0] & 0x1C) >> 2);
      pu16[0] <<= 8;
      pu16[0] |= (((pu8[0] & 0x03) << 6) | (pu8[1] & 0x3F));
    }
    else
    {
      pu16[0] = defUniChar;
    }
    n = 2;
  }
  else if ((pu8[0] & 0x80) == 0x00)
  {
    pu16[0] = pu8[0];
    n = 1;
  }
  return n;
}
size_t MyUtf8ToUtf16(const U8CHAR_T* pu8, size_t ilen, U16CHAR_T* pu16)
{
  size_t n = 0;
  if ((pu8[0] & 0xF0) == 0xE0 && ilen >= 3)
  {
    if ((pu8[1] & 0xC0) == 0x80 &&
        (pu8[2] & 0xC0) == 0x80)
    {
      pu16[0] = (((pu8[0] & 0x0F) << 4) | ((pu8[1] & 0x3C) >> 2));
      pu16[0] <<= 8;
      pu16[0] |= (((pu8[1] & 0x03) << 6) | (pu8[2] & 0x3F));
      n = 3;
    }
    else
    {
      pu16[0] = 0x0000;
      n = 1;
    }
  }
  else if ((pu8[0] & 0xE0) == 0xC0 && ilen >= 2)
  {
    if ((pu8[1] & 0xC0) == 0x80)
    {
      pu16[0] = ((pu8[0] & 0x1C) >> 2);
      pu16[0] <<= 8;
      pu16[0] |= (((pu8[0] & 0x03) << 6) | (pu8[1] & 0x3F));
      n = 2;
    }
    else
    {
      pu16[0] = 0x0000;
      n = 1;
    }
  }
  else if ((pu8[0] & 0x80) == 0x00)
  {
    pu16[0] = pu8[0];
    n = 1;
  }
  else
  {
      pu16[0] = 0x0000;
      n = 1;
  }
  return n;
}
size_t EncodeConverter::Utf8ToUtf16(const U8CHAR_T* pu8, size_t ilen, U16CHAR_T* pu16)
{
  size_t n = 0;
  if ((pu8[0] & 0xF0) == 0xE0 && ilen >= 3)
  {
    if ((pu8[1] & 0xC0) == 0x80 &&
        (pu8[2] & 0xC0) == 0x80)
    {
      pu16[0] = (((pu8[0] & 0x0F) << 4) | ((pu8[1] & 0x3C) >> 2));
      pu16[0] <<= 8;
      pu16[0] |= (((pu8[1] & 0x03) << 6) | (pu8[2] & 0x3F));
      n = 3;
      if( !is3ByteUtf16(pu16[0]) )
      {
          pu16[0] = 0x0000;
          n = 1;
      }
    }
    else
    {
      pu16[0] = 0x0000;
      n = 1;
    }
  }
  else if ((pu8[0] & 0xE0) == 0xC0 && ilen >= 2)
  {
    if ((pu8[1] & 0xC0) == 0x80)
    {
      pu16[0] = ((pu8[0] & 0x1C) >> 2);
      pu16[0] <<= 8;
      pu16[0] |= (((pu8[0] & 0x03) << 6) | (pu8[1] & 0x3F));
      n = 2;
      if( !is2ByteUtf16(pu16[0]) )
      {
          pu16[0] = 0x0000;
          n = 1;
      }
    }
    else
    {
      pu16[0] = 0x0000;
      n = 1;
    }
  }
  else if ((pu8[0] & 0x80) == 0x00)
  {
    pu16[0] = pu8[0];
    n = 1;
  }
  else
  {
      pu16[0] = 0x0000;
      n = 1;
  }
  return n;
  /*
  size_t n = 0;
  if ((pu8[0] & 0xF0) == 0xE0)
  {
    if (ilen >= 3 && (pu8[1] & 0xC0) == 0x80 &&
        (pu8[2] & 0xC0) == 0x80)
    {
      pu16[0] = (((pu8[0] & 0x0F) << 4) | ((pu8[1] & 0x3C) >> 2));
      pu16[0] <<= 8;
      pu16[0] |= (((pu8[1] & 0x03) << 6) | (pu8[2] & 0x3F));
    }
    else
    {
      pu16[0] = defUniChar;
    }
    n = 3;
  }
  else if ((pu8[0] & 0xE0) == 0xC0)
  {
    if( ilen >= 2 && (pu8[1] & 0xC0) == 0x80)
    {
      pu16[0] = ((pu8[0] & 0x1C) >> 2);
      pu16[0] <<= 8;
      pu16[0] |= (((pu8[0] & 0x03) << 6) | (pu8[1] & 0x3F));
    }
    else
    {
      pu16[0] = defUniChar;
    }
    n = 2;
  }
  else if ((pu8[0] & 0x80) == 0x00)
  {
    pu16[0] = pu8[0];
    n = 1;
  }
  else
  {
      pu16[0] = defUniChar;
      n = 1;
      for (size_t i = 1; i < ilen; i++)
      {
          if ((pu8[i] & 0xF0) == 0xE0 || (pu8[i] & 0xE0) == 0xC0 || (pu8[i] & 0x80) == 0x00)
              break;
          n++;
      }
  }
  return n;
  */
}
size_t EncodeConverter::Utf8ToUtf16(const U8CHAR_T* pu8, size_t ilen,
    U16CHAR_T* pu16, size_t olen)
{
  int offset = 0;
  size_t sz = 0;
  for (size_t i = 0; i < ilen && offset < static_cast<int>(olen); offset ++)
  {
    sz = Utf8ToUtf16(pu8 + i, ilen - i, pu16 + offset);
    i += sz;
    if (sz == 0) {
      // failed
      // assert(sz != 0);
      break;
    }
  }
//  pu16[offset] = '\0';
  return offset;
}
u16string EncodeConverter::Utf8ToUtf16(const u8string& u8str)
{
  U16CHAR_T* p16 = new U16CHAR_T[u8str.length() + 1];
  size_t len = Utf8ToUtf16(u8str.data(), u8str.length(),
      p16, u8str.length() + 1);
  u16string u16str(p16, len);
  delete[] p16;
  return u16str;
}
bool EncodeConverter::IsUTF8(const U8CHAR_T* pu8, size_t ilen)
{
  size_t i;
  size_t n = 0;
  for (i = 0; i < ilen; i += n)
  {
    if ((pu8[i] & 0xF0) == 0xE0 &&
        (pu8[i + 1] & 0xC0) == 0x80 &&
        (pu8[i + 2] & 0xC0) == 0x80)
    {
      n = 3;
    }
    else if ((pu8[i] & 0xE0) == 0xC0 &&
        (pu8[i + 1] & 0xC0) == 0x80)
    {
      n = 2;
    }
    else if ((pu8[i] & 0x80) == 0x00)
    {
      n = 1;
    }
    else
    {
      break;
    }
  }
  return i == ilen;
}
bool EncodeConverter::IsUTF8(const u8string& u8str)
{
  return IsUTF8(u8str.data(), u8str.length());
}
size_t EncodeConverter::GetUTF8Len(const U8CHAR_T* pu8, size_t ilen)
{
  size_t i;
  size_t n = 0;
  size_t rlen = 0;
  for (i = 0; i < ilen; i += n, rlen ++)
  {
    if ((pu8[i] & 0xF0) == 0xE0 &&
        (pu8[i + 1] & 0xC0) == 0x80 &&
        (pu8[i + 2] & 0xC0) == 0x80)
    {
      n = 3;
    }
    else if ((pu8[i] & 0xE0) == 0xC0 &&
        (pu8[i + 1] & 0xC0) == 0x80)
    {
      n = 2;
    }
    else if ((pu8[i] & 0x80) == 0x00)
    {
      n = 1;
    }
    else
    {
      break;
    }
  }
  if (i == ilen)
    return 0;
  else
    return rlen;
}
size_t EncodeConverter::GetUTF8Len(const u8string& u8str)
{
  return GetUTF8Len(u8str.data(), u8str.length());
}
size_t EncodeConverter::Utf16ToUtf8Len(const U16CHAR_T* pu16, size_t ilen)
{
  int offset = 0;
  for (size_t i = 0; i < ilen ; i++) {
      if (pu16[i] <= 0x007F)
      {
        offset += 1;
      }
      else if (pu16[i] >= 0x0080 &&  pu16[i] <= 0x07FF)
      {
        offset += 2;
      }
      else if (pu16[i] >= 0x0800)
      {
        offset += 3;
      }
  }
  return offset;
}
uint16_t EncodeConverter::ToUni(const char* sc, int &len)
{
    uint16_t wide[2];
    len = (int)Utf8ToUtf16((const U8CHAR_T*)sc, wide);
    return wide[0];
}
bool EncodeConverter::IsAllChineseCharactor(const U8CHAR_T* pu8, size_t ilen) {
    if (pu8 == NULL || ilen <= 0) {
        return false;
    }
    U16CHAR_T* p16 = new U16CHAR_T[ilen + 1];
    size_t len = Utf8ToUtf16(pu8, ilen, p16, ilen + 1);
    for (size_t i = 0; i < len; i++) {
        if (p16[i] < 0x4e00 || p16[i] > 0x9fff) {
            delete[] p16;
            return false;
        }
    }
    delete[] p16;
    return true;
}
bool EncodeConverter::HasAlpha(const U8CHAR_T* pu8, size_t ilen) {
  if (pu8 == NULL || ilen <= 0) {
    return false;
  }
  for (size_t i = 0; i < ilen; i++) {
    if (pu8[i]> 0 && isalpha(pu8[i])){
      return true;
    }
  }
  return false;
}
bool EncodeConverter::IsAllAlpha(const U8CHAR_T* pu8, size_t ilen) {
  if (pu8 == NULL || ilen <= 0) {
    return false;
  }
  for (size_t i = 0; i < ilen; i++) {
    if (!(pu8[i]> 0 && isalpha(pu8[i]))){
      return false;
    }
  }
  return true;
}
bool EncodeConverter::IsAllAlphaAndPunct(const U8CHAR_T* pu8, size_t ilen) {
  if (pu8 == NULL || ilen <= 0) {
    return false;
  }
  bool flag1 = HasAlpha(pu8, ilen);
  if (flag1 == false) {
    return false;
  }
  for (size_t i = 0; i < ilen; i++) {
    if (!(pu8[i]> 0 && (isalpha(pu8[i]) || (ispunct(pu8[i]))))){
      return false;
    }
  }
  return true;
}
bool EncodeConverter::IsAllAlphaAndDigit(const U8CHAR_T* pu8, size_t ilen) {
  if (pu8 == NULL || ilen <= 0) {
    return false;
  }
  bool flag1 = HasAlpha(pu8, ilen);
  if (flag1 == false) {
    return false;
  }
  for (size_t i = 0; i < ilen; i++) {
    if (!(pu8[i]> 0 && (isalnum(pu8[i]) || isalpha(pu8[i]) || pu8[i] == '\''))){
      return false;
    }
  }
  return true;
}
bool EncodeConverter::IsAllAlphaAndDigitAndBlank(const U8CHAR_T* pu8, size_t ilen) {
  if (pu8 == NULL || ilen <= 0) {
    return false;
  }
  for (size_t i = 0; i < ilen; i++) {
    if (!(pu8[i]> 0 && (isalnum(pu8[i]) || isalpha(pu8[i]) || isblank(pu8[i]) || pu8[i] == '\''))){
      return false;
    }
  }
  return true;
}
bool EncodeConverter::NeedAddTailBlank(std::string str) {
  U8CHAR_T *pu8 = (U8CHAR_T*)str.data();
  size_t ilen = str.size();
  if (pu8 == NULL || ilen <= 0) {
    return false;
  }
  if (IsAllAlpha(pu8, ilen) || IsAllAlphaAndPunct(pu8, ilen) || IsAllAlphaAndDigit(pu8, ilen)) {
    return true;
  } else {
    return false;
  }
}
std::vector<std::string> EncodeConverter::MergeEnglishWord(std::vector<std::string> &str_vec_input,
                                                           std::vector<int> &merge_mask) {
  std::vector<std::string> output;
  for (int i = 0; i < merge_mask.size(); i++) {
    if (merge_mask[i] == 1 && i > 0) {
      output[output.size() - 1] += str_vec_input[i];
    } else {
      output.push_back(str_vec_input[i]);
    }
  }
  str_vec_input.swap(output);
  return str_vec_input;
}
size_t EncodeConverter::Utf8ToCharset(const std::string &input, std::vector<std::string> &output) {
  std::string ch;
  for (size_t i = 0, len = 0; i != input.length(); i += len) {
    unsigned char byte = (unsigned)input[i];
    if (byte >= 0xFC) // lenght 6
      len = 6;
    else if (byte >= 0xF8)
      len = 5;
    else if (byte >= 0xF0)
      len = 4;
    else if (byte >= 0xE0)
      len = 3;
    else if (byte >= 0xC0)
      len = 2;
    else
      len = 1;
    ch = input.substr(i, len);
    output.push_back(ch);
  }
  return output.size();
}
}
funasr/runtime/onnxruntime/src/encode_converter.h
New file
@@ -0,0 +1,109 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 * MIT License  (https://opensource.org/licenses/MIT)
*/
#ifndef __WS__ENCODE_CONVERTER_H__
#define __WS__ENCODE_CONVERTER_H__
#include <string>
#include <stdint.h>
#include <vector>
#ifdef _MSC_VER
#include <windows.h>
#endif // _MSC_VER
namespace funasr {
    typedef unsigned char           U8CHAR_T;
    typedef unsigned short          U16CHAR_T;
    typedef std::basic_string<U8CHAR_T>  u8string;
    typedef std::basic_string<U16CHAR_T> u16string;
    class EncodeConverter {
    public:
        static const U16CHAR_T defUniChar = 0x25a1;  //WHITE SQUARE
    public:
        static void SwapEndian(U16CHAR_T* pbuf, size_t len);
        static size_t Utf16ToUtf8(const U16CHAR_T* pu16, U8CHAR_T* pu8);
        ///< @param pu16 UTF16 string
        ///< @param pu8 UTF8 string
        static size_t Utf16ToUtf8(const U16CHAR_T* pu16, size_t ilen,
                                  U8CHAR_T* pu8, size_t olen);
        static u8string Utf16ToUtf8(const u16string& u16str);
        static size_t Utf8ToUtf16(const U8CHAR_T* pu8, U16CHAR_T* pu16);
        static size_t Utf8ToUtf16(const U8CHAR_T* pu8, size_t ilen, U16CHAR_T* pu16);
        ///< @param pu8 UTF8 string
        ///< @param pu16 UTF16 string
        static size_t Utf8ToUtf16(const U8CHAR_T* pu8, size_t ilen,
                                  U16CHAR_T* pu16, size_t olen);
        static u16string Utf8ToUtf16(const u8string& u8str);
        ///< @param pu8 string
        ///< @return if string is encoded as UTF8 - true, otherwise false
        static bool IsUTF8(const U8CHAR_T* pu8, size_t ilen);
        ///< @param u8str string
        ///< @return if string is encoded as UTF8 - true, otherwise false
        static bool IsUTF8(const u8string& u8str);
        ///< @param UTF8 string
        ///< @return the word number of UTF8
        static size_t GetUTF8Len(const U8CHAR_T* pu8, size_t ilen);
        ///< @param UTF8 string
        ///< @return the word number of UTF8
        static size_t GetUTF8Len(const u8string& u8str);
        ///< @param pu16 UTF16 string
        ///< @param ilen UTF16 length
        ///< @return UTF8 string length
        static size_t Utf16ToUtf8Len(const U16CHAR_T* pu16, size_t ilen);
        static uint16_t ToUni(const char* sc, int &len);
        static bool IsChineseCharacter(U16CHAR_T &u16) {
            return (u16 >= 0x4e00 && u16 <= 0x9fff)  // common
                || (u16 >= 0x3400 && u16 <= 0x4dff); // rare, extension A
        }
        // whether the string is all Chinese
        static bool IsAllChineseCharactor(const U8CHAR_T* pu8, size_t ilen);
        static bool HasAlpha(const U8CHAR_T* pu8, size_t ilen);
        static bool NeedAddTailBlank(std::string str);
        static bool IsAllAlpha(const U8CHAR_T* pu8, size_t ilen);
        static bool IsAllAlphaAndPunct(const U8CHAR_T* pu8, size_t ilen);
        static bool IsAllAlphaAndDigit(const U8CHAR_T* pu8, size_t ilen);
        static bool IsAllAlphaAndDigitAndBlank(const U8CHAR_T* pu8, size_t ilen);
        static std::vector<std::string> MergeEnglishWord(std::vector<std::string> &str_vec_input,
                                                         std::vector<int> &merge_mask);
        static size_t Utf8ToCharset(const std::string &input, std::vector<std::string> &output);
#ifdef _MSC_VER
        // convert to the local ansi page
        static std::string UTF8ToLocaleAnsi(const std::string& strUTF8) {
            int len = MultiByteToWideChar(CP_UTF8, 0, strUTF8.c_str(), -1, NULL, 0);
            unsigned short*wszGBK = new unsigned short[len + 1];
            memset(wszGBK, 0, len * 2 + 2);
            MultiByteToWideChar(CP_UTF8, 0, (LPCCH)strUTF8.c_str(), -1, (LPWSTR)wszGBK, len);
            len = WideCharToMultiByte(CP_ACP, 0, (LPCWCH)wszGBK, -1, NULL, 0, NULL, NULL);
            char *szGBK = new char[len + 1];
            memset(szGBK, 0, len + 1);
            WideCharToMultiByte(CP_ACP, 0, (LPCWCH)wszGBK, -1, szGBK, len, NULL, NULL);
            std::string strTemp(szGBK);
            delete[]szGBK;
            delete[]wszGBK;
            return strTemp;
        }
#endif
    };
}
#endif //__WS_ENCODE_CONVERTER_H__
funasr/runtime/onnxruntime/src/funasrruntime.cpp
@@ -1,4 +1,5 @@
#include "precomp.h"
#include <vector>
#ifdef __cplusplus 
extern "C" {
@@ -216,7 +217,7 @@
    }
    // APIs for Offline-stream Infer
    _FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate, std::string wav_format)
    _FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, const std::vector<std::vector<float>> &hw_emb, int sampling_rate, std::string wav_format)
    {
        funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
        if (!offline_stream)
@@ -248,19 +249,22 @@
        int n_total = audio.GetQueueSize();
        float start_time = 0.0;
        while (audio.Fetch(buff, len, flag, start_time) > 0) {
            string msg = (offline_stream->asr_handle)->Forward(buff, len, true);
            string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb);
            std::vector<std::string> msg_vec = funasr::split(msg, '|');
            p_result->msg += msg_vec[0];
            //timestamp
            if(msg_vec.size() > 1){
                std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
                std::string cur_stamp = "";
                std::string cur_stamp = "[";
                for(int i=0; i<msg_stamp.size()-1; i+=2){
                    float begin = std::stof(msg_stamp[i])+start_time;
                    float end = std::stof(msg_stamp[i+1])+start_time;
                    cur_stamp += "["+std::to_string(begin)+","+std::to_string(end)+"],";
                    cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"]";
                    if(i != msg_stamp.size()-2){
                        cur_stamp +=",";
                }
                p_result->stamp += cur_stamp;
                }
                p_result->stamp += cur_stamp + "]";
            }
            n_step++;
            if (fn_callback)
@@ -274,7 +278,7 @@
        return p_result;
    }
    _FUNASRAPI FUNASR_RESULT FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate)
    _FUNASRAPI FUNASR_RESULT FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, const std::vector<std::vector<float>> &hw_emb, int sampling_rate)
    {
        funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
        if (!offline_stream)
@@ -308,20 +312,24 @@
        int n_total = audio.GetQueueSize();
        float start_time = 0.0;
        while (audio.Fetch(buff, len, flag, start_time) > 0) {
            string msg = (offline_stream->asr_handle)->Forward(buff, len, true);
            string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb);
            std::vector<std::string> msg_vec = funasr::split(msg, '|');
            p_result->msg += msg_vec[0];
            //timestamp
            if(msg_vec.size() > 1){
                std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ',');
                std::string cur_stamp = "";
                std::string cur_stamp = "[";
                for(int i=0; i<msg_stamp.size()-1; i+=2){
                    float begin = std::stof(msg_stamp[i])+start_time;
                    float end = std::stof(msg_stamp[i+1])+start_time;
                    cur_stamp += "["+std::to_string(begin)+","+std::to_string(end)+"],";
                    cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"]";
                    if(i != msg_stamp.size()-2){
                        cur_stamp +=",";
                }
                p_result->stamp += cur_stamp;
            }
                p_result->stamp += cur_stamp + "]";
            }
            n_step++;
            if (fn_callback)
                fn_callback(n_step, n_total);
@@ -334,6 +342,14 @@
        return p_result;
    }
    _FUNASRAPI const std::vector<std::vector<float>> CompileHotwordEmbedding(FUNASR_HANDLE handle, std::string &hotwords) {
        funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
        std::vector<std::vector<float>> emb;
        if (!offline_stream)
            return emb;
        return (offline_stream->asr_handle)->CompileHotwordEmbedding(hotwords);
    }
    // APIs for 2pass-stream Infer
    _FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished, int sampling_rate, std::string wav_format, ASR_TYPE mode)
    {
funasr/runtime/onnxruntime/src/offline-stream.cpp
@@ -33,18 +33,36 @@
        string am_model_path;
        string am_cmvn_path;
        string am_config_path;
        string hw_compile_model_path;
        string seg_dict_path;
    
        asr_handle = make_unique<Paraformer>();
        bool enable_hotword = false;
        hw_compile_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_EB_NAME);
        seg_dict_path = PathAppend(model_path.at(MODEL_DIR), MODEL_SEG_DICT);
        if (access(hw_compile_model_path.c_str(), F_OK) == 0) { // if model_eb.onnx exist, hotword enabled
          enable_hotword = true;
          asr_handle->InitHwCompiler(hw_compile_model_path, thread_num);
          asr_handle->InitSegDict(seg_dict_path);
        }
        if (enable_hotword) {
        am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
        if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
            am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
        }
        } else {
          am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
          if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
            am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
          }
        }
        am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
        am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
        asr_handle = make_unique<Paraformer>();
        asr_handle->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num);
    }
    // PUNC model
    if(model_path.find(PUNC_DIR) != model_path.end()){
        string punc_model_path;
funasr/runtime/onnxruntime/src/paraformer-online.cpp
@@ -469,7 +469,7 @@
    return result;
}
string ParaformerOnline::Forward(float* din, int len, bool input_finished)
string ParaformerOnline::Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb)
{
    std::vector<std::vector<float>> wav_feats;
    std::vector<float> waves(din, din+len);
funasr/runtime/onnxruntime/src/paraformer-online.h
@@ -101,7 +101,7 @@
        void AddOverlapChunk(std::vector<std::vector<float>> &wav_feats, bool input_finished);
        
        string ForwardChunk(std::vector<std::vector<float>> &wav_feats, bool input_finished);
        string Forward(float* din, int len, bool input_finished);
        string Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}});
        string Rescoring();
        // 2pass
        std::string online_res;
funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -4,13 +4,17 @@
*/
#include "precomp.h"
#include "paraformer.h"
#include "encode_converter.h"
#include <cstddef>
using namespace std;
namespace funasr {
Paraformer::Paraformer()
:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options_{}{
:use_hotword(false),
 env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options_{},
 hw_env_(ORT_LOGGING_LEVEL_ERROR, "paraformer_hw"),hw_session_options{} {
}
// offline
@@ -45,6 +49,10 @@
    m_strInputNames.push_back(strName.c_str());
    GetInputName(m_session_.get(), strName,1);
    m_strInputNames.push_back(strName);
    if (use_hotword) {
        GetInputName(m_session_.get(), strName, 2);
        m_strInputNames.push_back(strName);
    }
    
    size_t numOutputNodes = m_session_->GetOutputCount();
    for(int index=0; index<numOutputNodes; index++){
@@ -206,10 +214,47 @@
    }
}
void Paraformer::InitHwCompiler(const std::string &hw_model, int thread_num) {
    hw_session_options.SetIntraOpNumThreads(thread_num);
    hw_session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
    // DisableCpuMemArena can improve performance
    hw_session_options.DisableCpuMemArena();
    try {
        hw_m_session = std::make_unique<Ort::Session>(hw_env_, hw_model.c_str(), hw_session_options);
        LOG(INFO) << "Successfully load model from " << hw_model;
    } catch (std::exception const &e) {
        LOG(ERROR) << "Error when load hw compiler onnx model: " << e.what();
        exit(0);
    }
    string strName;
    GetInputName(hw_m_session.get(), strName);
    hw_m_strInputNames.push_back(strName.c_str());
    //GetInputName(hw_m_session.get(), strName,1);
    //hw_m_strInputNames.push_back(strName);
    GetOutputName(hw_m_session.get(), strName);
    hw_m_strOutputNames.push_back(strName);
    for (auto& item : hw_m_strInputNames)
        hw_m_szInputNames.push_back(item.c_str());
    for (auto& item : hw_m_strOutputNames)
        hw_m_szOutputNames.push_back(item.c_str());
    // if init hotword compiler is called, this is a hotword paraformer model
    use_hotword = true;
}
void Paraformer::InitSegDict(const std::string &seg_dict_model) {
    seg_dict = new SegDict(seg_dict_model.c_str());
}
Paraformer::~Paraformer()
{
    if(vocab)
        delete vocab;
    if(seg_dict)
        delete seg_dict;
}
void Paraformer::Reset()
@@ -228,6 +273,10 @@
    int32_t feature_dim = fbank_opts_.mel_opts.num_bins;
    vector<float> features(frames * feature_dim);
    float *p = features.data();
    //std::cout << "samples " << len << std::endl;
    //std::cout << "fbank frames " << frames << std::endl;
    //std::cout << "fbank dim " << feature_dim << std::endl;
    //std::cout << "feature size " << features.size() << std::endl;
    for (int32_t i = 0; i != frames; ++i) {
        const float *f = fbank_.GetFrame(i);
@@ -549,7 +598,7 @@
    }
  }
string Paraformer::Forward(float* din, int len, bool input_finished)
string Paraformer::Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb)
{
    int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
@@ -559,6 +608,7 @@
    int32_t feat_dim = lfr_m*in_feat_dim;
    int32_t num_frames = wav_feats.size() / feat_dim;
    //std::cout << "feat in: " << num_frames << " " << feat_dim << std::endl;
#ifdef _WIN_X86
        Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
@@ -583,10 +633,36 @@
    input_onnx.emplace_back(std::move(onnx_feats));
    input_onnx.emplace_back(std::move(onnx_feats_len));
    std::vector<float> embedding;
    try{
        if (use_hotword) {
            if(hw_emb.size()<=0){
                LOG(ERROR) << "hw_emb is null";
                return "";
            }
            //PrintMat(hw_emb, "input_clas_emb");
            const int64_t hotword_shape[3] = {1, hw_emb.size(), hw_emb[0].size()};
            embedding.reserve(hw_emb.size() * hw_emb[0].size());
            for (auto item : hw_emb) {
                embedding.insert(embedding.end(), item.begin(), item.end());
            }
            //LOG(INFO) << "hotword shape " << hotword_shape[0] << " " << hotword_shape[1] << " " << hotword_shape[2] << " size " << embedding.size();
            Ort::Value onnx_hw_emb = Ort::Value::CreateTensor<float>(
                m_memoryInfo, embedding.data(), embedding.size(), hotword_shape, 3);
            input_onnx.emplace_back(std::move(onnx_hw_emb));
        }
    }catch (std::exception const &e)
    {
        LOG(ERROR)<<e.what();
        return "";
    }
    string result;
    try {
        auto outputTensor = m_session_->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), input_onnx.size(), m_szOutputNames.data(), m_szOutputNames.size());
        std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
        //LOG(INFO) << "paraformer out shape " << outputShape[0] << " " << outputShape[1] << " " << outputShape[2];
        int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
        float* floatData = outputTensor[0].GetTensorMutableData<float>();
@@ -610,6 +686,17 @@
        }else{
            result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]);
        }
//         int pos = 0;
//         std::vector<std::vector<float>> logits;
//         for (int j = 0; j < outputShape[1]; j++)
//         {
//             std::vector<float> vec_token;
//             vec_token.insert(vec_token.begin(), floatData + pos, floatData + pos + outputShape[2]);
//             logits.push_back(vec_token);
//             pos += outputShape[2];
//         }
//         //PrintMat(logits, "logits_out");
//         result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]);
    }
    catch (std::exception const &e)
    {
@@ -619,6 +706,96 @@
    return result;
}
std::vector<std::vector<float>> Paraformer::CompileHotwordEmbedding(std::string &hotwords) {
    int embedding_dim = encoder_size;
    std::vector<std::vector<float>> hw_emb;
    if (!use_hotword) {
        std::vector<float> vec(embedding_dim, 0);
        hw_emb.push_back(vec);
        return hw_emb;
    }
    int max_hotword_len = 10;
    std::vector<int32_t> hotword_matrix;
    std::vector<int32_t> lengths;
    int hotword_size = 1;
    if (!hotwords.empty()) {
      std::vector<std::string> hotword_array = split(hotwords, ' ');
      hotword_size = hotword_array.size() + 1;
      hotword_matrix.reserve(hotword_size * max_hotword_len);
      for (auto hotword : hotword_array) {
        std::vector<std::string> chars;
        if (EncodeConverter::IsAllChineseCharactor((const U8CHAR_T*)hotword.c_str(), hotword.size())) {
          KeepChineseCharacterAndSplit(hotword, chars);
        } else {
          // for english
          std::vector<std::string> words = split(hotword, ' ');
          for (auto word : words) {
            std::vector<string> tokens = seg_dict->GetTokensByWord(word);
            chars.insert(chars.end(), tokens.begin(), tokens.end());
          }
        }
        std::vector<int32_t> hw_vector(max_hotword_len, 0);
        int vector_len = std::min(max_hotword_len, (int)chars.size());
        for (int i=0; i<chars.size(); i++) {
          std::cout << chars[i] << " ";
          hw_vector[i] = vocab->GetIdByToken(chars[i]);
        }
        std::cout << std::endl;
        lengths.push_back(vector_len);
        hotword_matrix.insert(hotword_matrix.end(), hw_vector.begin(), hw_vector.end());
      }
    }
    std::vector<int32_t> blank_vec(max_hotword_len, 0);
    blank_vec[0] = 1;
    hotword_matrix.insert(hotword_matrix.end(), blank_vec.begin(), blank_vec.end());
    lengths.push_back(1);
#ifdef _WIN_X86
        Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
#else
        Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
#endif
    const int64_t input_shape_[2] = {hotword_size, max_hotword_len};
    Ort::Value onnx_hotword = Ort::Value::CreateTensor<int32_t>(m_memoryInfo,
        (int32_t*)hotword_matrix.data(),
        hotword_size * max_hotword_len,
        input_shape_,
        2);
    LOG(INFO) << "clas shape " << hotword_size << " " << max_hotword_len << std::endl;
    std::vector<Ort::Value> input_onnx;
    input_onnx.emplace_back(std::move(onnx_hotword));
    std::vector<std::vector<float>> result;
    try {
        auto outputTensor = hw_m_session->Run(Ort::RunOptions{nullptr}, hw_m_szInputNames.data(), input_onnx.data(), input_onnx.size(), hw_m_szOutputNames.data(), hw_m_szOutputNames.size());
        std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
        int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
        float* floatData = outputTensor[0].GetTensorMutableData<float>(); // shape [max_hotword_len, hotword_size, dim]
        // get embedding by real hotword length
        assert(outputShape[0] == max_hotword_len);
        assert(outputShape[1] == hotword_size);
        embedding_dim = outputShape[2];
        for (int j = 0; j < hotword_size; j++)
        {
            int start_pos = hotword_size * (lengths[j] - 1) * embedding_dim + j * embedding_dim;
            std::vector<float> embedding;
            embedding.insert(embedding.begin(), floatData + start_pos, floatData + start_pos + embedding_dim);
            result.push_back(embedding);
        }
    }
    catch (std::exception const &e)
    {
        LOG(ERROR)<<e.what();
    }
    //PrintMat(result, "clas_embedding_output");
    return result;
}
string Paraformer::Rescoring()
{
    LOG(ERROR)<<"Not Imp!!!!!!";
funasr/runtime/onnxruntime/src/paraformer.h
@@ -16,6 +16,7 @@
    */
    private:
        Vocab* vocab = nullptr;
        SegDict* seg_dict = nullptr;
        //const float scale = 22.6274169979695;
        const float scale = 1.0;
@@ -23,6 +24,14 @@
        void LoadCmvn(const char *filename);
        vector<float> ApplyLfr(const vector<float> &in);
        void ApplyCmvn(vector<float> *v);
        std::shared_ptr<Ort::Session> hw_m_session = nullptr;
        Ort::Env hw_env_;
        Ort::SessionOptions hw_session_options;
        vector<string> hw_m_strInputNames, hw_m_strOutputNames;
        vector<const char*> hw_m_szInputNames;
        vector<const char*> hw_m_szOutputNames;
        bool use_hotword;
    public:
        Paraformer();
@@ -32,13 +41,17 @@
        void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
        // 2pass
        void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
        void InitHwCompiler(const std::string &hw_model, int thread_num);
        void InitSegDict(const std::string &seg_dict_model);
        std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords);
        void Reset();
        vector<float> FbankKaldi(float sample_rate, const float* waves, int len);
        string Forward(float* din, int len, bool input_finished=true);
        string Forward(float* din, int len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}});
        string GreedySearch( float* in, int n_len, int64_t token_nums, bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0});
        void TimestampOnnx(std::vector<float> &us_alphas, vector<float> us_cif_peak, vector<string>& char_list, std::string &res_str, 
                           vector<vector<float>> &timestamp_list, float begin_time = 0.0, float total_offset = -1.5);
        string PostProcess(std::vector<string> &raw_char, std::vector<std::vector<float>> &timestamp_list);
        string Rescoring();
        knf::FbankOptions fbank_opts_;
funasr/runtime/onnxruntime/src/precomp.h
@@ -38,11 +38,13 @@
#include "ct-transformer-online.h"
#include "e2e-vad.h"
#include "fsmn-vad.h"
#include "encode_converter.h"
#include "vocab.h"
#include "audio.h"
#include "fsmn-vad-online.h"
#include "tensor.h"
#include "util.h"
#include "seg_dict.h"
#include "resample.h"
#include "paraformer.h"
#include "paraformer-online.h"
funasr/runtime/onnxruntime/src/seg_dict.cpp
New file
@@ -0,0 +1,53 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 * MIT License  (https://opensource.org/licenses/MIT)
*/
#include "precomp.h"
//#include "util.h"
//#include "seg_dict.h"
#include <glog/logging.h>
#include <fstream>
#include <iostream>
#include <list>
#include <sstream>
#include <string>
using namespace std;
namespace funasr {
SegDict::SegDict(const char *filename)
{
    ifstream in(filename);
    if (!in) {
      LOG(ERROR) << filename << " open failed !!";
      return;
    }
    string textline;
    while (getline(in, textline)) {
      std::vector<string> line_item = split(textline, '\t');
      //std::cout << textline << std::endl;
      if (line_item.size() > 1) {
        std::string word = line_item[0];
        std::string segs = line_item[1];
        std::vector<string> segs_vec = split(segs, ' ');
        seg_dict[word] = segs_vec;
      }
    }
    LOG(INFO) << "load seg dict successfully";
}
std::vector<std::string> SegDict::GetTokensByWord(const std::string &word) {
  if (seg_dict.count(word))
    return seg_dict[word];
  else {
    std::vector<string> vec;
    return vec;
  }
}
SegDict::~SegDict()
{
}
} // namespace funasr
funasr/runtime/onnxruntime/src/seg_dict.h
New file
@@ -0,0 +1,26 @@
/**
 * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
 * MIT License  (https://opensource.org/licenses/MIT)
*/
#ifndef SEG_DICT_H
#define SEG_DICT_H
#include <stdint.h>
#include <string>
#include <vector>
#include <map>
using namespace std;
namespace funasr {
class SegDict {
  private:
    std::map<string, std::vector<string>> seg_dict;
  public:
    SegDict(const char *filename);
    ~SegDict();
    std::vector<std::string> GetTokensByWord(const std::string &word);
};
} // namespace funasr
#endif
funasr/runtime/onnxruntime/src/util.cpp
@@ -189,6 +189,25 @@
    return (extension == target);
}
void KeepChineseCharacterAndSplit(const std::string &input_str,
                                  std::vector<std::string> &chinese_characters) {
  chinese_characters.resize(0);
  std::vector<U16CHAR_T> u16_buf;
  u16_buf.resize(std::max(u16_buf.size(), input_str.size() + 1));
  U16CHAR_T* pu16 = u16_buf.data();
  U8CHAR_T * pu8 = (U8CHAR_T*)input_str.data();
  size_t ilen = input_str.size();
  size_t len = EncodeConverter::Utf8ToUtf16(pu8, ilen, pu16, ilen + 1);
  for (size_t i = 0; i < len; i++) {
    if (EncodeConverter::IsChineseCharacter(pu16[i])) {
      U8CHAR_T u8buf[4];
      size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf);
      u8buf[n] = '\0';
      chinese_characters.push_back((const char*)u8buf);
    }
  }
}
std::vector<std::string> split(const std::string &s, char delim) {
  std::vector<std::string> elems;
  std::stringstream ss(s);
@@ -199,4 +218,14 @@
  return elems;
}
template<typename T>
void PrintMat(const std::vector<std::vector<T>> &mat, const std::string &name) {
  std::cout << name << ":" << std::endl;
  for (auto item : mat) {
    for (auto item_ : item) {
      std::cout << item_ << " ";
    }
    std::cout << std::endl;
  }
}
} // namespace funasr
funasr/runtime/onnxruntime/src/util.h
@@ -27,6 +27,12 @@
string PathAppend(const string &p1, const string &p2);
bool is_target_file(const std::string& filename, const std::string target);
void KeepChineseCharacterAndSplit(const std::string &input_str,
                                  std::vector<std::string> &chinese_characters);
std::vector<std::string> split(const std::string &s, char delim);
template<typename T>
void PrintMat(const std::vector<std::vector<T>> &mat, const std::string &name);
} // namespace funasr
#endif
funasr/runtime/onnxruntime/src/vocab.cpp
@@ -29,11 +29,21 @@
        exit(-1);
    }
    YAML::Node myList = config["token_list"];
    int i = 0;
    for (YAML::const_iterator it = myList.begin(); it != myList.end(); ++it) {
        vocab.push_back(it->as<string>());
        token_id[it->as<string>()] = i;
        i ++;
    }
}
int Vocab::GetIdByToken(const std::string &token) {
    if (token_id.count(token)) {
        return token_id[token];
    }
    return 0;
}
void Vocab::Vector2String(vector<int> in, std::vector<std::string> &preds)
{
    for (auto it = in.begin(); it != in.end(); it++) {
funasr/runtime/onnxruntime/src/vocab.h
@@ -5,12 +5,14 @@
#include <stdint.h>
#include <string>
#include <vector>
#include <map>
using namespace std;
namespace funasr {
class Vocab {
  private:
    vector<string> vocab;
    std::map<string, int> token_id;
    bool IsEnglish(string ch);
    void LoadVocabFromYaml(const char* filename);
@@ -21,6 +23,7 @@
    bool IsChinese(string ch);
    void Vector2String(vector<int> in, std::vector<std::string> &preds);
    string Vector2StringV2(vector<int> in);
    int GetIdByToken(const std::string &token);
};
} // namespace funasr
funasr/runtime/python/onnxruntime/demo_contextual_paraformer.py
New file
@@ -0,0 +1,11 @@
from funasr_onnx import ContextualParaformer
from pathlib import Path
model_dir = "./export/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"
model = ContextualParaformer(model_dir, batch_size=1)
wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav'.format(Path.home())]
hotwords = '随机热词 各种热词 魔搭 阿里巴巴'
result = model(wav_path, hotwords)
print(result)
funasr/runtime/python/onnxruntime/funasr_onnx/__init__.py
@@ -1,5 +1,5 @@
# -*- encoding: utf-8 -*-
from .paraformer_bin import Paraformer
from .paraformer_bin import Paraformer, ContextualParaformer
from .vad_bin import Fsmn_vad
from .vad_bin import Fsmn_vad_online
from .punc_bin import CT_Transformer
funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_bin.py
@@ -7,6 +7,7 @@
from typing import List, Union, Tuple
import copy
import torch
import librosa
import numpy as np
@@ -16,6 +17,7 @@
from .utils.postprocess_utils import sentence_postprocess
from .utils.frontend import WavFrontend
from .utils.timestamp_utils import time_stamp_lfr6_onnx
from .utils.utils import pad_list, make_pad_mask
logging = get_logger()
@@ -210,3 +212,149 @@
        # texts = sentence_postprocess(token)
        return token
class ContextualParaformer(Paraformer):
    """
    Author: Speech Lab of DAMO Academy, Alibaba Group
    Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
    https://arxiv.org/abs/2206.08317
    """
    def __init__(self, model_dir: Union[str, Path] = None,
                 batch_size: int = 1,
                 device_id: Union[str, int] = "-1",
                 plot_timestamp_to: str = "",
                 quantize: bool = False,
                 intra_op_num_threads: int = 4,
                 cache_dir: str = None
                 ):
        if not Path(model_dir).exists():
            from modelscope.hub.snapshot_download import snapshot_download
            try:
                model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
            except:
                raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
        if quantize:
            model_bb_file = os.path.join(model_dir, 'model_quant.onnx')
            model_eb_file = os.path.join(model_dir, 'model_eb_quant.onnx')
        else:
            model_bb_file = os.path.join(model_dir, 'model.onnx')
            model_eb_file = os.path.join(model_dir, 'model_eb.onnx')
        token_list_file = os.path.join(model_dir, 'tokens.txt')
        self.vocab = {}
        with open(Path(token_list_file), 'r') as fin:
            for i, line in enumerate(fin.readlines()):
                self.vocab[line.strip()] = i
        #if quantize:
        #    model_file = os.path.join(model_dir, 'model_quant.onnx')
        #if not os.path.exists(model_file):
        #    logging.error(".onnx model not exist, please export first.")
        config_file = os.path.join(model_dir, 'config.yaml')
        cmvn_file = os.path.join(model_dir, 'am.mvn')
        config = read_yaml(config_file)
        self.converter = TokenIDConverter(config['token_list'])
        self.tokenizer = CharTokenizer()
        self.frontend = WavFrontend(
            cmvn_file=cmvn_file,
            **config['frontend_conf']
        )
        self.ort_infer_bb = OrtInferSession(model_bb_file, device_id, intra_op_num_threads=intra_op_num_threads)
        self.ort_infer_eb = OrtInferSession(model_eb_file, device_id, intra_op_num_threads=intra_op_num_threads)
        self.batch_size = batch_size
        self.plot_timestamp_to = plot_timestamp_to
        if "predictor_bias" in config['model_conf'].keys():
            self.pred_bias = config['model_conf']['predictor_bias']
        else:
            self.pred_bias = 0
    def __call__(self,
                 wav_content: Union[str, np.ndarray, List[str]],
                 hotwords: str,
                 **kwargs) -> List:
        # make hotword list
        hotwords, hotwords_length = self.proc_hotword(hotwords)
        # import pdb; pdb.set_trace()
        [bias_embed] = self.eb_infer(hotwords, hotwords_length)
        # index from bias_embed
        bias_embed = bias_embed.transpose(1, 0, 2)
        _ind = np.arange(0, len(hotwords)).tolist()
        bias_embed = bias_embed[_ind, hotwords_length.cpu().numpy().tolist()]
        waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq)
        waveform_nums = len(waveform_list)
        asr_res = []
        for beg_idx in range(0, waveform_nums, self.batch_size):
            end_idx = min(waveform_nums, beg_idx + self.batch_size)
            feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx])
            bias_embed = np.expand_dims(bias_embed, axis=0)
            bias_embed = np.repeat(bias_embed, feats.shape[0], axis=0)
            try:
                outputs = self.bb_infer(feats, feats_len, bias_embed)
                am_scores, valid_token_lens = outputs[0], outputs[1]
            except ONNXRuntimeError:
                #logging.warning(traceback.format_exc())
                logging.warning("input wav is silence or noise")
                preds = ['']
            else:
                preds = self.decode(am_scores, valid_token_lens)
                for pred in preds:
                    pred = sentence_postprocess(pred)
                    asr_res.append({'preds': pred})
        return asr_res
    def proc_hotword(self, hotwords):
        hotwords = hotwords.split(" ")
        hotwords_length = [len(i) - 1 for i in hotwords]
        hotwords_length.append(0)
        hotwords_length = torch.Tensor(hotwords_length).to(torch.int32)
        # hotwords.append('<s>')
        def word_map(word):
            return torch.tensor([self.vocab[i] for i in word])
        hotword_int = [word_map(i) for i in hotwords]
        # import pdb; pdb.set_trace()
        hotword_int.append(torch.tensor([1]))
        hotwords = pad_list(hotword_int, pad_value=0, max_len=10)
        return hotwords, hotwords_length
    def bb_infer(self, feats: np.ndarray,
              feats_len: np.ndarray, bias_embed) -> Tuple[np.ndarray, np.ndarray]:
        outputs = self.ort_infer_bb([feats, feats_len, bias_embed])
        return outputs
    def eb_infer(self, hotwords, hotwords_length):
        outputs = self.ort_infer_eb([hotwords.to(torch.int32).numpy(), hotwords_length.to(torch.int32).numpy()])
        return outputs
    def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
        return [self.decode_one(am_score, token_num)
                for am_score, token_num in zip(am_scores, token_nums)]
    def decode_one(self,
                   am_score: np.ndarray,
                   valid_token_num: int) -> List[str]:
        yseq = am_score.argmax(axis=-1)
        score = am_score.max(axis=-1)
        score = np.sum(score, axis=-1)
        # pad with mask tokens to ensure compatibility with sos/eos tokens
        # asr_model.sos:1  asr_model.eos:2
        yseq = np.array([1] + yseq.tolist() + [2])
        hyp = Hypothesis(yseq=yseq, score=score)
        # remove sos/eos and get results
        last_pos = -1
        token_int = hyp.yseq[1:last_pos].tolist()
        # remove blank symbol id, which is assumed to be 0
        token_int = list(filter(lambda x: x not in (0, 2), token_int))
        # Change integer-ids to tokens
        token = self.converter.ids2tokens(token_int)
        token = token[:valid_token_num-self.pred_bias]
        # texts = sentence_postprocess(token)
        return token
funasr/runtime/python/onnxruntime/funasr_onnx/utils/utils.py
@@ -7,6 +7,7 @@
from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union
import re
import torch
import numpy as np
import yaml
try:
@@ -22,6 +23,52 @@
logger_initialized = {}
def pad_list(xs, pad_value, max_len=None):
    n_batch = len(xs)
    if max_len is None:
        max_len = max(x.size(0) for x in xs)
    pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value)
    for i in range(n_batch):
        pad[i, : xs[i].size(0)] = xs[i]
    return pad
def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None):
    if length_dim == 0:
        raise ValueError("length_dim cannot be 0: {}".format(length_dim))
    if not isinstance(lengths, list):
        lengths = lengths.tolist()
    bs = int(len(lengths))
    if maxlen is None:
        if xs is None:
            maxlen = int(max(lengths))
        else:
            maxlen = xs.size(length_dim)
    else:
        assert xs is None
        assert maxlen >= int(max(lengths))
    seq_range = torch.arange(0, maxlen, dtype=torch.int64)
    seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen)
    seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1)
    mask = seq_range_expand >= seq_length_expand
    if xs is not None:
        assert xs.size(0) == bs, (xs.size(0), bs)
        if length_dim < 0:
            length_dim = xs.dim() + length_dim
        # ind = (:, None, ..., None, :, , None, ..., None)
        ind = tuple(
            slice(None) if i in (0, length_dim) else None for i in range(xs.dim())
        )
        mask = mask[ind].expand_as(xs).to(xs.device)
    return mask
class TokenIDConverter():
    def __init__(self, token_list: Union[List, str],
                 ):
funasr/runtime/websocket/CMakeLists.txt
@@ -57,6 +57,7 @@
# install openssl first apt-get install libssl-dev
find_package(OpenSSL REQUIRED)
#message("CXX_FLAGS "${CMAKE_CXX_FLAGS})
add_executable(funasr-wss-server "funasr-wss-server.cpp" "websocket-server.cpp")
add_executable(funasr-wss-server-2pass "funasr-wss-server-2pass.cpp" "websocket-server-2pass.cpp")
add_executable(funasr-wss-client "funasr-wss-client.cpp")
funasr/runtime/websocket/funasr-wss-client.cpp
@@ -32,9 +32,9 @@
 */
void WaitABit() {
    #ifdef WIN32
        Sleep(1000);
        Sleep(500);
    #else
        sleep(1);
        usleep(500);
    #endif
}
std::atomic<int> wav_index(0);
@@ -108,8 +108,10 @@
            case websocketpp::frame::opcode::text:
                total_num=total_num+1;
                LOG(INFO)<< "Thread: " << this_thread::get_id() <<",on_message = " << payload;
                LOG(INFO) << "total_num=" << total_num << " wav_index=" <<wav_index;
                if((total_num+1)==wav_index)
                {
                    LOG(INFO) << "close client";
                    websocketpp::lib::error_code ec;
                    m_client.close(m_hdl, websocketpp::close::status::going_away, "", ec);
                    if (ec){
@@ -120,7 +122,7 @@
    }
    // This method will block until the connection is complete  
    void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids) {
    void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids, std::string hotwords) {
        // Create a new connection to the given URI
        websocketpp::lib::error_code ec;
        typename websocketpp::client<T>::connection_ptr con =
@@ -141,12 +143,16 @@
        // Create a thread to run the ASIO io_service event loop
        websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
                                            &m_client);
        bool send_hotword = true;
        while(true){
            int i = wav_index.fetch_add(1);
            if (i >= wav_list.size()) {
                break;
            }
            send_wav_data(wav_list[i], wav_ids[i]);
            send_wav_data(wav_list[i], wav_ids[i], hotwords, send_hotword);
            if(send_hotword){
                send_hotword = false;
            }
        }
        WaitABit(); 
@@ -181,7 +187,7 @@
        m_done = true;
    }
    // send wav to server
    void send_wav_data(string wav_path, string wav_id) {
    void send_wav_data(string wav_path, string wav_id, string hotwords, bool send_hotword) {
        uint64_t count = 0;
        std::stringstream val;
@@ -237,6 +243,10 @@
        jsonbegin["wav_name"] = wav_id;
        jsonbegin["wav_format"] = wav_format;
        jsonbegin["is_speaking"] = true;
        if(send_hotword){
            LOG(INFO) << "hotwords: "<< hotwords;
            jsonbegin["hotwords"] = hotwords;
        }
        m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text,
                      ec);
@@ -311,7 +321,7 @@
        jsonresult["is_speaking"] = false;
        m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
                      ec);
        // WaitABit();
        std::this_thread::sleep_for(std::chrono::milliseconds(100));
    }
    websocketpp::client<T> m_client;
@@ -340,12 +350,14 @@
    TCLAP::ValueArg<int> is_ssl_(
        "", "is-ssl", "is-ssl is 1 means use wss connection, or use ws connection", 
        false, 1, "int");
    TCLAP::ValueArg<std::string> hotword_("", HOTWORD, "*.txt(one hotword perline) or hotwords seperate by space (could be: 阿里巴巴 达摩院)", false, "", "string");
    cmd.add(server_ip_);
    cmd.add(port_);
    cmd.add(wav_path_);
    cmd.add(thread_num_);
    cmd.add(is_ssl_);
    cmd.add(hotword_);
    cmd.parse(argc, argv);
    std::string server_ip = server_ip_.getValue();
@@ -361,6 +373,27 @@
    } else {
        uri = "ws://" + server_ip + ":" + port;
    }
    // read hotwords
    std::string hotword = hotword_.getValue();
    std::string hotwords_;
    if(IsTargetFile(hotword, "txt")){
        ifstream in(hotword);
        if (!in.is_open()) {
            LOG(ERROR) << "Failed to open file: " <<  hotword;
            return 0;
        }
        string line;
        while(getline(in, line))
        {
            hotwords_ +=line+HOTWORD_SEP;
        }
        in.close();
    }else{
        hotwords_ = hotword;
    }
    // read wav_path
    std::vector<string> wav_list;
@@ -388,17 +421,17 @@
    }
    
    for (size_t i = 0; i < threads_num; i++) {
        client_threads.emplace_back([uri, wav_list, wav_ids, is_ssl]() {
        client_threads.emplace_back([uri, wav_list, wav_ids, is_ssl, hotwords_]() {
          if (is_ssl == 1) {
            WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
            c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
            c.run(uri, wav_list, wav_ids);
            c.run(uri, wav_list, wav_ids, hotwords_);
          } else {
            WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
            c.run(uri, wav_list, wav_ids);
            c.run(uri, wav_list, wav_ids, hotwords_);
          }
        });
    }
funasr/runtime/websocket/funasr-wss-server-2pass.cpp
@@ -368,6 +368,7 @@
    server server_;  // server for websocket
    wss_server wss_server_;
    if (is_ssl) {
      LOG(INFO)<< "SSL is opened!";
      wss_server_.init_asio(&io_server);  // init asio
      wss_server_.set_reuse_addr(
          true);  // reuse address as we create multiple threads
@@ -380,6 +381,7 @@
      websocket_srv.initAsr(model_path, s_model_thread_num);  // init asr model
    } else {
      LOG(INFO)<< "SSL is closed!";
      server_.init_asio(&io_server);  // init asio
      server_.set_reuse_addr(
          true);  // reuse address as we create multiple threads
funasr/runtime/websocket/funasr-wss-server.cpp
@@ -180,6 +180,18 @@
                python_cmd_asr = python_cmd + " --model-name " + s_asr_path + " --export-dir ./ " + " --model_revision " + model_path["model-revision"];
                down_asr_path  = s_asr_path;
            }else{
                size_t found = s_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404");
                if (found != std::string::npos) {
                    model_path["model-revision"]="v1.2.4";
                }else{
                    found = s_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404");
                    if (found != std::string::npos) {
                        model_path["model-revision"]="v1.0.3";
                        model_path[QUANTIZE]=false;
                        s_asr_quant = false;
                    }
                }
                // modelscope
                LOG(INFO) << "Download model: " <<  s_asr_path << " from modelscope: ";
                python_cmd_asr = python_cmd + " --model-name " + s_asr_path + " --export-dir " + s_download_model_dir + " --model_revision " + model_path["model-revision"];
@@ -278,6 +290,7 @@
    server server_;  // server for websocket
    wss_server wss_server_;
    if (is_ssl) {
      LOG(INFO)<< "SSL is opened!";
      wss_server_.init_asio(&io_server);  // init asio
      wss_server_.set_reuse_addr(
          true);  // reuse address as we create multiple threads
@@ -290,6 +303,7 @@
      websocket_srv.initAsr(model_path, s_model_thread_num);  // init asr model
    } else {
      LOG(INFO)<< "SSL is closed!";
      server_.init_asio(&io_server);  // init asio
      server_.set_reuse_addr(
          true);  // reuse address as we create multiple threads
funasr/runtime/websocket/readme.md
@@ -169,7 +169,6 @@
```
API-reference:
```text
--server-ip: The IP address of the machine where FunASR runtime-SDK service is deployed. The default value is the IP address of the local machine (127.0.0.1). If the client and service are not on the same server, it needs to be changed to the IP address of the deployment machine.
--port: The port number of the deployed service is 10095.
funasr/runtime/websocket/websocket-server-2pass.cpp
funasr/runtime/websocket/websocket-server.cpp
@@ -56,25 +56,37 @@
// feed buffer to asr engine for decoder
void WebSocketServer::do_decoder(const std::vector<char>& buffer,
                                 websocketpp::connection_hdl& hdl,
                                 const nlohmann::json& msg) {
                                 websocketpp::lib::mutex& thread_lock,
                                 std::vector<std::vector<float>> &hotwords_embedding,
                                 std::string wav_name,
                                 std::string wav_format) {
  scoped_lock guard(thread_lock);
  try {
    int num_samples = buffer.size();  // the size of the buf
    if (!buffer.empty()) {
      // feed data to asr engine
    if (!buffer.empty() && hotwords_embedding.size() >0 ) {
      std::string asr_result;
      std::string stamp_res;
      try{
      FUNASR_RESULT Result = FunOfflineInferBuffer(
          asr_hanlde, buffer.data(), buffer.size(), RASR_NONE, NULL, 16000, msg["wav_format"]);
            asr_hanlde, buffer.data(), buffer.size(), RASR_NONE, NULL, hotwords_embedding, 16000, wav_format);
      std::string asr_result =
          ((FUNASR_RECOG_RESULT*)Result)->msg;  // get decode result
        asr_result = ((FUNASR_RECOG_RESULT*)Result)->msg;  // get decode result
        stamp_res = ((FUNASR_RECOG_RESULT*)Result)->stamp;
      FunASRFreeResult(Result);
      }catch (std::exception const& e) {
        LOG(ERROR) << e.what();
        return;
      }
      websocketpp::lib::error_code ec;
      nlohmann::json jsonresult;        // result json
      jsonresult["text"] = asr_result;  // put result in 'text'
      jsonresult["mode"] = "offline";
      jsonresult["wav_name"] = msg["wav_name"];
      if(stamp_res != ""){
        jsonresult["timestamp"] = stamp_res;
      }
      jsonresult["wav_name"] = wav_name;
      // send the json to client
      if (is_ssl) {
@@ -86,11 +98,6 @@
      }
      LOG(INFO) << "buffer.size=" << buffer.size() << ",result json=" << jsonresult.dump();
      if (!isonline) {
        //  close the client if it is not online asr
        // server_->close(hdl, websocketpp::close::status::normal, "DONE", ec);
        // fout.close();
      }
    }
  } catch (std::exception const& e) {
@@ -100,12 +107,11 @@
void WebSocketServer::on_open(websocketpp::connection_hdl hdl) {
  scoped_lock guard(m_lock);     // for threads safty
  check_and_clean_connection();  // remove closed connection
  std::shared_ptr<FUNASR_MESSAGE> data_msg =
      std::make_shared<FUNASR_MESSAGE>();  // put a new data vector for new
                                           // connection
  data_msg->samples = std::make_shared<std::vector<char>>();
  data_msg->thread_lock = std::make_shared<websocketpp::lib::mutex>();
  data_msg->msg = nlohmann::json::parse("{}");
  data_msg->msg["wav_format"] = "pcm";
  data_map.emplace(hdl, data_msg);
@@ -114,18 +120,50 @@
void WebSocketServer::on_close(websocketpp::connection_hdl hdl) {
  scoped_lock guard(m_lock);
  data_map.erase(hdl);  // remove data vector when  connection is closed
  std::shared_ptr<FUNASR_MESSAGE> data_msg = nullptr;
  auto it_data = data_map.find(hdl);
  if (it_data != data_map.end()) {
    data_msg = it_data->second;
  } else {
    return;
  }
  unique_lock guard_decoder(*(data_msg->thread_lock));
  data_msg->msg["is_eof"]=true;
  guard_decoder.unlock();
  // data_map.erase(hdl);  // remove data vector when  connection is closed
  LOG(INFO) << "on_close, active connections: " << data_map.size();
}
// remove closed connection
void remove_hdl(
    websocketpp::connection_hdl hdl,
    std::map<websocketpp::connection_hdl, std::shared_ptr<FUNASR_MESSAGE>,
             std::owner_less<websocketpp::connection_hdl>>& data_map) {
  std::shared_ptr<FUNASR_MESSAGE> data_msg = nullptr;
  auto it_data = data_map.find(hdl);
  if (it_data != data_map.end()) {
    data_msg = it_data->second;
  } else {
    return;
  }
  unique_lock guard_decoder(*(data_msg->thread_lock));
  if (data_msg->msg["is_eof"]==true) {
      data_map.erase(hdl);
    LOG(INFO) << "remove one connection";
  }
  guard_decoder.unlock();
}
void WebSocketServer::check_and_clean_connection() {
  while(true){
    std::this_thread::sleep_for(std::chrono::milliseconds(5000));
  std::vector<websocketpp::connection_hdl> to_remove;  // remove list
  auto iter = data_map.begin();
  while (iter != data_map.end()) {  // loop to find closed connection
    websocketpp::connection_hdl hdl = iter->first;
      try{
    if (is_ssl) {
      wss_server::connection_ptr con = wss_server_->get_con_from_hdl(hdl);
      if (con->get_state() != 1) {  // session::state::open ==1
@@ -137,14 +175,33 @@
        to_remove.push_back(hdl);
      }
    }
      }
      catch (std::exception const &e)
      {
        // if connection is close, we set is_eof = true
        std::shared_ptr<FUNASR_MESSAGE> data_msg = nullptr;
        auto it_data = data_map.find(hdl);
        if (it_data != data_map.end()) {
          data_msg = it_data->second;
        } else {
            continue;
        }
        unique_lock guard_decoder(*(data_msg->thread_lock));
        data_msg->msg["is_eof"]=true;
        guard_decoder.unlock();
        to_remove.push_back(hdl);
        LOG(INFO)<<"connection is closed: "<<e.what();
      }
    iter++;
  }
  for (auto hdl : to_remove) {
    data_map.erase(hdl);
    LOG(INFO)<< "remove one connection ";
      remove_hdl(hdl, data_map);
      //LOG(INFO) << "remove one connection ";
  }
}
}
void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
                                 message_ptr msg) {
  unique_lock lock(m_lock);
@@ -157,6 +214,7 @@
    msg_data = it_data->second;
  }
  std::shared_ptr<std::vector<char>> sample_data_p = msg_data->samples;
  std::shared_ptr<websocketpp::lib::mutex> thread_lock_p = msg_data->thread_lock;
  lock.unlock();
  if (sample_data_p == nullptr) {
@@ -165,7 +223,7 @@
  }
  const std::string& payload = msg->get_payload();  // get msg type
  unique_lock guard_decoder(*(thread_lock_p)); // mutex for one connection
  switch (msg->get_opcode()) {
    case websocketpp::frame::opcode::text: {
      nlohmann::json jsonresult = nlohmann::json::parse(payload);
@@ -175,24 +233,42 @@
      if (jsonresult["wav_format"] != nullptr) {
        msg_data->msg["wav_format"] = jsonresult["wav_format"];
      }
      if(msg_data->hotwords_embedding == NULL){
        if (jsonresult["hotwords"] != nullptr) {
          msg_data->msg["hotwords"] = jsonresult["hotwords"];
          if (!msg_data->msg["hotwords"].empty()) {
            std::string hw = msg_data->msg["hotwords"];
            LOG(INFO)<<"hotwords: " << hw;
            std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw);
            msg_data->hotwords_embedding =
                std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
          }
        }else{
            std::string hw = "";
            LOG(INFO)<<"hotwords: " << hw;
            std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw);
            msg_data->hotwords_embedding =
                std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
        }
      }
      if (jsonresult["is_speaking"] == false ||
          jsonresult["is_finished"] == true) {
        LOG(INFO) << "client done";
        if (isonline) {
          // do_close(ws);
        } else {
          // add padding to the end of the wav data
          // std::vector<short> padding(static_cast<short>(0.3 * 16000));
          // sample_data_p->insert(sample_data_p->end(), padding.data(),
          //                       padding.data() + padding.size());
          // for offline, send all receive data to decoder engine
        std::vector<std::vector<float>> hotwords_embedding_(*(msg_data->hotwords_embedding));
          asio::post(io_decoder_,
                     std::bind(&WebSocketServer::do_decoder, this,
                               std::move(*(sample_data_p.get())),
                               std::move(hdl), std::move(msg_data->msg)));
        }
                              std::move(hdl),
                              std::ref(*thread_lock_p),
                              std::move(hotwords_embedding_),
                              msg_data->msg["wav_name"],
                              msg_data->msg["wav_format"]));
      }
      break;
    }
@@ -200,19 +276,15 @@
      // recived binary data
      const auto* pcm_data = static_cast<const char*>(payload.data());
      int32_t num_samples = payload.size();
      //LOG(INFO) << "recv binary num_samples " << num_samples;
      if (isonline) {
        // if online TODO(zhaoming) still not done
        std::vector<char> s(pcm_data, pcm_data + num_samples);
        asio::post(io_decoder_,
                   std::bind(&WebSocketServer::do_decoder, this, std::move(s),
                             std::move(hdl), std::move(msg_data->msg)));
        // TODO
      } else {
        // for offline, we add receive data to end of the sample data vector
        sample_data_p->insert(sample_data_p->end(), pcm_data,
                              pcm_data + num_samples);
      }
      break;
    }
    default:
@@ -229,6 +301,11 @@
    asr_hanlde = FunOfflineInit(model_path, thread_num);
    LOG(INFO) << "model successfully inited";
    LOG(INFO) << "initAsr run check_and_clean_connection";
    std::thread clean_thread(&WebSocketServer::check_and_clean_connection,this);
    clean_thread.detach();
    LOG(INFO) << "initAsr run check_and_clean_connection finished";
  } catch (const std::exception& e) {
    LOG(INFO) << e.what();
  }
funasr/runtime/websocket/websocket-server.h
@@ -46,13 +46,17 @@
    context_ptr;
typedef struct {
  std::string msg;
  float snippet_time;
    std::string msg="";
    std::string stamp="";
    std::string tpass_msg="";
    float snippet_time=0;
} FUNASR_RECOG_RESULT;
typedef struct {
  nlohmann::json msg;
  std::shared_ptr<std::vector<char>> samples;
  std::shared_ptr<std::vector<std::vector<float>>> hotwords_embedding=NULL;
  std::shared_ptr<websocketpp::lib::mutex> thread_lock; // lock for each connection
} FUNASR_MESSAGE;
// See https://wiki.mozilla.org/Security/Server_Side_TLS for more details about
@@ -106,7 +110,10 @@
    }
  }
  void do_decoder(const std::vector<char>& buffer,
                  websocketpp::connection_hdl& hdl, const nlohmann::json& msg);
                  websocketpp::connection_hdl& hdl,
                  websocketpp::lib::mutex& thread_lock,
                  std::vector<std::vector<float>> &hotwords_embedding,
                  std::string wav_name, std::string wav_format);
  void initAsr(std::map<std::string, std::string>& model_path, int thread_num);
  void on_message(websocketpp::connection_hdl hdl, message_ptr msg);