Dev hw (#878)
* merge from hw (#872)
* hotwords
* Contextual Paraformer onnx export
* update
* update
* quant inference
* add clas hotword support
* update websocket-server
* update websocket-server
* add catch for hotword
* update websocket-server
* update paraformer
* update websocket-server
* add wait for funasr-wss-client
* fix core by adding clean_thread
* fix wav_name
* update funasr-wss-client
* update websocket-server
* Update SDK_tutorial_online_zh.md
---------
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>
* Update websocket_protocol_zh.md
* Update websocket_protocol.md
* Update SDK_tutorial_zh.md
* Update SDK_tutorial.md
* Update SDK_advanced_guide_online_zh.md
* Update SDK_advanced_guide_online.md
* Update SDK_advanced_guide_offline_zh.md
* Update SDK_advanced_guide_offline_zh.md
* Update SDK_advanced_guide_offline.md
* Update SDK_advanced_guide_offline.md
* Update docker_offline_cpu_zh_lists
* update docs
* update
---------
Co-authored-by: shixian.shi <shixian.shi@alibaba-inc.com>
| | |
| | | import json |
| | | from typing import Union, Dict |
| | | from pathlib import Path |
| | | |
| | | import os |
| | | import logging |
| | | import torch |
| | | |
| | | from funasr.export.models import get_model |
| | | import numpy as np |
| | | import random |
| | | import logging |
| | | import numpy as np |
| | | from pathlib import Path |
| | | from typing import Union, Dict, List |
| | | from funasr.export.models import get_model |
| | | from funasr.utils.types import str2bool, str2triple_str |
| | | # torch_version = float(".".join(torch.__version__.split(".")[:2])) |
| | | # assert torch_version > 1.9 |
| | |
| | | |
| | | # export encoder1 |
| | | self.export_config["model_name"] = "model" |
| | | models = get_model( |
| | | model = get_model( |
| | | model, |
| | | self.export_config, |
| | | ) |
| | | if not isinstance(models, tuple): |
| | | models = (models,) |
| | | |
| | | for i, model in enumerate(models): |
| | | if isinstance(model, List): |
| | | for m in model: |
| | | m.eval() |
| | | if self.onnx: |
| | | self._export_onnx(m, verbose, export_dir) |
| | | else: |
| | | self._export_torchscripts(m, verbose, export_dir) |
| | | print("output dir: {}".format(export_dir)) |
| | | else: |
| | | model.eval() |
| | | # self._export_onnx(model, verbose, export_dir) |
| | | if self.onnx: |
| | | self._export_onnx(model, verbose, export_dir) |
| | | else: |
| | | self._export_torchscripts(model, verbose, export_dir) |
| | | |
| | | print("output dir: {}".format(export_dir)) |
| | | |
| | | |
| | |
| | | # model_script = torch.jit.script(model) |
| | | model_script = model #torch.jit.trace(model) |
| | | model_path = os.path.join(path, f'{model.model_name}.onnx') |
| | | if not os.path.exists(model_path): |
| | | # if not os.path.exists(model_path): |
| | | torch.onnx.export( |
| | | model_script, |
| | | dummy_input, |
| | |
| | | from funasr.export.models.CT_Transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export |
| | | from funasr.export.models.e2e_asr_paraformer import ParaformerOnline_encoder_predictor as ParaformerOnline_encoder_predictor_export |
| | | from funasr.export.models.e2e_asr_paraformer import ParaformerOnline_decoder as ParaformerOnline_decoder_export |
| | | from funasr.export.models.e2e_asr_contextual_paraformer import ContextualParaformer_backbone as ContextualParaformer_backbone_export |
| | | from funasr.export.models.e2e_asr_contextual_paraformer import ContextualParaformer_embedder as ContextualParaformer_embedder_export |
| | | from funasr.models.e2e_asr_contextual_paraformer import NeatContextualParaformer |
| | | |
| | | |
| | | def get_model(model, export_config=None): |
| | | if isinstance(model, BiCifParaformer): |
| | | if isinstance(model, NeatContextualParaformer): |
| | | backbone = ContextualParaformer_backbone_export(model, **export_config) |
| | | embedder = ContextualParaformer_embedder_export(model, **export_config) |
| | | return [embedder, backbone] |
| | | elif isinstance(model, BiCifParaformer): |
| | | return BiCifParaformer_export(model, **export_config) |
| | | elif isinstance(model, ParaformerOnline): |
| | | return (ParaformerOnline_encoder_predictor_export(model, model_name="model"), |
| New file |
| | |
| | | import os |
| | | import torch |
| | | import torch.nn as nn |
| | | |
| | | from funasr.export.utils.torch_function import MakePadMask |
| | | from funasr.export.utils.torch_function import sequence_mask |
| | | from funasr.modules.attention import MultiHeadedAttentionSANMDecoder |
| | | from funasr.export.models.modules.multihead_att import MultiHeadedAttentionSANMDecoder as MultiHeadedAttentionSANMDecoder_export |
| | | from funasr.modules.attention import MultiHeadedAttentionCrossAtt |
| | | from funasr.export.models.modules.multihead_att import MultiHeadedAttentionCrossAtt as MultiHeadedAttentionCrossAtt_export |
| | | from funasr.modules.positionwise_feed_forward import PositionwiseFeedForwardDecoderSANM |
| | | from funasr.export.models.modules.feedforward import PositionwiseFeedForwardDecoderSANM as PositionwiseFeedForwardDecoderSANM_export |
| | | from funasr.export.models.modules.decoder_layer import DecoderLayerSANM as DecoderLayerSANM_export |
| | | |
| | | |
| | | class ContextualSANMDecoder(nn.Module): |
| | | def __init__(self, model, |
| | | max_seq_len=512, |
| | | model_name='decoder', |
| | | onnx: bool = True,): |
| | | super().__init__() |
| | | # self.embed = model.embed #Embedding(model.embed, max_seq_len) |
| | | self.model = model |
| | | if onnx: |
| | | self.make_pad_mask = MakePadMask(max_seq_len, flip=False) |
| | | else: |
| | | self.make_pad_mask = sequence_mask(max_seq_len, flip=False) |
| | | |
| | | for i, d in enumerate(self.model.decoders): |
| | | if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM): |
| | | d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward) |
| | | if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder): |
| | | d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn) |
| | | if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt): |
| | | d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn) |
| | | self.model.decoders[i] = DecoderLayerSANM_export(d) |
| | | |
| | | if self.model.decoders2 is not None: |
| | | for i, d in enumerate(self.model.decoders2): |
| | | if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM): |
| | | d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward) |
| | | if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder): |
| | | d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn) |
| | | self.model.decoders2[i] = DecoderLayerSANM_export(d) |
| | | |
| | | for i, d in enumerate(self.model.decoders3): |
| | | if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM): |
| | | d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward) |
| | | self.model.decoders3[i] = DecoderLayerSANM_export(d) |
| | | |
| | | self.output_layer = model.output_layer |
| | | self.after_norm = model.after_norm |
| | | self.model_name = model_name |
| | | |
| | | # bias decoder |
| | | if isinstance(self.model.bias_decoder.src_attn, MultiHeadedAttentionCrossAtt): |
| | | self.model.bias_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.bias_decoder.src_attn) |
| | | self.bias_decoder = self.model.bias_decoder |
| | | # last decoder |
| | | if isinstance(self.model.last_decoder.src_attn, MultiHeadedAttentionCrossAtt): |
| | | self.model.last_decoder.src_attn = MultiHeadedAttentionCrossAtt_export(self.model.last_decoder.src_attn) |
| | | if isinstance(self.model.last_decoder.self_attn, MultiHeadedAttentionSANMDecoder): |
| | | self.model.last_decoder.self_attn = MultiHeadedAttentionSANMDecoder_export(self.model.last_decoder.self_attn) |
| | | if isinstance(self.model.last_decoder.feed_forward, PositionwiseFeedForwardDecoderSANM): |
| | | self.model.last_decoder.feed_forward = PositionwiseFeedForwardDecoderSANM_export(self.model.last_decoder.feed_forward) |
| | | self.last_decoder = self.model.last_decoder |
| | | self.bias_output = self.model.bias_output |
| | | self.dropout = self.model.dropout |
| | | |
| | | |
| | | def prepare_mask(self, mask): |
| | | mask_3d_btd = mask[:, :, None] |
| | | if len(mask.shape) == 2: |
| | | mask_4d_bhlt = 1 - mask[:, None, None, :] |
| | | elif len(mask.shape) == 3: |
| | | mask_4d_bhlt = 1 - mask[:, None, :] |
| | | mask_4d_bhlt = mask_4d_bhlt * -10000.0 |
| | | |
| | | return mask_3d_btd, mask_4d_bhlt |
| | | |
| | | def forward( |
| | | self, |
| | | hs_pad: torch.Tensor, |
| | | hlens: torch.Tensor, |
| | | ys_in_pad: torch.Tensor, |
| | | ys_in_lens: torch.Tensor, |
| | | bias_embed: torch.Tensor, |
| | | ): |
| | | |
| | | tgt = ys_in_pad |
| | | tgt_mask = self.make_pad_mask(ys_in_lens) |
| | | tgt_mask, _ = self.prepare_mask(tgt_mask) |
| | | # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None] |
| | | |
| | | memory = hs_pad |
| | | memory_mask = self.make_pad_mask(hlens) |
| | | _, memory_mask = self.prepare_mask(memory_mask) |
| | | # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :] |
| | | |
| | | x = tgt |
| | | x, tgt_mask, memory, memory_mask, _ = self.model.decoders( |
| | | x, tgt_mask, memory, memory_mask |
| | | ) |
| | | |
| | | _, _, x_self_attn, x_src_attn = self.last_decoder( |
| | | x, tgt_mask, memory, memory_mask |
| | | ) |
| | | |
| | | # contextual paraformer related |
| | | contextual_length = torch.Tensor([bias_embed.shape[1]]).int().repeat(hs_pad.shape[0]) |
| | | # contextual_mask = myutils.sequence_mask(contextual_length, device=memory.device)[:, None, :] |
| | | contextual_mask = self.make_pad_mask(contextual_length) |
| | | contextual_mask, _ = self.prepare_mask(contextual_mask) |
| | | # import pdb; pdb.set_trace() |
| | | contextual_mask = contextual_mask.transpose(2, 1).unsqueeze(1) |
| | | cx, tgt_mask, _, _, _ = self.bias_decoder(x_self_attn, tgt_mask, bias_embed, memory_mask=contextual_mask) |
| | | |
| | | if self.bias_output is not None: |
| | | x = torch.cat([x_src_attn, cx], dim=2) |
| | | x = self.bias_output(x.transpose(1, 2)).transpose(1, 2) # 2D -> D |
| | | x = x_self_attn + self.dropout(x) |
| | | |
| | | if self.model.decoders2 is not None: |
| | | x, tgt_mask, memory, memory_mask, _ = self.model.decoders2( |
| | | x, tgt_mask, memory, memory_mask |
| | | ) |
| | | x, tgt_mask, memory, memory_mask, _ = self.model.decoders3( |
| | | x, tgt_mask, memory, memory_mask |
| | | ) |
| | | x = self.after_norm(x) |
| | | x = self.output_layer(x) |
| | | |
| | | return x, ys_in_lens |
| | | |
| | | |
| | | def get_dummy_inputs(self, enc_size): |
| | | tgt = torch.LongTensor([0]).unsqueeze(0) |
| | | memory = torch.randn(1, 100, enc_size) |
| | | pre_acoustic_embeds = torch.randn(1, 1, enc_size) |
| | | cache_num = len(self.model.decoders) + len(self.model.decoders2) |
| | | cache = [ |
| | | torch.zeros((1, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size)) |
| | | for _ in range(cache_num) |
| | | ] |
| | | return (tgt, memory, pre_acoustic_embeds, cache) |
| | | |
| | | def is_optimizable(self): |
| | | return True |
| | | |
| | | def get_input_names(self): |
| | | cache_num = len(self.model.decoders) + len(self.model.decoders2) |
| | | return ['tgt', 'memory', 'pre_acoustic_embeds'] \ |
| | | + ['cache_%d' % i for i in range(cache_num)] |
| | | |
| | | def get_output_names(self): |
| | | cache_num = len(self.model.decoders) + len(self.model.decoders2) |
| | | return ['y'] \ |
| | | + ['out_cache_%d' % i for i in range(cache_num)] |
| | | |
| | | def get_dynamic_axes(self): |
| | | ret = { |
| | | 'tgt': { |
| | | 0: 'tgt_batch', |
| | | 1: 'tgt_length' |
| | | }, |
| | | 'memory': { |
| | | 0: 'memory_batch', |
| | | 1: 'memory_length' |
| | | }, |
| | | 'pre_acoustic_embeds': { |
| | | 0: 'acoustic_embeds_batch', |
| | | 1: 'acoustic_embeds_length', |
| | | } |
| | | } |
| | | cache_num = len(self.model.decoders) + len(self.model.decoders2) |
| | | ret.update({ |
| | | 'cache_%d' % d: { |
| | | 0: 'cache_%d_batch' % d, |
| | | 2: 'cache_%d_length' % d |
| | | } |
| | | for d in range(cache_num) |
| | | }) |
| | | return ret |
| | | |
| | | def get_model_config(self, path): |
| | | return { |
| | | "dec_type": "XformerDecoder", |
| | | "model_path": os.path.join(path, f'{self.model_name}.onnx'), |
| | | "n_layers": len(self.model.decoders) + len(self.model.decoders2), |
| | | "odim": self.model.decoders[0].size |
| | | } |
| New file |
| | |
| | | from audioop import bias |
| | | import logging |
| | | import torch |
| | | import torch.nn as nn |
| | | import numpy as np |
| | | |
| | | from funasr.export.utils.torch_function import MakePadMask |
| | | from funasr.export.utils.torch_function import sequence_mask |
| | | from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt |
| | | from funasr.models.encoder.conformer_encoder import ConformerEncoder |
| | | from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export |
| | | from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export |
| | | from funasr.models.predictor.cif import CifPredictorV2 |
| | | from funasr.export.models.predictor.cif import CifPredictorV2 as CifPredictorV2_export |
| | | from funasr.models.decoder.sanm_decoder import ParaformerSANMDecoder |
| | | from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN |
| | | from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export |
| | | from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSAN as ParaformerDecoderSAN_export |
| | | from funasr.export.models.decoder.contextual_decoder import ContextualSANMDecoder as ContextualSANMDecoder_export |
| | | from funasr.models.decoder.contextual_decoder import ContextualParaformerDecoder |
| | | |
| | | |
| | | class ContextualParaformer_backbone(nn.Module): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition |
| | | https://arxiv.org/abs/2206.08317 |
| | | """ |
| | | |
| | | def __init__( |
| | | self, |
| | | model, |
| | | max_seq_len=512, |
| | | feats_dim=560, |
| | | model_name='model', |
| | | **kwargs, |
| | | ): |
| | | super().__init__() |
| | | onnx = False |
| | | if "onnx" in kwargs: |
| | | onnx = kwargs["onnx"] |
| | | if isinstance(model.encoder, SANMEncoder): |
| | | self.encoder = SANMEncoder_export(model.encoder, onnx=onnx) |
| | | elif isinstance(model.encoder, ConformerEncoder): |
| | | self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx) |
| | | if isinstance(model.predictor, CifPredictorV2): |
| | | self.predictor = CifPredictorV2_export(model.predictor) |
| | | |
| | | # decoder |
| | | if isinstance(model.decoder, ContextualParaformerDecoder): |
| | | self.decoder = ContextualSANMDecoder_export(model.decoder, onnx=onnx) |
| | | elif isinstance(model.decoder, ParaformerSANMDecoder): |
| | | self.decoder = ParaformerSANMDecoder_export(model.decoder, onnx=onnx) |
| | | elif isinstance(model.decoder, ParaformerDecoderSAN): |
| | | self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx) |
| | | |
| | | self.feats_dim = feats_dim |
| | | self.model_name = model_name |
| | | |
| | | if onnx: |
| | | self.make_pad_mask = MakePadMask(max_seq_len, flip=False) |
| | | else: |
| | | self.make_pad_mask = sequence_mask(max_seq_len, flip=False) |
| | | |
| | | def forward( |
| | | self, |
| | | speech: torch.Tensor, |
| | | speech_lengths: torch.Tensor, |
| | | bias_embed: torch.Tensor, |
| | | ): |
| | | # a. To device |
| | | batch = {"speech": speech, "speech_lengths": speech_lengths} |
| | | # batch = to_device(batch, device=self.device) |
| | | |
| | | enc, enc_len = self.encoder(**batch) |
| | | mask = self.make_pad_mask(enc_len)[:, None, :] |
| | | pre_acoustic_embeds, pre_token_length, alphas, pre_peak_index = self.predictor(enc, mask) |
| | | pre_token_length = pre_token_length.floor().type(torch.int32) |
| | | |
| | | # bias_embed = bias_embed. squeeze(0).repeat([enc.shape[0], 1, 1]) |
| | | |
| | | decoder_out, _ = self.decoder(enc, enc_len, pre_acoustic_embeds, pre_token_length, bias_embed) |
| | | decoder_out = torch.log_softmax(decoder_out, dim=-1) |
| | | # sample_ids = decoder_out.argmax(dim=-1) |
| | | return decoder_out, pre_token_length |
| | | |
| | | def get_dummy_inputs(self): |
| | | speech = torch.randn(2, 30, self.feats_dim) |
| | | speech_lengths = torch.tensor([6, 30], dtype=torch.int32) |
| | | bias_embed = torch.randn(2, 1, 512) |
| | | return (speech, speech_lengths, bias_embed) |
| | | |
| | | def get_dummy_inputs_txt(self, txt_file: str = "/mnt/workspace/data_fbank/0207/12345.wav.fea.txt"): |
| | | import numpy as np |
| | | fbank = np.loadtxt(txt_file) |
| | | fbank_lengths = np.array([fbank.shape[0], ], dtype=np.int32) |
| | | speech = torch.from_numpy(fbank[None, :, :].astype(np.float32)) |
| | | speech_lengths = torch.from_numpy(fbank_lengths.astype(np.int32)) |
| | | return (speech, speech_lengths) |
| | | |
| | | def get_input_names(self): |
| | | return ['speech', 'speech_lengths', 'bias_embed'] |
| | | |
| | | def get_output_names(self): |
| | | return ['logits', 'token_num'] |
| | | |
| | | def get_dynamic_axes(self): |
| | | return { |
| | | 'speech': { |
| | | 0: 'batch_size', |
| | | 1: 'feats_length' |
| | | }, |
| | | 'speech_lengths': { |
| | | 0: 'batch_size', |
| | | }, |
| | | 'bias_embed': { |
| | | 0: 'batch_size', |
| | | 1: 'num_hotwords' |
| | | }, |
| | | 'logits': { |
| | | 0: 'batch_size', |
| | | 1: 'logits_length' |
| | | }, |
| | | } |
| | | |
| | | |
| | | class ContextualParaformer_embedder(nn.Module): |
| | | def __init__(self, |
| | | model, |
| | | max_seq_len=512, |
| | | feats_dim=560, |
| | | model_name='model', |
| | | **kwargs,): |
| | | super().__init__() |
| | | self.embedding = model.bias_embed |
| | | model.bias_encoder.batch_first = False |
| | | self.bias_encoder = model.bias_encoder |
| | | # self.bias_encoder.batch_first = False |
| | | self.feats_dim = feats_dim |
| | | self.model_name = "{}_eb".format(model_name) |
| | | |
| | | def forward(self, hotword): |
| | | hotword = self.embedding(hotword).transpose(0, 1) # batch second |
| | | hw_embed, (_, _) = self.bias_encoder(hotword) |
| | | return hw_embed |
| | | |
| | | def get_dummy_inputs(self): |
| | | hotword = torch.tensor([ |
| | | [10, 11, 12, 13, 14, 10, 11, 12, 13, 14], |
| | | [100, 101, 0, 0, 0, 0, 0, 0, 0, 0], |
| | | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| | | [10, 11, 12, 13, 14, 10, 11, 12, 13, 14], |
| | | [100, 101, 0, 0, 0, 0, 0, 0, 0, 0], |
| | | [1, 0, 0, 0, 0, 0, 0, 0, 0, 0], |
| | | ], |
| | | dtype=torch.int32) |
| | | # hotword_length = torch.tensor([10, 2, 1], dtype=torch.int32) |
| | | return (hotword) |
| | | |
| | | def get_input_names(self): |
| | | return ['hotword'] |
| | | |
| | | def get_output_names(self): |
| | | return ['hw_embed'] |
| | | |
| | | def get_dynamic_axes(self): |
| | | return { |
| | | 'hotword': { |
| | | 0: 'num_hotwords', |
| | | }, |
| | | 'hw_embed': { |
| | | 0: 'num_hotwords', |
| | | }, |
| | | } |
| | |
| | | --vad-dir damo/speech_fsmn_vad_zh-cn-16k-common-onnx \ |
| | | --model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \ |
| | | --punc-dir damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx > log.out 2>&1 & |
| | | |
| | | # If you want to close ssl,please add:--certfile 0 |
| | | # If you want to deploy the timestamp or hotword model, please set --model-dir to the corresponding model: |
| | | # speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-onnx(timestamp) |
| | | # damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404-onnx(hotword) |
| | | ``` |
| | | |
| | | More details about the script run_server.sh: |
| | |
| | | --port: Port number that the server listens on. Default is 10095. |
| | | --decoder-thread-num: Number of inference threads that the server starts. Default is 8. |
| | | --io-thread-num: Number of IO threads that the server starts. Default is 1. |
| | | --certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. |
| | | --keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key. |
| | | --certfile <string>: SSL certificate file. Default is ../../../ssl_key/server.crt. If you want to close ssl,set "" |
| | | --keyfile <string>: SSL key file. Default is ../../../ssl_key/server.key. If you want to close ssl,set "" |
| | | ``` |
| | | |
| | | The FunASR-wss-server also supports loading models from a local path (see Preparing Model Resources for detailed instructions on preparing local model resources). Here is an example: |
| | |
| | | --output_dir: the path to the recognition result output. |
| | | --ssl: whether to use SSL encryption. The default is to use SSL. |
| | | --mode: offline mode. |
| | | --hotword If am is hotword model, setting hotword: *.txt(one hotword perline) or hotwords seperate by space (could be: 阿里巴巴 达摩院) |
| | | ``` |
| | | |
| | | ### c++-client |
| | |
| | | --output_dir: the path to the recognition result output. |
| | | --ssl: whether to use SSL encryption. The default is to use SSL. |
| | | --mode: offline mode. |
| | | --hotword If am is hotword model, setting hotword: *.txt(one hotword perline) or hotwords seperate by space (could be: 阿里巴巴 达摩院) |
| | | ``` |
| | | |
| | | ### Custom client |
| | |
| | | |
| | | ```text |
| | | # First communication |
| | | {"mode": "offline", "wav_name": wav_name, "is_speaking": True} |
| | | {"mode": "offline", "wav_name": wav_name, "is_speaking": True, "hotwords": "hotword1|hotword2"} |
| | | # Send wav data |
| | | Bytes data |
| | | # Send end flag |
| | |
| | | --punc-dir damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx > log.out 2>&1 & |
| | | |
| | | # 如果您想关闭ssl,增加参数:--certfile 0 |
| | | # 如果您想使用时间戳或者热词模型进行部署,请设置--model-dir为对应模型: |
| | | # damo/speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404-onnx(时间戳) |
| | | # 或者 damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404-onnx(热词) |
| | | |
| | | ``` |
| | | 服务端详细参数介绍可参考[服务端参数介绍](#服务端参数介绍) |
| | | ### 客户端测试与使用 |
| | |
| | | --port 10095 部署端口号 |
| | | --mode offline表示离线文件转写 |
| | | --audio_in 需要进行转写的音频文件,支持文件路径,文件列表wav.scp |
| | | --output_dir 识别结果保存路径 |
| | | --thread_num 设置并发发送线程数,默认为1 |
| | | --ssl 设置是否开启ssl证书校验,默认1开启,设置为0关闭 |
| | | --hotword 如果模型为热词模型,可以设置热词: *.txt(每行一个热词) 或者空格分隔的热词字符串 (could be: 阿里巴巴 达摩院) |
| | | ``` |
| | | |
| | | ### cpp-client |
| | |
| | | --server-ip 为FunASR runtime-SDK服务部署机器ip,默认为本机ip(127.0.0.1),如果client与服务不在同一台服务器,需要改为部署机器ip |
| | | --port 10095 部署端口号 |
| | | --wav-path 需要进行转写的音频文件,支持文件路径 |
| | | --hotword 如果模型为热词模型,可以设置热词: *.txt(每行一个热词) 或者空格分隔的热词字符串 (could be: 阿里巴巴 达摩院) |
| | | ``` |
| | | |
| | | ### Html网页版 |
| | |
| | | --port 服务端监听的端口号,默认为 10095 |
| | | --decoder-thread-num 服务端启动的推理线程数,默认为 8 |
| | | --io-thread-num 服务端启动的IO线程数,默认为 1 |
| | | --certfile ssl的证书文件,默认为:../../../ssl_key/server.crt |
| | | --keyfile ssl的密钥文件,默认为:../../../ssl_key/server.key |
| | | --certfile ssl的证书文件,默认为:../../../ssl_key/server.crt,如果需要关闭ssl,参数设置为”“ |
| | | --keyfile ssl的密钥文件,默认为:../../../ssl_key/server.key,如果需要关闭ssl,参数设置为”“ |
| | | ``` |
| | | |
| | | funasr-wss-server同时也支持从本地路径加载模型(本地模型资源准备详见[模型资源准备](#模型资源准备))示例如下: |
| | |
| | | --port 服务端监听的端口号,默认为 10095 |
| | | --decoder-thread-num 服务端启动的推理线程数,默认为 8 |
| | | --io-thread-num 服务端启动的IO线程数,默认为 1 |
| | | --certfile ssl的证书文件,默认为:../../../ssl_key/server.crt |
| | | --keyfile ssl的密钥文件,默认为:../../../ssl_key/server.key |
| | | --certfile ssl的证书文件,默认为:../../../ssl_key/server.crt,如果需要关闭ssl,参数设置为”“ |
| | | --keyfile ssl的密钥文件,默认为:../../../ssl_key/server.key,如果需要关闭ssl,参数设置为”“ |
| | | ``` |
| | | |
| | | ## 模型资源准备 |
| | |
| | | --model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \ |
| | | --online-model-dir damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online-onnx \ |
| | | --punc-dir damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727-onnx > log.out 2>&1 & |
| | | |
| | | # If you want to close ssl,please add:--certfile 0 |
| | | ``` |
| | | For a more detailed description of server parameters, please refer to [Server Introduction]() |
| | | ### Client Testing and Usage |
| | |
| | | ```shell |
| | | sudo bash funasr-runtime-deploy-offline-cpu-zh.sh install --workspace /root/funasr-runtime-resources |
| | | ``` |
| | | Note: If you need to deploy the timestamp model or hotword model, select the corresponding model in step 2 of the installation and deployment process, where 1 is the paraformer-large model, 2 is the paraformer-large timestamp model, and 3 is the paraformer-large hotword model. |
| | | |
| | | ### Client Testing and Usage |
| | | |
| | |
| | | --audio_in is the audio file that needs to be transcribed, supporting file paths and file list wav.scp |
| | | --thread_num sets the number of concurrent sending threads, default is 1 |
| | | --ssl sets whether to enable SSL certificate verification, default is 1 to enable, and 0 to disable |
| | | --hotword If am is hotword model, setting hotword: *.txt(one hotword perline) or hotwords seperate by space (could be: 阿里巴巴 达摩院) |
| | | ``` |
| | | |
| | | ### cpp-client |
| | |
| | | --wav-path specifies the audio file to be transcribed, and supports file paths. |
| | | --thread_num sets the number of concurrent send threads, with a default value of 1. |
| | | --ssl sets whether to enable SSL certificate verification, with a default value of 1 for enabling and 0 for disabling. |
| | | --hotword If am is hotword model, setting hotword: *.txt(one hotword perline) or hotwords seperate by space (could be: 阿里巴巴 达摩院) |
| | | ``` |
| | | |
| | | ### html-client |
| | |
| | | ```shell |
| | | sudo bash funasr-runtime-deploy-offline-cpu-zh.sh install --workspace ./funasr-runtime-resources |
| | | ``` |
| | | 注:如果需要部署时间戳模型或者热词模型,在安装部署步骤2时选择对应模型,其中1为paraformer-large模型,2为paraformer-large 时间戳模型,3为paraformer-large 热词模型 |
| | | |
| | | ### 客户端测试与使用 |
| | | |
| | |
| | | --audio_in 需要进行转写的音频文件,支持文件路径,文件列表wav.scp |
| | | --thread_num 设置并发发送线程数,默认为1 |
| | | --ssl 设置是否开启ssl证书校验,默认1开启,设置为0关闭 |
| | | --hotword 如果模型为热词模型,可以设置热词: *.txt(每行一个热词) 或者空格分隔的热词字符串 (could be: 阿里巴巴 达摩院) |
| | | ``` |
| | | |
| | | ### cpp-client |
| | |
| | | --wav-path 需要进行转写的音频文件,支持文件路径 |
| | | --thread_num 设置并发发送线程数,默认为1 |
| | | --ssl 设置是否开启ssl证书校验,默认1开启,设置为0关闭 |
| | | --hotword 如果模型为热词模型,可以设置热词: *.txt(每行一个热词) 或者空格分隔的热词字符串 (could be: 阿里巴巴 达摩院) |
| | | ``` |
| | | |
| | | ### html-client |
| | |
| | | DOCKER: |
| | | funasr-runtime-sdk-cpu-0.2.0 |
| | | funasr-runtime-sdk-cpu-0.1.0 |
| | | DEFAULT_ASR_MODEL: |
| | | damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx |
| | |
| | | #### Initial Communication |
| | | The message (which needs to be serialized in JSON) is: |
| | | ```text |
| | | {"mode": "offline", "wav_name": "wav_name", "is_speaking": True,"wav_format":"pcm"} |
| | | {"mode": "offline", "wav_name": "wav_name","wav_format":"pcm","is_speaking": True,"wav_format":"pcm","hotwords":"阿里巴巴 达摩院 阿里云"} |
| | | ``` |
| | | Parameter explanation: |
| | | ```text |
| | |
| | | `wav_format`: the audio and video file extension, such as pcm, mp3, mp4, etc. |
| | | `is_speaking`: False indicates the end of a sentence, such as a VAD segmentation point or the end of a WAV file |
| | | `audio_fs`: when the input audio is in PCM format, the audio sampling rate parameter needs to be added |
| | | `hotwords`:If AM is the hotword model, hotword data needs to be sent to the server in string format, with " " used as a separator between hotwords. For example:"阿里巴巴 达摩院 阿里云" |
| | | ``` |
| | | |
| | | #### Sending Audio Data |
| | |
| | | #### Sending Recognition Results |
| | | The message (serialized in JSON) is: |
| | | ```text |
| | | {"mode": "offline", "wav_name": "wav_name", "text": "asr ouputs", "is_final": True} |
| | | {"mode": "offline", "wav_name": "wav_name", "text": "asr ouputs", "is_final": True, "timestamp":"[[100,200], [200,500]]"} |
| | | ``` |
| | | Parameter explanation: |
| | | ```text |
| | |
| | | `wav_name`: the name of the audio file to be transcribed |
| | | `text`: the text output of speech recognition |
| | | `is_final`: indicating the end of recognition |
| | | `timestamp`:If AM is a timestamp model, it will return this field, indicating the timestamp, in the format of "[[100,200], [200,500]]" |
| | | ``` |
| | | |
| | | ## Real-time Speech Recognition |
| | |
| | | #### Initial Communication |
| | | The message (which needs to be serialized in JSON) is: |
| | | ```text |
| | | {"mode": "2pass", "wav_name": "wav_name", "is_speaking": True, "wav_format":"pcm", "chunk_size":[5,10,5] |
| | | {"mode": "2pass", "wav_name": "wav_name", "is_speaking": True, "wav_format":"pcm", "chunk_size":[5,10,5]} |
| | | ``` |
| | | Parameter explanation: |
| | | ```text |
| | |
| | | #### 首次通信 |
| | | message为(需要用json序列化): |
| | | ```text |
| | | {"mode": "offline", "wav_name": "wav_name", "is_speaking": True,"wav_format":"pcm"} |
| | | {"mode": "offline", "wav_name": "wav_name","wav_format":"pcm","is_speaking": True,"wav_format":"pcm","hotwords":"阿里巴巴 达摩院 阿里云"} |
| | | ``` |
| | | 参数介绍: |
| | | ```text |
| | |
| | | `wav_format`:表示音视频文件后缀名,可选pcm、mp3、mp4等 |
| | | `is_speaking`:False 表示断句尾点,例如,vad切割点,或者一条wav结束 |
| | | `audio_fs`:当输入音频为pcm数据是,需要加上音频采样率参数 |
| | | `hotwords`:如果AM为热词模型,需要向服务端发送热词数据,格式为字符串,热词之间用" "分隔,例如 "阿里巴巴 达摩院 阿里云" |
| | | ``` |
| | | |
| | | #### 发送音频数据 |
| | |
| | | #### 发送识别结果 |
| | | message为(采用json序列化) |
| | | ```text |
| | | {"mode": "offline", "wav_name": "wav_name", "text": "asr ouputs", "is_final": True} |
| | | {"mode": "offline", "wav_name": "wav_name", "text": "asr ouputs", "is_final": True,"timestamp":"[[100,200], [200,500]]"} |
| | | ``` |
| | | 参数介绍: |
| | | ```text |
| | |
| | | `wav_name`:表示需要推理音频文件名 |
| | | `text`:表示语音识别输出文本 |
| | | `is_final`:表示识别结束 |
| | | `timestamp`:如果AM为时间戳模型,会返回此字段,表示时间戳,格式为 "[[100,200], [200,500]]"(ms) |
| | | ``` |
| | | |
| | | ## 实时语音识别 |
| | |
| | | #### 首次通信 |
| | | message为(需要用json序列化): |
| | | ```text |
| | | {"mode": "2pass", "wav_name": "wav_name", "is_speaking": True, "wav_format":"pcm", "chunk_size":[5,10,5] |
| | | {"mode": "2pass", "wav_name": "wav_name", "is_speaking": True, "wav_format":"pcm", "chunk_size":[5,10,5]} |
| | | ``` |
| | | 参数介绍: |
| | | ```text |
| | |
| | | |
| | | TCLAP::CmdLine cmd("funasr-onnx-2pass", ' ', "1.0"); |
| | | TCLAP::ValueArg<std::string> offline_model_dir("", OFFLINE_MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string"); |
| | | TCLAP::ValueArg<std::string> online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains encoder.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains model.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string"); |
| | | TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string"); |
| | | TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string"); |
| | | TCLAP::ValueArg<std::int32_t> onnx_thread("", "onnx-inter-thread", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t"); |
| | | TCLAP::ValueArg<std::int32_t> thread_num_("", THREAD_NUM, "multi-thread num for rtf", false, 1, "int32_t"); |
| | |
| | | |
| | | TCLAP::CmdLine cmd("funasr-onnx-2pass", ' ', "1.0"); |
| | | TCLAP::ValueArg<std::string> offline_model_dir("", OFFLINE_MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string"); |
| | | TCLAP::ValueArg<std::string> online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains encoder.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains model.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string"); |
| | | TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string"); |
| | | TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string"); |
| | | TCLAP::ValueArg<std::int32_t> onnx_thread("", "onnx-inter-thread", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t"); |
| | | |
| | |
| | | |
| | | TCLAP::CmdLine cmd("funasr-onnx-offline-punc", ' ', "1.0"); |
| | | TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the punc model path, which contains model.onnx, punc.yaml", true, "", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> txt_path("", TXT_PATH, "txt file path, one sentence per line", true, "", "string"); |
| | | |
| | | cmd.add(model_dir); |
| | |
| | | std::mutex mtx; |
| | | |
| | | void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wav_ids, |
| | | float* total_length, long* total_time, int core_id) { |
| | | float* total_length, long* total_time, int core_id, string hotwords) { |
| | | |
| | | struct timeval start, end; |
| | | long seconds = 0; |
| | | float n_total_length = 0.0f; |
| | | long n_total_time = 0; |
| | | std::vector<std::vector<float>> hotwords_embedding = CompileHotwordEmbedding(asr_handle, hotwords); |
| | | |
| | | // warm up |
| | | for (size_t i = 0; i < 1; i++) |
| | | { |
| | | FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, 16000); |
| | | FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, hotwords_embedding, 16000); |
| | | if(result){ |
| | | FunASRFreeResult(result); |
| | | } |
| | |
| | | } |
| | | |
| | | gettimeofday(&start, NULL); |
| | | FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[i].c_str(), RASR_NONE, NULL, 16000); |
| | | FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[i].c_str(), RASR_NONE, NULL, hotwords_embedding, 16000); |
| | | |
| | | gettimeofday(&end, NULL); |
| | | seconds = (end.tv_sec - start.tv_sec); |
| | |
| | | |
| | | TCLAP::CmdLine cmd("funasr-onnx-offline-rtf", ' ', "1.0"); |
| | | TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string"); |
| | | TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string"); |
| | | TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string"); |
| | | TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string"); |
| | | TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string"); |
| | | |
| | | TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string"); |
| | | TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t"); |
| | | TCLAP::ValueArg<std::string> hotword("", HOTWORD, "*.txt(one hotword perline) or hotwords seperate by | (could be: 阿里巴巴 达摩院)", false, "", "string"); |
| | | |
| | | cmd.add(model_dir); |
| | | cmd.add(quantize); |
| | |
| | | cmd.add(punc_quant); |
| | | cmd.add(wav_path); |
| | | cmd.add(thread_num); |
| | | cmd.add(hotword); |
| | | cmd.parse(argc, argv); |
| | | |
| | | std::map<std::string, std::string> model_path; |
| | |
| | | long seconds = (end.tv_sec - start.tv_sec); |
| | | long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); |
| | | LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s"; |
| | | |
| | | // read hotwords |
| | | std::string hotword_ = hotword.getValue(); |
| | | std::string hotwords_; |
| | | |
| | | if(is_target_file(hotword_, "txt")){ |
| | | ifstream in(hotword_); |
| | | if (!in.is_open()) { |
| | | LOG(ERROR) << "Failed to open file: " << model_path.at(HOTWORD) ; |
| | | return 0; |
| | | } |
| | | string line; |
| | | while(getline(in, line)) |
| | | { |
| | | hotwords_ +=line+HOTWORD_SEP; |
| | | } |
| | | in.close(); |
| | | }else{ |
| | | hotwords_ = hotword_; |
| | | } |
| | | |
| | | // read wav_path |
| | | vector<string> wav_list; |
| | |
| | | int rtf_threds = thread_num.getValue(); |
| | | for (int i = 0; i < rtf_threds; i++) |
| | | { |
| | | threads.emplace_back(thread(runReg, asr_handle, wav_list, wav_ids, &total_length, &total_time, i)); |
| | | threads.emplace_back(thread(runReg, asr_handle, wav_list, wav_ids, &total_length, &total_time, i, hotwords_)); |
| | | } |
| | | |
| | | for (auto& thread : threads) |
| | |
| | | |
| | | TCLAP::CmdLine cmd("funasr-onnx-offline-vad", ' ', "1.0"); |
| | | TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", true, "", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string"); |
| | | |
| | | TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string"); |
| | | |
| | |
| | | |
| | | TCLAP::CmdLine cmd("funasr-onnx-offline", ' ', "1.0"); |
| | | TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the asr model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string"); |
| | | TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string"); |
| | | TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string"); |
| | | TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string"); |
| | | TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string"); |
| | | |
| | | TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string"); |
| | | TCLAP::ValueArg<std::string> hotword("", HOTWORD, "*.txt(one hotword perline) or hotwords seperate by space (could be: 阿里巴巴 达摩院)", false, "", "string"); |
| | | |
| | | cmd.add(model_dir); |
| | | cmd.add(quantize); |
| | |
| | | cmd.add(punc_dir); |
| | | cmd.add(punc_quant); |
| | | cmd.add(wav_path); |
| | | cmd.add(hotword); |
| | | cmd.parse(argc, argv); |
| | | |
| | | std::map<std::string, std::string> model_path; |
| | |
| | | long seconds = (end.tv_sec - start.tv_sec); |
| | | long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); |
| | | LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s"; |
| | | |
| | | // read hotwords |
| | | std::string hotword_ = hotword.getValue(); |
| | | std::string hotwords_; |
| | | |
| | | if(is_target_file(hotword_, "txt")){ |
| | | ifstream in(hotword_); |
| | | if (!in.is_open()) { |
| | | LOG(ERROR) << "Failed to open file: " << model_path.at(HOTWORD) ; |
| | | return 0; |
| | | } |
| | | string line; |
| | | while(getline(in, line)) |
| | | { |
| | | hotwords_ +=line+HOTWORD_SEP; |
| | | } |
| | | in.close(); |
| | | }else{ |
| | | hotwords_ = hotword_; |
| | | } |
| | | |
| | | // read wav_path |
| | | vector<string> wav_list; |
| | |
| | | |
| | | float snippet_time = 0.0f; |
| | | long taking_micros = 0; |
| | | std::vector<std::vector<float>> hotwords_embedding = CompileHotwordEmbedding(asr_hanlde, hotwords_); |
| | | for (int i = 0; i < wav_list.size(); i++) { |
| | | auto& wav_file = wav_list[i]; |
| | | auto& wav_id = wav_ids[i]; |
| | | gettimeofday(&start, NULL); |
| | | FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL, 16000); |
| | | FUNASR_RESULT result=FunOfflineInfer(asr_hanlde, wav_file.c_str(), RASR_NONE, NULL, hotwords_embedding, 16000); |
| | | gettimeofday(&end, NULL); |
| | | seconds = (end.tv_sec - start.tv_sec); |
| | | taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec); |
| | |
| | | FLAGS_logtostderr = true; |
| | | |
| | | TCLAP::CmdLine cmd("funasr-onnx-offline-vad", ' ', "1.0"); |
| | | TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", true, "", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string"); |
| | | TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the asr online model path, which contains model.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string"); |
| | | |
| | | TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string"); |
| | | |
| | |
| | | |
| | | TCLAP::CmdLine cmd("funasr-onnx-online-punc", ' ', "1.0"); |
| | | TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the punc model path, which contains model.onnx, punc.yaml", true, "", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> txt_path("", TXT_PATH, "txt file path, one sentence per line", true, "", "string"); |
| | | |
| | | cmd.add(model_dir); |
| | |
| | | |
| | | TCLAP::CmdLine cmd("funasr-onnx-online-rtf", ' ', "1.0"); |
| | | TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "true (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string"); |
| | | TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string"); |
| | | TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "true (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string"); |
| | | TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string"); |
| | | TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string"); |
| | | TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "true (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string"); |
| | | |
| | | TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string"); |
| | | TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t"); |
| | |
| | | |
| | | TCLAP::CmdLine cmd("funasr-onnx-offline-vad", ' ', "1.0"); |
| | | TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", true, "", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string"); |
| | | TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "true", "string"); |
| | | |
| | | TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string"); |
| | | |
| | |
| | | #define TXT_PATH "txt-path" |
| | | #define THREAD_NUM "thread-num" |
| | | #define PORT_ID "port-id" |
| | | #define HOTWORD_SEP " " |
| | | |
| | | // #define VAD_MODEL_PATH "vad-model" |
| | | // #define VAD_CMVN_PATH "vad-cmvn" |
| | |
| | | // #define PUNC_CONFIG_PATH "punc-config" |
| | | |
| | | #define MODEL_NAME "model.onnx" |
| | | // hotword embedding compile model |
| | | #define MODEL_EB_NAME "model_eb.onnx" |
| | | #define QUANT_MODEL_NAME "model_quant.onnx" |
| | | #define VAD_CMVN_NAME "vad.mvn" |
| | | #define VAD_CONFIG_NAME "vad.yaml" |
| | | #define AM_CMVN_NAME "am.mvn" |
| | | #define AM_CONFIG_NAME "config.yaml" |
| | | #define PUNC_CONFIG_NAME "punc.yaml" |
| | | #define MODEL_SEG_DICT "seg_dict" |
| | | #define HOTWORD "hotword" |
| | | |
| | | #define ENCODER_NAME "model.onnx" |
| | | #define QUANT_ENCODER_NAME "model_quant.onnx" |
| | |
| | | //OfflineStream |
| | | _FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num); |
| | | // buffer |
| | | _FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000, std::string wav_format="pcm"); |
| | | _FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, const std::vector<std::vector<float>> &hw_emb, int sampling_rate=16000, std::string wav_format="pcm"); |
| | | // file, support wav & pcm |
| | | _FUNASRAPI FUNASR_RESULT FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000); |
| | | _FUNASRAPI FUNASR_RESULT FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, const std::vector<std::vector<float>> &hw_emb, int sampling_rate=16000); |
| | | _FUNASRAPI const std::vector<std::vector<float>> CompileHotwordEmbedding(FUNASR_HANDLE handle, std::string &hotwords); |
| | | _FUNASRAPI void FunOfflineUninit(FUNASR_HANDLE handle); |
| | | |
| | | //2passStream |
| | |
| | | virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){}; |
| | | virtual void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){}; |
| | | virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){}; |
| | | virtual std::string Forward(float *din, int len, bool input_finished){return "";}; |
| | | virtual std::string Forward(float *din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}){return "";}; |
| | | virtual std::string Rescoring() = 0; |
| | | virtual void InitHwCompiler(const std::string &hw_model, int thread_num){}; |
| | | virtual void InitSegDict(const std::string &seg_dict_model){}; |
| | | virtual std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords){}; |
| | | }; |
| | | |
| | | Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE); |
| | |
| | | file(GLOB files1 "*.cpp") |
| | | set(files ${files1}) |
| | | |
| | | message("files: "${files}) |
| | | |
| | | add_library(funasr SHARED ${files}) |
| | | |
| | | if(WIN32) |
| | |
| | | include_directories(${FFMPEG_DIR}/include) |
| | | endif() |
| | | |
| | | #message("CXX_FLAGS "${CMAKE_CXX_FLAGS}) |
| | | include_directories(${CMAKE_SOURCE_DIR}/include) |
| | | target_link_libraries(funasr PUBLIC onnxruntime ${EXTRA_LIBS}) |
| New file |
| | |
| | | /** |
| | | * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | * MIT License (https://opensource.org/licenses/MIT) |
| | | */ |
| | | #include "encode_converter.h" |
| | | #include <assert.h> |
| | | |
| | | |
| | | namespace funasr { |
| | | using namespace std; |
| | | |
| | | U16CHAR_T UTF16[8]; |
| | | U8CHAR_T UTF8[8]; |
| | | |
| | | size_t MyUtf8ToUtf16(const U8CHAR_T* pu8, size_t ilen, U16CHAR_T* pu16); |
| | | size_t MyUtf16ToUtf8(const U16CHAR_T* pu16, U8CHAR_T* pu8); |
| | | |
| | | |
| | | void EncodeConverter::SwapEndian(U16CHAR_T* pbuf, size_t len) |
| | | { |
| | | for (size_t i = 0; i < len; i++) { |
| | | pbuf[i] = ((pbuf[i] >> 8) | (pbuf[i] << 8)); |
| | | } |
| | | } |
| | | |
| | | |
| | | size_t MyUtf16ToUtf8(const U16CHAR_T* pu16, U8CHAR_T* pu8) |
| | | { |
| | | size_t n = 0; |
| | | if (pu16[0] <= 0x007F) |
| | | { |
| | | pu8[0] = (pu16[0] & 0x7F); |
| | | n = 1; |
| | | } |
| | | else if (pu16[0] >= 0x0080 && pu16[0] <= 0x07FF) |
| | | { |
| | | pu8[1] = (0x80 | (pu16[0] & 0x003F)); |
| | | pu8[0] = (0xC0 | ((pu16[0] >> 6) & 0x001F)); |
| | | n = 2; |
| | | } |
| | | else if (pu16[0] >= 0x0800) |
| | | { |
| | | pu8[2] = (0x80 | (pu16[0] & 0x003F)); |
| | | pu8[1] = (0x80 | ((pu16[0] >> 6) & 0x003F)); |
| | | pu8[0] = (0xE0 | ((pu16[0] >> 12) & 0x000F)); |
| | | n = 3; |
| | | } |
| | | |
| | | return n; |
| | | } |
| | | |
| | | #define is2ByteUtf16(u16) ( (u16) >= 0x0080 && (u16) <= 0x07FF ) |
| | | #define is3ByteUtf16(u16) ( (u16) >= 0x0800 ) |
| | | |
| | | size_t EncodeConverter::Utf16ToUtf8(const U16CHAR_T* pu16, U8CHAR_T* pu8) |
| | | { |
| | | size_t n = 0; |
| | | if (pu16[0] <= 0x007F) |
| | | { |
| | | pu8[0] = (pu16[0] & 0x7F); |
| | | n = 1; |
| | | } |
| | | else if (pu16[0] >= 0x0080 && pu16[0] <= 0x07FF) |
| | | { |
| | | pu8[1] = (0x80 | (pu16[0] & 0x003F)); |
| | | pu8[0] = (0xC0 | ((pu16[0] >> 6) & 0x001F)); |
| | | n = 2; |
| | | } |
| | | else if (pu16[0] >= 0x0800) |
| | | { |
| | | pu8[2] = (0x80 | (pu16[0] & 0x003F)); |
| | | pu8[1] = (0x80 | ((pu16[0] >> 6) & 0x003F)); |
| | | pu8[0] = (0xE0 | ((pu16[0] >> 12) & 0x000F)); |
| | | n = 3; |
| | | } |
| | | |
| | | return n; |
| | | } |
| | | |
| | | size_t EncodeConverter::Utf16ToUtf8(const U16CHAR_T* pu16, size_t ilen, |
| | | U8CHAR_T* pu8, size_t olen) |
| | | { |
| | | size_t offset = 0; |
| | | size_t sz = 0; |
| | | /* |
| | | for (size_t i = 0; i < ilen && offset < static_cast<int>(olen) - 3; i++) { |
| | | sz = utf16ToUtf8(pu16 + i, pu8 + offset); |
| | | offset += sz; |
| | | } |
| | | */ |
| | | for (size_t i = 0; i < ilen && static_cast<int>(offset) < static_cast<int>(olen); i++) { |
| | | sz = Utf16ToUtf8(pu16 + i, pu8 + offset); |
| | | if (static_cast<int>(offset + static_cast<int>(sz)) <= static_cast<int>(olen)) |
| | | offset += sz; |
| | | } |
| | | |
| | | // pu8[offset] = '\0'; |
| | | return offset; |
| | | } |
| | | |
| | | u8string EncodeConverter::Utf16ToUtf8(const u16string& u16str) |
| | | { |
| | | size_t buflen = u16str.length()*3 + 1; |
| | | U8CHAR_T* pu8 = new U8CHAR_T[buflen]; |
| | | size_t len = Utf16ToUtf8(u16str.data(), u16str.length(), |
| | | pu8, buflen); |
| | | u8string u8str(pu8, len); |
| | | delete [] pu8; |
| | | |
| | | return u8str; |
| | | } |
| | | |
| | | size_t EncodeConverter::Utf8ToUtf16(const U8CHAR_T* pu8, U16CHAR_T* pu16) |
| | | { |
| | | size_t n = 0; |
| | | if ((pu8[0] & 0xF0) == 0xE0) |
| | | { |
| | | if ((pu8[1] & 0xC0) == 0x80 && |
| | | (pu8[2] & 0xC0) == 0x80) |
| | | { |
| | | pu16[0] = (((pu8[0] & 0x0F) << 4) | ((pu8[1] & 0x3C) >> 2)); |
| | | pu16[0] <<= 8; |
| | | pu16[0] |= (((pu8[1] & 0x03) << 6) | (pu8[2] & 0x3F)); |
| | | } |
| | | else |
| | | { |
| | | pu16[0] = defUniChar; |
| | | } |
| | | n = 3; |
| | | } |
| | | else if ((pu8[0] & 0xE0) == 0xC0) |
| | | { |
| | | if ((pu8[1] & 0xC0) == 0x80) |
| | | { |
| | | pu16[0] = ((pu8[0] & 0x1C) >> 2); |
| | | pu16[0] <<= 8; |
| | | pu16[0] |= (((pu8[0] & 0x03) << 6) | (pu8[1] & 0x3F)); |
| | | } |
| | | else |
| | | { |
| | | pu16[0] = defUniChar; |
| | | } |
| | | n = 2; |
| | | } |
| | | else if ((pu8[0] & 0x80) == 0x00) |
| | | { |
| | | pu16[0] = pu8[0]; |
| | | n = 1; |
| | | } |
| | | |
| | | return n; |
| | | } |
| | | |
| | | size_t MyUtf8ToUtf16(const U8CHAR_T* pu8, size_t ilen, U16CHAR_T* pu16) |
| | | { |
| | | size_t n = 0; |
| | | if ((pu8[0] & 0xF0) == 0xE0 && ilen >= 3) |
| | | { |
| | | if ((pu8[1] & 0xC0) == 0x80 && |
| | | (pu8[2] & 0xC0) == 0x80) |
| | | { |
| | | pu16[0] = (((pu8[0] & 0x0F) << 4) | ((pu8[1] & 0x3C) >> 2)); |
| | | pu16[0] <<= 8; |
| | | pu16[0] |= (((pu8[1] & 0x03) << 6) | (pu8[2] & 0x3F)); |
| | | n = 3; |
| | | } |
| | | else |
| | | { |
| | | pu16[0] = 0x0000; |
| | | n = 1; |
| | | } |
| | | } |
| | | else if ((pu8[0] & 0xE0) == 0xC0 && ilen >= 2) |
| | | { |
| | | if ((pu8[1] & 0xC0) == 0x80) |
| | | { |
| | | pu16[0] = ((pu8[0] & 0x1C) >> 2); |
| | | pu16[0] <<= 8; |
| | | pu16[0] |= (((pu8[0] & 0x03) << 6) | (pu8[1] & 0x3F)); |
| | | n = 2; |
| | | } |
| | | else |
| | | { |
| | | pu16[0] = 0x0000; |
| | | n = 1; |
| | | } |
| | | } |
| | | else if ((pu8[0] & 0x80) == 0x00) |
| | | { |
| | | pu16[0] = pu8[0]; |
| | | n = 1; |
| | | } |
| | | else |
| | | { |
| | | pu16[0] = 0x0000; |
| | | n = 1; |
| | | } |
| | | return n; |
| | | } |
| | | |
| | | size_t EncodeConverter::Utf8ToUtf16(const U8CHAR_T* pu8, size_t ilen, U16CHAR_T* pu16) |
| | | { |
| | | size_t n = 0; |
| | | if ((pu8[0] & 0xF0) == 0xE0 && ilen >= 3) |
| | | { |
| | | if ((pu8[1] & 0xC0) == 0x80 && |
| | | (pu8[2] & 0xC0) == 0x80) |
| | | { |
| | | pu16[0] = (((pu8[0] & 0x0F) << 4) | ((pu8[1] & 0x3C) >> 2)); |
| | | pu16[0] <<= 8; |
| | | pu16[0] |= (((pu8[1] & 0x03) << 6) | (pu8[2] & 0x3F)); |
| | | n = 3; |
| | | if( !is3ByteUtf16(pu16[0]) ) |
| | | { |
| | | pu16[0] = 0x0000; |
| | | n = 1; |
| | | } |
| | | } |
| | | else |
| | | { |
| | | pu16[0] = 0x0000; |
| | | n = 1; |
| | | } |
| | | } |
| | | else if ((pu8[0] & 0xE0) == 0xC0 && ilen >= 2) |
| | | { |
| | | if ((pu8[1] & 0xC0) == 0x80) |
| | | { |
| | | pu16[0] = ((pu8[0] & 0x1C) >> 2); |
| | | pu16[0] <<= 8; |
| | | pu16[0] |= (((pu8[0] & 0x03) << 6) | (pu8[1] & 0x3F)); |
| | | n = 2; |
| | | if( !is2ByteUtf16(pu16[0]) ) |
| | | { |
| | | pu16[0] = 0x0000; |
| | | n = 1; |
| | | } |
| | | } |
| | | else |
| | | { |
| | | pu16[0] = 0x0000; |
| | | n = 1; |
| | | } |
| | | } |
| | | else if ((pu8[0] & 0x80) == 0x00) |
| | | { |
| | | pu16[0] = pu8[0]; |
| | | n = 1; |
| | | } |
| | | else |
| | | { |
| | | pu16[0] = 0x0000; |
| | | n = 1; |
| | | } |
| | | |
| | | return n; |
| | | /* |
| | | size_t n = 0; |
| | | if ((pu8[0] & 0xF0) == 0xE0) |
| | | { |
| | | if (ilen >= 3 && (pu8[1] & 0xC0) == 0x80 && |
| | | (pu8[2] & 0xC0) == 0x80) |
| | | { |
| | | pu16[0] = (((pu8[0] & 0x0F) << 4) | ((pu8[1] & 0x3C) >> 2)); |
| | | pu16[0] <<= 8; |
| | | pu16[0] |= (((pu8[1] & 0x03) << 6) | (pu8[2] & 0x3F)); |
| | | } |
| | | else |
| | | { |
| | | pu16[0] = defUniChar; |
| | | } |
| | | n = 3; |
| | | } |
| | | else if ((pu8[0] & 0xE0) == 0xC0) |
| | | { |
| | | if( ilen >= 2 && (pu8[1] & 0xC0) == 0x80) |
| | | { |
| | | pu16[0] = ((pu8[0] & 0x1C) >> 2); |
| | | pu16[0] <<= 8; |
| | | pu16[0] |= (((pu8[0] & 0x03) << 6) | (pu8[1] & 0x3F)); |
| | | } |
| | | else |
| | | { |
| | | pu16[0] = defUniChar; |
| | | } |
| | | n = 2; |
| | | } |
| | | else if ((pu8[0] & 0x80) == 0x00) |
| | | { |
| | | pu16[0] = pu8[0]; |
| | | n = 1; |
| | | } |
| | | else |
| | | { |
| | | pu16[0] = defUniChar; |
| | | n = 1; |
| | | for (size_t i = 1; i < ilen; i++) |
| | | { |
| | | if ((pu8[i] & 0xF0) == 0xE0 || (pu8[i] & 0xE0) == 0xC0 || (pu8[i] & 0x80) == 0x00) |
| | | break; |
| | | n++; |
| | | } |
| | | } |
| | | |
| | | return n; |
| | | */ |
| | | } |
| | | |
| | | size_t EncodeConverter::Utf8ToUtf16(const U8CHAR_T* pu8, size_t ilen, |
| | | U16CHAR_T* pu16, size_t olen) |
| | | { |
| | | int offset = 0; |
| | | size_t sz = 0; |
| | | for (size_t i = 0; i < ilen && offset < static_cast<int>(olen); offset ++) |
| | | { |
| | | sz = Utf8ToUtf16(pu8 + i, ilen - i, pu16 + offset); |
| | | i += sz; |
| | | if (sz == 0) { |
| | | // failed |
| | | // assert(sz != 0); |
| | | break; |
| | | } |
| | | } |
| | | // pu16[offset] = '\0'; |
| | | |
| | | return offset; |
| | | } |
| | | |
| | | u16string EncodeConverter::Utf8ToUtf16(const u8string& u8str) |
| | | { |
| | | U16CHAR_T* p16 = new U16CHAR_T[u8str.length() + 1]; |
| | | size_t len = Utf8ToUtf16(u8str.data(), u8str.length(), |
| | | p16, u8str.length() + 1); |
| | | u16string u16str(p16, len); |
| | | delete[] p16; |
| | | |
| | | return u16str; |
| | | } |
| | | |
| | | bool EncodeConverter::IsUTF8(const U8CHAR_T* pu8, size_t ilen) |
| | | { |
| | | size_t i; |
| | | size_t n = 0; |
| | | for (i = 0; i < ilen; i += n) |
| | | { |
| | | if ((pu8[i] & 0xF0) == 0xE0 && |
| | | (pu8[i + 1] & 0xC0) == 0x80 && |
| | | (pu8[i + 2] & 0xC0) == 0x80) |
| | | { |
| | | n = 3; |
| | | } |
| | | else if ((pu8[i] & 0xE0) == 0xC0 && |
| | | (pu8[i + 1] & 0xC0) == 0x80) |
| | | { |
| | | n = 2; |
| | | } |
| | | else if ((pu8[i] & 0x80) == 0x00) |
| | | { |
| | | n = 1; |
| | | } |
| | | else |
| | | { |
| | | break; |
| | | } |
| | | } |
| | | |
| | | return i == ilen; |
| | | } |
| | | |
| | | bool EncodeConverter::IsUTF8(const u8string& u8str) |
| | | { |
| | | return IsUTF8(u8str.data(), u8str.length()); |
| | | } |
| | | |
| | | size_t EncodeConverter::GetUTF8Len(const U8CHAR_T* pu8, size_t ilen) |
| | | { |
| | | size_t i; |
| | | size_t n = 0; |
| | | size_t rlen = 0; |
| | | for (i = 0; i < ilen; i += n, rlen ++) |
| | | { |
| | | if ((pu8[i] & 0xF0) == 0xE0 && |
| | | (pu8[i + 1] & 0xC0) == 0x80 && |
| | | (pu8[i + 2] & 0xC0) == 0x80) |
| | | { |
| | | n = 3; |
| | | } |
| | | else if ((pu8[i] & 0xE0) == 0xC0 && |
| | | (pu8[i + 1] & 0xC0) == 0x80) |
| | | { |
| | | n = 2; |
| | | } |
| | | else if ((pu8[i] & 0x80) == 0x00) |
| | | { |
| | | n = 1; |
| | | } |
| | | else |
| | | { |
| | | break; |
| | | } |
| | | } |
| | | |
| | | if (i == ilen) |
| | | return 0; |
| | | else |
| | | return rlen; |
| | | } |
| | | |
| | | size_t EncodeConverter::GetUTF8Len(const u8string& u8str) |
| | | { |
| | | return GetUTF8Len(u8str.data(), u8str.length()); |
| | | } |
| | | |
| | | |
| | | size_t EncodeConverter::Utf16ToUtf8Len(const U16CHAR_T* pu16, size_t ilen) |
| | | { |
| | | int offset = 0; |
| | | for (size_t i = 0; i < ilen ; i++) { |
| | | if (pu16[i] <= 0x007F) |
| | | { |
| | | offset += 1; |
| | | } |
| | | else if (pu16[i] >= 0x0080 && pu16[i] <= 0x07FF) |
| | | { |
| | | offset += 2; |
| | | } |
| | | else if (pu16[i] >= 0x0800) |
| | | { |
| | | offset += 3; |
| | | } |
| | | } |
| | | |
| | | return offset; |
| | | } |
| | | |
| | | uint16_t EncodeConverter::ToUni(const char* sc, int &len) |
| | | { |
| | | uint16_t wide[2]; |
| | | len = (int)Utf8ToUtf16((const U8CHAR_T*)sc, wide); |
| | | return wide[0]; |
| | | } |
| | | |
| | | bool EncodeConverter::IsAllChineseCharactor(const U8CHAR_T* pu8, size_t ilen) { |
| | | if (pu8 == NULL || ilen <= 0) { |
| | | return false; |
| | | } |
| | | |
| | | U16CHAR_T* p16 = new U16CHAR_T[ilen + 1]; |
| | | size_t len = Utf8ToUtf16(pu8, ilen, p16, ilen + 1); |
| | | for (size_t i = 0; i < len; i++) { |
| | | if (p16[i] < 0x4e00 || p16[i] > 0x9fff) { |
| | | delete[] p16; |
| | | return false; |
| | | } |
| | | } |
| | | delete[] p16; |
| | | return true; |
| | | } |
| | | |
| | | bool EncodeConverter::HasAlpha(const U8CHAR_T* pu8, size_t ilen) { |
| | | if (pu8 == NULL || ilen <= 0) { |
| | | return false; |
| | | } |
| | | for (size_t i = 0; i < ilen; i++) { |
| | | if (pu8[i]> 0 && isalpha(pu8[i])){ |
| | | return true; |
| | | } |
| | | } |
| | | return false; |
| | | } |
| | | |
| | | |
| | | bool EncodeConverter::IsAllAlpha(const U8CHAR_T* pu8, size_t ilen) { |
| | | if (pu8 == NULL || ilen <= 0) { |
| | | return false; |
| | | } |
| | | for (size_t i = 0; i < ilen; i++) { |
| | | if (!(pu8[i]> 0 && isalpha(pu8[i]))){ |
| | | return false; |
| | | } |
| | | } |
| | | return true; |
| | | } |
| | | |
| | | bool EncodeConverter::IsAllAlphaAndPunct(const U8CHAR_T* pu8, size_t ilen) { |
| | | if (pu8 == NULL || ilen <= 0) { |
| | | return false; |
| | | } |
| | | bool flag1 = HasAlpha(pu8, ilen); |
| | | if (flag1 == false) { |
| | | return false; |
| | | } |
| | | |
| | | for (size_t i = 0; i < ilen; i++) { |
| | | if (!(pu8[i]> 0 && (isalpha(pu8[i]) || (ispunct(pu8[i]))))){ |
| | | return false; |
| | | } |
| | | } |
| | | return true; |
| | | } |
| | | |
| | | bool EncodeConverter::IsAllAlphaAndDigit(const U8CHAR_T* pu8, size_t ilen) { |
| | | if (pu8 == NULL || ilen <= 0) { |
| | | return false; |
| | | } |
| | | bool flag1 = HasAlpha(pu8, ilen); |
| | | if (flag1 == false) { |
| | | return false; |
| | | } |
| | | |
| | | for (size_t i = 0; i < ilen; i++) { |
| | | if (!(pu8[i]> 0 && (isalnum(pu8[i]) || isalpha(pu8[i]) || pu8[i] == '\''))){ |
| | | return false; |
| | | } |
| | | } |
| | | return true; |
| | | } |
| | | bool EncodeConverter::IsAllAlphaAndDigitAndBlank(const U8CHAR_T* pu8, size_t ilen) { |
| | | if (pu8 == NULL || ilen <= 0) { |
| | | return false; |
| | | } |
| | | for (size_t i = 0; i < ilen; i++) { |
| | | if (!(pu8[i]> 0 && (isalnum(pu8[i]) || isalpha(pu8[i]) || isblank(pu8[i]) || pu8[i] == '\''))){ |
| | | return false; |
| | | } |
| | | } |
| | | return true; |
| | | } |
| | | bool EncodeConverter::NeedAddTailBlank(std::string str) { |
| | | U8CHAR_T *pu8 = (U8CHAR_T*)str.data(); |
| | | size_t ilen = str.size(); |
| | | if (pu8 == NULL || ilen <= 0) { |
| | | return false; |
| | | } |
| | | if (IsAllAlpha(pu8, ilen) || IsAllAlphaAndPunct(pu8, ilen) || IsAllAlphaAndDigit(pu8, ilen)) { |
| | | return true; |
| | | } else { |
| | | return false; |
| | | } |
| | | } |
| | | std::vector<std::string> EncodeConverter::MergeEnglishWord(std::vector<std::string> &str_vec_input, |
| | | std::vector<int> &merge_mask) { |
| | | std::vector<std::string> output; |
| | | for (int i = 0; i < merge_mask.size(); i++) { |
| | | if (merge_mask[i] == 1 && i > 0) { |
| | | output[output.size() - 1] += str_vec_input[i]; |
| | | } else { |
| | | output.push_back(str_vec_input[i]); |
| | | } |
| | | } |
| | | str_vec_input.swap(output); |
| | | return str_vec_input; |
| | | } |
| | | size_t EncodeConverter::Utf8ToCharset(const std::string &input, std::vector<std::string> &output) { |
| | | std::string ch; |
| | | for (size_t i = 0, len = 0; i != input.length(); i += len) { |
| | | unsigned char byte = (unsigned)input[i]; |
| | | if (byte >= 0xFC) // lenght 6 |
| | | len = 6; |
| | | else if (byte >= 0xF8) |
| | | len = 5; |
| | | else if (byte >= 0xF0) |
| | | len = 4; |
| | | else if (byte >= 0xE0) |
| | | len = 3; |
| | | else if (byte >= 0xC0) |
| | | len = 2; |
| | | else |
| | | len = 1; |
| | | ch = input.substr(i, len); |
| | | output.push_back(ch); |
| | | } |
| | | return output.size(); |
| | | } |
| | | } |
| New file |
| | |
| | | /** |
| | | * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | * MIT License (https://opensource.org/licenses/MIT) |
| | | */ |
| | | #ifndef __WS__ENCODE_CONVERTER_H__ |
| | | #define __WS__ENCODE_CONVERTER_H__ |
| | | |
| | | #include <string> |
| | | #include <stdint.h> |
| | | #include <vector> |
| | | #ifdef _MSC_VER |
| | | #include <windows.h> |
| | | #endif // _MSC_VER |
| | | |
| | | namespace funasr { |
| | | typedef unsigned char U8CHAR_T; |
| | | typedef unsigned short U16CHAR_T; |
| | | typedef std::basic_string<U8CHAR_T> u8string; |
| | | typedef std::basic_string<U16CHAR_T> u16string; |
| | | |
| | | class EncodeConverter { |
| | | public: |
| | | static const U16CHAR_T defUniChar = 0x25a1; //WHITE SQUARE |
| | | |
| | | public: |
| | | static void SwapEndian(U16CHAR_T* pbuf, size_t len); |
| | | |
| | | static size_t Utf16ToUtf8(const U16CHAR_T* pu16, U8CHAR_T* pu8); |
| | | |
| | | ///< @param pu16 UTF16 string |
| | | ///< @param pu8 UTF8 string |
| | | static size_t Utf16ToUtf8(const U16CHAR_T* pu16, size_t ilen, |
| | | U8CHAR_T* pu8, size_t olen); |
| | | |
| | | static u8string Utf16ToUtf8(const u16string& u16str); |
| | | |
| | | static size_t Utf8ToUtf16(const U8CHAR_T* pu8, U16CHAR_T* pu16); |
| | | |
| | | static size_t Utf8ToUtf16(const U8CHAR_T* pu8, size_t ilen, U16CHAR_T* pu16); |
| | | |
| | | ///< @param pu8 UTF8 string |
| | | ///< @param pu16 UTF16 string |
| | | static size_t Utf8ToUtf16(const U8CHAR_T* pu8, size_t ilen, |
| | | U16CHAR_T* pu16, size_t olen); |
| | | |
| | | static u16string Utf8ToUtf16(const u8string& u8str); |
| | | |
| | | ///< @param pu8 string |
| | | ///< @return if string is encoded as UTF8 - true, otherwise false |
| | | static bool IsUTF8(const U8CHAR_T* pu8, size_t ilen); |
| | | |
| | | ///< @param u8str string |
| | | ///< @return if string is encoded as UTF8 - true, otherwise false |
| | | static bool IsUTF8(const u8string& u8str); |
| | | |
| | | ///< @param UTF8 string |
| | | ///< @return the word number of UTF8 |
| | | static size_t GetUTF8Len(const U8CHAR_T* pu8, size_t ilen); |
| | | |
| | | ///< @param UTF8 string |
| | | ///< @return the word number of UTF8 |
| | | static size_t GetUTF8Len(const u8string& u8str); |
| | | |
| | | ///< @param pu16 UTF16 string |
| | | ///< @param ilen UTF16 length |
| | | ///< @return UTF8 string length |
| | | static size_t Utf16ToUtf8Len(const U16CHAR_T* pu16, size_t ilen); |
| | | |
| | | static uint16_t ToUni(const char* sc, int &len); |
| | | |
| | | static bool IsChineseCharacter(U16CHAR_T &u16) { |
| | | return (u16 >= 0x4e00 && u16 <= 0x9fff) // common |
| | | || (u16 >= 0x3400 && u16 <= 0x4dff); // rare, extension A |
| | | } |
| | | |
| | | // whether the string is all Chinese |
| | | static bool IsAllChineseCharactor(const U8CHAR_T* pu8, size_t ilen); |
| | | static bool HasAlpha(const U8CHAR_T* pu8, size_t ilen); |
| | | static bool NeedAddTailBlank(std::string str); |
| | | static bool IsAllAlpha(const U8CHAR_T* pu8, size_t ilen); |
| | | static bool IsAllAlphaAndPunct(const U8CHAR_T* pu8, size_t ilen); |
| | | static bool IsAllAlphaAndDigit(const U8CHAR_T* pu8, size_t ilen); |
| | | static bool IsAllAlphaAndDigitAndBlank(const U8CHAR_T* pu8, size_t ilen); |
| | | static std::vector<std::string> MergeEnglishWord(std::vector<std::string> &str_vec_input, |
| | | std::vector<int> &merge_mask); |
| | | static size_t Utf8ToCharset(const std::string &input, std::vector<std::string> &output); |
| | | |
| | | #ifdef _MSC_VER |
| | | // convert to the local ansi page |
| | | static std::string UTF8ToLocaleAnsi(const std::string& strUTF8) { |
| | | int len = MultiByteToWideChar(CP_UTF8, 0, strUTF8.c_str(), -1, NULL, 0); |
| | | unsigned short*wszGBK = new unsigned short[len + 1]; |
| | | memset(wszGBK, 0, len * 2 + 2); |
| | | MultiByteToWideChar(CP_UTF8, 0, (LPCCH)strUTF8.c_str(), -1, (LPWSTR)wszGBK, len); |
| | | |
| | | len = WideCharToMultiByte(CP_ACP, 0, (LPCWCH)wszGBK, -1, NULL, 0, NULL, NULL); |
| | | char *szGBK = new char[len + 1]; |
| | | memset(szGBK, 0, len + 1); |
| | | WideCharToMultiByte(CP_ACP, 0, (LPCWCH)wszGBK, -1, szGBK, len, NULL, NULL); |
| | | std::string strTemp(szGBK); |
| | | delete[]szGBK; |
| | | delete[]wszGBK; |
| | | return strTemp; |
| | | } |
| | | #endif |
| | | }; |
| | | } |
| | | |
| | | #endif //__WS_ENCODE_CONVERTER_H__ |
| | |
| | | #include "precomp.h" |
| | | #include <vector> |
| | | #ifdef __cplusplus |
| | | |
| | | extern "C" { |
| | |
| | | } |
| | | |
| | | // APIs for Offline-stream Infer |
| | | _FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate, std::string wav_format) |
| | | _FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, const std::vector<std::vector<float>> &hw_emb, int sampling_rate, std::string wav_format) |
| | | { |
| | | funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle; |
| | | if (!offline_stream) |
| | |
| | | int n_total = audio.GetQueueSize(); |
| | | float start_time = 0.0; |
| | | while (audio.Fetch(buff, len, flag, start_time) > 0) { |
| | | string msg = (offline_stream->asr_handle)->Forward(buff, len, true); |
| | | string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb); |
| | | std::vector<std::string> msg_vec = funasr::split(msg, '|'); |
| | | p_result->msg += msg_vec[0]; |
| | | //timestamp |
| | | if(msg_vec.size() > 1){ |
| | | std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ','); |
| | | std::string cur_stamp = ""; |
| | | std::string cur_stamp = "["; |
| | | for(int i=0; i<msg_stamp.size()-1; i+=2){ |
| | | float begin = std::stof(msg_stamp[i])+start_time; |
| | | float end = std::stof(msg_stamp[i+1])+start_time; |
| | | cur_stamp += "["+std::to_string(begin)+","+std::to_string(end)+"],"; |
| | | cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"]"; |
| | | if(i != msg_stamp.size()-2){ |
| | | cur_stamp +=","; |
| | | } |
| | | p_result->stamp += cur_stamp; |
| | | } |
| | | p_result->stamp += cur_stamp + "]"; |
| | | } |
| | | n_step++; |
| | | if (fn_callback) |
| | |
| | | return p_result; |
| | | } |
| | | |
| | | _FUNASRAPI FUNASR_RESULT FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate) |
| | | _FUNASRAPI FUNASR_RESULT FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, const std::vector<std::vector<float>> &hw_emb, int sampling_rate) |
| | | { |
| | | funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle; |
| | | if (!offline_stream) |
| | |
| | | int n_total = audio.GetQueueSize(); |
| | | float start_time = 0.0; |
| | | while (audio.Fetch(buff, len, flag, start_time) > 0) { |
| | | string msg = (offline_stream->asr_handle)->Forward(buff, len, true); |
| | | string msg = (offline_stream->asr_handle)->Forward(buff, len, true, hw_emb); |
| | | std::vector<std::string> msg_vec = funasr::split(msg, '|'); |
| | | p_result->msg += msg_vec[0]; |
| | | //timestamp |
| | | if(msg_vec.size() > 1){ |
| | | std::vector<std::string> msg_stamp = funasr::split(msg_vec[1], ','); |
| | | std::string cur_stamp = ""; |
| | | std::string cur_stamp = "["; |
| | | for(int i=0; i<msg_stamp.size()-1; i+=2){ |
| | | float begin = std::stof(msg_stamp[i])+start_time; |
| | | float end = std::stof(msg_stamp[i+1])+start_time; |
| | | cur_stamp += "["+std::to_string(begin)+","+std::to_string(end)+"],"; |
| | | cur_stamp += "["+std::to_string((int)(1000*begin))+","+std::to_string((int)(1000*end))+"]"; |
| | | if(i != msg_stamp.size()-2){ |
| | | cur_stamp +=","; |
| | | } |
| | | p_result->stamp += cur_stamp; |
| | | } |
| | | p_result->stamp += cur_stamp + "]"; |
| | | } |
| | | |
| | | n_step++; |
| | | if (fn_callback) |
| | | fn_callback(n_step, n_total); |
| | |
| | | return p_result; |
| | | } |
| | | |
| | | _FUNASRAPI const std::vector<std::vector<float>> CompileHotwordEmbedding(FUNASR_HANDLE handle, std::string &hotwords) { |
| | | funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle; |
| | | std::vector<std::vector<float>> emb; |
| | | if (!offline_stream) |
| | | return emb; |
| | | return (offline_stream->asr_handle)->CompileHotwordEmbedding(hotwords); |
| | | } |
| | | |
| | | // APIs for 2pass-stream Infer |
| | | _FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished, int sampling_rate, std::string wav_format, ASR_TYPE mode) |
| | | { |
| | |
| | | string am_model_path; |
| | | string am_cmvn_path; |
| | | string am_config_path; |
| | | string hw_compile_model_path; |
| | | string seg_dict_path; |
| | | |
| | | asr_handle = make_unique<Paraformer>(); |
| | | bool enable_hotword = false; |
| | | hw_compile_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_EB_NAME); |
| | | seg_dict_path = PathAppend(model_path.at(MODEL_DIR), MODEL_SEG_DICT); |
| | | if (access(hw_compile_model_path.c_str(), F_OK) == 0) { // if model_eb.onnx exist, hotword enabled |
| | | enable_hotword = true; |
| | | asr_handle->InitHwCompiler(hw_compile_model_path, thread_num); |
| | | asr_handle->InitSegDict(seg_dict_path); |
| | | } |
| | | if (enable_hotword) { |
| | | am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME); |
| | | if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){ |
| | | am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME); |
| | | } |
| | | } else { |
| | | am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME); |
| | | if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){ |
| | | am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME); |
| | | } |
| | | } |
| | | am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME); |
| | | am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME); |
| | | |
| | | asr_handle = make_unique<Paraformer>(); |
| | | asr_handle->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num); |
| | | } |
| | | |
| | | |
| | | // PUNC model |
| | | if(model_path.find(PUNC_DIR) != model_path.end()){ |
| | | string punc_model_path; |
| | |
| | | return result; |
| | | } |
| | | |
| | | string ParaformerOnline::Forward(float* din, int len, bool input_finished) |
| | | string ParaformerOnline::Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb) |
| | | { |
| | | std::vector<std::vector<float>> wav_feats; |
| | | std::vector<float> waves(din, din+len); |
| | |
| | | void AddOverlapChunk(std::vector<std::vector<float>> &wav_feats, bool input_finished); |
| | | |
| | | string ForwardChunk(std::vector<std::vector<float>> &wav_feats, bool input_finished); |
| | | string Forward(float* din, int len, bool input_finished); |
| | | string Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb={{0.0}}); |
| | | string Rescoring(); |
| | | // 2pass |
| | | std::string online_res; |
| | |
| | | */ |
| | | |
| | | #include "precomp.h" |
| | | #include "paraformer.h" |
| | | #include "encode_converter.h" |
| | | #include <cstddef> |
| | | |
| | | using namespace std; |
| | | |
| | | namespace funasr { |
| | | |
| | | Paraformer::Paraformer() |
| | | :env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options_{}{ |
| | | :use_hotword(false), |
| | | env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options_{}, |
| | | hw_env_(ORT_LOGGING_LEVEL_ERROR, "paraformer_hw"),hw_session_options{} { |
| | | } |
| | | |
| | | // offline |
| | |
| | | m_strInputNames.push_back(strName.c_str()); |
| | | GetInputName(m_session_.get(), strName,1); |
| | | m_strInputNames.push_back(strName); |
| | | if (use_hotword) { |
| | | GetInputName(m_session_.get(), strName, 2); |
| | | m_strInputNames.push_back(strName); |
| | | } |
| | | |
| | | size_t numOutputNodes = m_session_->GetOutputCount(); |
| | | for(int index=0; index<numOutputNodes; index++){ |
| | |
| | | } |
| | | } |
| | | |
| | | void Paraformer::InitHwCompiler(const std::string &hw_model, int thread_num) { |
| | | hw_session_options.SetIntraOpNumThreads(thread_num); |
| | | hw_session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL); |
| | | // DisableCpuMemArena can improve performance |
| | | hw_session_options.DisableCpuMemArena(); |
| | | |
| | | try { |
| | | hw_m_session = std::make_unique<Ort::Session>(hw_env_, hw_model.c_str(), hw_session_options); |
| | | LOG(INFO) << "Successfully load model from " << hw_model; |
| | | } catch (std::exception const &e) { |
| | | LOG(ERROR) << "Error when load hw compiler onnx model: " << e.what(); |
| | | exit(0); |
| | | } |
| | | |
| | | string strName; |
| | | GetInputName(hw_m_session.get(), strName); |
| | | hw_m_strInputNames.push_back(strName.c_str()); |
| | | //GetInputName(hw_m_session.get(), strName,1); |
| | | //hw_m_strInputNames.push_back(strName); |
| | | |
| | | GetOutputName(hw_m_session.get(), strName); |
| | | hw_m_strOutputNames.push_back(strName); |
| | | |
| | | for (auto& item : hw_m_strInputNames) |
| | | hw_m_szInputNames.push_back(item.c_str()); |
| | | for (auto& item : hw_m_strOutputNames) |
| | | hw_m_szOutputNames.push_back(item.c_str()); |
| | | // if init hotword compiler is called, this is a hotword paraformer model |
| | | use_hotword = true; |
| | | } |
| | | |
| | | void Paraformer::InitSegDict(const std::string &seg_dict_model) { |
| | | seg_dict = new SegDict(seg_dict_model.c_str()); |
| | | } |
| | | |
| | | Paraformer::~Paraformer() |
| | | { |
| | | if(vocab) |
| | | delete vocab; |
| | | if(seg_dict) |
| | | delete seg_dict; |
| | | } |
| | | |
| | | void Paraformer::Reset() |
| | |
| | | int32_t feature_dim = fbank_opts_.mel_opts.num_bins; |
| | | vector<float> features(frames * feature_dim); |
| | | float *p = features.data(); |
| | | //std::cout << "samples " << len << std::endl; |
| | | //std::cout << "fbank frames " << frames << std::endl; |
| | | //std::cout << "fbank dim " << feature_dim << std::endl; |
| | | //std::cout << "feature size " << features.size() << std::endl; |
| | | |
| | | for (int32_t i = 0; i != frames; ++i) { |
| | | const float *f = fbank_.GetFrame(i); |
| | |
| | | } |
| | | } |
| | | |
| | | string Paraformer::Forward(float* din, int len, bool input_finished) |
| | | string Paraformer::Forward(float* din, int len, bool input_finished, const std::vector<std::vector<float>> &hw_emb) |
| | | { |
| | | |
| | | int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins; |
| | |
| | | |
| | | int32_t feat_dim = lfr_m*in_feat_dim; |
| | | int32_t num_frames = wav_feats.size() / feat_dim; |
| | | //std::cout << "feat in: " << num_frames << " " << feat_dim << std::endl; |
| | | |
| | | #ifdef _WIN_X86 |
| | | Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); |
| | |
| | | input_onnx.emplace_back(std::move(onnx_feats)); |
| | | input_onnx.emplace_back(std::move(onnx_feats_len)); |
| | | |
| | | std::vector<float> embedding; |
| | | try{ |
| | | if (use_hotword) { |
| | | if(hw_emb.size()<=0){ |
| | | LOG(ERROR) << "hw_emb is null"; |
| | | return ""; |
| | | } |
| | | //PrintMat(hw_emb, "input_clas_emb"); |
| | | const int64_t hotword_shape[3] = {1, hw_emb.size(), hw_emb[0].size()}; |
| | | embedding.reserve(hw_emb.size() * hw_emb[0].size()); |
| | | for (auto item : hw_emb) { |
| | | embedding.insert(embedding.end(), item.begin(), item.end()); |
| | | } |
| | | //LOG(INFO) << "hotword shape " << hotword_shape[0] << " " << hotword_shape[1] << " " << hotword_shape[2] << " size " << embedding.size(); |
| | | Ort::Value onnx_hw_emb = Ort::Value::CreateTensor<float>( |
| | | m_memoryInfo, embedding.data(), embedding.size(), hotword_shape, 3); |
| | | |
| | | input_onnx.emplace_back(std::move(onnx_hw_emb)); |
| | | } |
| | | }catch (std::exception const &e) |
| | | { |
| | | LOG(ERROR)<<e.what(); |
| | | return ""; |
| | | } |
| | | |
| | | string result; |
| | | try { |
| | | auto outputTensor = m_session_->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), input_onnx.size(), m_szOutputNames.data(), m_szOutputNames.size()); |
| | | std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape(); |
| | | //LOG(INFO) << "paraformer out shape " << outputShape[0] << " " << outputShape[1] << " " << outputShape[2]; |
| | | |
| | | int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>()); |
| | | float* floatData = outputTensor[0].GetTensorMutableData<float>(); |
| | |
| | | }else{ |
| | | result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]); |
| | | } |
| | | // int pos = 0; |
| | | // std::vector<std::vector<float>> logits; |
| | | // for (int j = 0; j < outputShape[1]; j++) |
| | | // { |
| | | // std::vector<float> vec_token; |
| | | // vec_token.insert(vec_token.begin(), floatData + pos, floatData + pos + outputShape[2]); |
| | | // logits.push_back(vec_token); |
| | | // pos += outputShape[2]; |
| | | // } |
| | | // //PrintMat(logits, "logits_out"); |
| | | // result = GreedySearch(floatData, *encoder_out_lens, outputShape[2]); |
| | | } |
| | | catch (std::exception const &e) |
| | | { |
| | |
| | | return result; |
| | | } |
| | | |
| | | |
| | | std::vector<std::vector<float>> Paraformer::CompileHotwordEmbedding(std::string &hotwords) { |
| | | int embedding_dim = encoder_size; |
| | | std::vector<std::vector<float>> hw_emb; |
| | | if (!use_hotword) { |
| | | std::vector<float> vec(embedding_dim, 0); |
| | | hw_emb.push_back(vec); |
| | | return hw_emb; |
| | | } |
| | | int max_hotword_len = 10; |
| | | std::vector<int32_t> hotword_matrix; |
| | | std::vector<int32_t> lengths; |
| | | int hotword_size = 1; |
| | | if (!hotwords.empty()) { |
| | | std::vector<std::string> hotword_array = split(hotwords, ' '); |
| | | hotword_size = hotword_array.size() + 1; |
| | | hotword_matrix.reserve(hotword_size * max_hotword_len); |
| | | for (auto hotword : hotword_array) { |
| | | std::vector<std::string> chars; |
| | | if (EncodeConverter::IsAllChineseCharactor((const U8CHAR_T*)hotword.c_str(), hotword.size())) { |
| | | KeepChineseCharacterAndSplit(hotword, chars); |
| | | } else { |
| | | // for english |
| | | std::vector<std::string> words = split(hotword, ' '); |
| | | for (auto word : words) { |
| | | std::vector<string> tokens = seg_dict->GetTokensByWord(word); |
| | | chars.insert(chars.end(), tokens.begin(), tokens.end()); |
| | | } |
| | | } |
| | | std::vector<int32_t> hw_vector(max_hotword_len, 0); |
| | | int vector_len = std::min(max_hotword_len, (int)chars.size()); |
| | | for (int i=0; i<chars.size(); i++) { |
| | | std::cout << chars[i] << " "; |
| | | hw_vector[i] = vocab->GetIdByToken(chars[i]); |
| | | } |
| | | std::cout << std::endl; |
| | | lengths.push_back(vector_len); |
| | | hotword_matrix.insert(hotword_matrix.end(), hw_vector.begin(), hw_vector.end()); |
| | | } |
| | | } |
| | | std::vector<int32_t> blank_vec(max_hotword_len, 0); |
| | | blank_vec[0] = 1; |
| | | hotword_matrix.insert(hotword_matrix.end(), blank_vec.begin(), blank_vec.end()); |
| | | lengths.push_back(1); |
| | | |
| | | #ifdef _WIN_X86 |
| | | Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU); |
| | | #else |
| | | Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault); |
| | | #endif |
| | | |
| | | const int64_t input_shape_[2] = {hotword_size, max_hotword_len}; |
| | | Ort::Value onnx_hotword = Ort::Value::CreateTensor<int32_t>(m_memoryInfo, |
| | | (int32_t*)hotword_matrix.data(), |
| | | hotword_size * max_hotword_len, |
| | | input_shape_, |
| | | 2); |
| | | LOG(INFO) << "clas shape " << hotword_size << " " << max_hotword_len << std::endl; |
| | | |
| | | std::vector<Ort::Value> input_onnx; |
| | | input_onnx.emplace_back(std::move(onnx_hotword)); |
| | | |
| | | std::vector<std::vector<float>> result; |
| | | try { |
| | | auto outputTensor = hw_m_session->Run(Ort::RunOptions{nullptr}, hw_m_szInputNames.data(), input_onnx.data(), input_onnx.size(), hw_m_szOutputNames.data(), hw_m_szOutputNames.size()); |
| | | std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape(); |
| | | |
| | | int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>()); |
| | | float* floatData = outputTensor[0].GetTensorMutableData<float>(); // shape [max_hotword_len, hotword_size, dim] |
| | | // get embedding by real hotword length |
| | | assert(outputShape[0] == max_hotword_len); |
| | | assert(outputShape[1] == hotword_size); |
| | | embedding_dim = outputShape[2]; |
| | | |
| | | for (int j = 0; j < hotword_size; j++) |
| | | { |
| | | int start_pos = hotword_size * (lengths[j] - 1) * embedding_dim + j * embedding_dim; |
| | | std::vector<float> embedding; |
| | | embedding.insert(embedding.begin(), floatData + start_pos, floatData + start_pos + embedding_dim); |
| | | result.push_back(embedding); |
| | | } |
| | | } |
| | | catch (std::exception const &e) |
| | | { |
| | | LOG(ERROR)<<e.what(); |
| | | } |
| | | //PrintMat(result, "clas_embedding_output"); |
| | | return result; |
| | | } |
| | | |
| | | string Paraformer::Rescoring() |
| | | { |
| | | LOG(ERROR)<<"Not Imp!!!!!!"; |
| | |
| | | */ |
| | | private: |
| | | Vocab* vocab = nullptr; |
| | | SegDict* seg_dict = nullptr; |
| | | //const float scale = 22.6274169979695; |
| | | const float scale = 1.0; |
| | | |
| | |
| | | void LoadCmvn(const char *filename); |
| | | vector<float> ApplyLfr(const vector<float> &in); |
| | | void ApplyCmvn(vector<float> *v); |
| | | |
| | | std::shared_ptr<Ort::Session> hw_m_session = nullptr; |
| | | Ort::Env hw_env_; |
| | | Ort::SessionOptions hw_session_options; |
| | | vector<string> hw_m_strInputNames, hw_m_strOutputNames; |
| | | vector<const char*> hw_m_szInputNames; |
| | | vector<const char*> hw_m_szOutputNames; |
| | | bool use_hotword; |
| | | |
| | | public: |
| | | Paraformer(); |
| | |
| | | void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num); |
| | | // 2pass |
| | | void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num); |
| | | void InitHwCompiler(const std::string &hw_model, int thread_num); |
| | | void InitSegDict(const std::string &seg_dict_model); |
| | | std::vector<std::vector<float>> CompileHotwordEmbedding(std::string &hotwords); |
| | | void Reset(); |
| | | vector<float> FbankKaldi(float sample_rate, const float* waves, int len); |
| | | string Forward(float* din, int len, bool input_finished=true); |
| | | string Forward(float* din, int len, bool input_finished=true, const std::vector<std::vector<float>> &hw_emb={{0.0}}); |
| | | string GreedySearch( float* in, int n_len, int64_t token_nums, bool is_stamp=false, std::vector<float> us_alphas={0}, std::vector<float> us_cif_peak={0}); |
| | | void TimestampOnnx(std::vector<float> &us_alphas, vector<float> us_cif_peak, vector<string>& char_list, std::string &res_str, |
| | | vector<vector<float>> ×tamp_list, float begin_time = 0.0, float total_offset = -1.5); |
| | | string PostProcess(std::vector<string> &raw_char, std::vector<std::vector<float>> ×tamp_list); |
| | | |
| | | string Rescoring(); |
| | | |
| | | knf::FbankOptions fbank_opts_; |
| | |
| | | #include "ct-transformer-online.h" |
| | | #include "e2e-vad.h" |
| | | #include "fsmn-vad.h" |
| | | #include "encode_converter.h" |
| | | #include "vocab.h" |
| | | #include "audio.h" |
| | | #include "fsmn-vad-online.h" |
| | | #include "tensor.h" |
| | | #include "util.h" |
| | | #include "seg_dict.h" |
| | | #include "resample.h" |
| | | #include "paraformer.h" |
| | | #include "paraformer-online.h" |
| New file |
| | |
| | | /** |
| | | * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | * MIT License (https://opensource.org/licenses/MIT) |
| | | */ |
| | | #include "precomp.h" |
| | | //#include "util.h" |
| | | //#include "seg_dict.h" |
| | | #include <glog/logging.h> |
| | | |
| | | #include <fstream> |
| | | #include <iostream> |
| | | #include <list> |
| | | #include <sstream> |
| | | #include <string> |
| | | |
| | | using namespace std; |
| | | |
| | | namespace funasr { |
| | | SegDict::SegDict(const char *filename) |
| | | { |
| | | ifstream in(filename); |
| | | if (!in) { |
| | | LOG(ERROR) << filename << " open failed !!"; |
| | | return; |
| | | } |
| | | string textline; |
| | | while (getline(in, textline)) { |
| | | std::vector<string> line_item = split(textline, '\t'); |
| | | //std::cout << textline << std::endl; |
| | | if (line_item.size() > 1) { |
| | | std::string word = line_item[0]; |
| | | std::string segs = line_item[1]; |
| | | std::vector<string> segs_vec = split(segs, ' '); |
| | | seg_dict[word] = segs_vec; |
| | | } |
| | | } |
| | | LOG(INFO) << "load seg dict successfully"; |
| | | } |
| | | std::vector<std::string> SegDict::GetTokensByWord(const std::string &word) { |
| | | if (seg_dict.count(word)) |
| | | return seg_dict[word]; |
| | | else { |
| | | std::vector<string> vec; |
| | | return vec; |
| | | } |
| | | } |
| | | |
| | | SegDict::~SegDict() |
| | | { |
| | | } |
| | | |
| | | |
| | | } // namespace funasr |
| New file |
| | |
| | | /** |
| | | * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved. |
| | | * MIT License (https://opensource.org/licenses/MIT) |
| | | */ |
| | | #ifndef SEG_DICT_H |
| | | #define SEG_DICT_H |
| | | |
| | | #include <stdint.h> |
| | | #include <string> |
| | | #include <vector> |
| | | #include <map> |
| | | using namespace std; |
| | | |
| | | namespace funasr { |
| | | class SegDict { |
| | | private: |
| | | std::map<string, std::vector<string>> seg_dict; |
| | | |
| | | public: |
| | | SegDict(const char *filename); |
| | | ~SegDict(); |
| | | std::vector<std::string> GetTokensByWord(const std::string &word); |
| | | }; |
| | | |
| | | } // namespace funasr |
| | | #endif |
| | |
| | | return (extension == target); |
| | | } |
| | | |
| | | void KeepChineseCharacterAndSplit(const std::string &input_str, |
| | | std::vector<std::string> &chinese_characters) { |
| | | chinese_characters.resize(0); |
| | | std::vector<U16CHAR_T> u16_buf; |
| | | u16_buf.resize(std::max(u16_buf.size(), input_str.size() + 1)); |
| | | U16CHAR_T* pu16 = u16_buf.data(); |
| | | U8CHAR_T * pu8 = (U8CHAR_T*)input_str.data(); |
| | | size_t ilen = input_str.size(); |
| | | size_t len = EncodeConverter::Utf8ToUtf16(pu8, ilen, pu16, ilen + 1); |
| | | for (size_t i = 0; i < len; i++) { |
| | | if (EncodeConverter::IsChineseCharacter(pu16[i])) { |
| | | U8CHAR_T u8buf[4]; |
| | | size_t n = EncodeConverter::Utf16ToUtf8(pu16 + i, u8buf); |
| | | u8buf[n] = '\0'; |
| | | chinese_characters.push_back((const char*)u8buf); |
| | | } |
| | | } |
| | | } |
| | | |
| | | std::vector<std::string> split(const std::string &s, char delim) { |
| | | std::vector<std::string> elems; |
| | | std::stringstream ss(s); |
| | |
| | | return elems; |
| | | } |
| | | |
| | | template<typename T> |
| | | void PrintMat(const std::vector<std::vector<T>> &mat, const std::string &name) { |
| | | std::cout << name << ":" << std::endl; |
| | | for (auto item : mat) { |
| | | for (auto item_ : item) { |
| | | std::cout << item_ << " "; |
| | | } |
| | | std::cout << std::endl; |
| | | } |
| | | } |
| | | } // namespace funasr |
| | |
| | | string PathAppend(const string &p1, const string &p2); |
| | | bool is_target_file(const std::string& filename, const std::string target); |
| | | |
| | | void KeepChineseCharacterAndSplit(const std::string &input_str, |
| | | std::vector<std::string> &chinese_characters); |
| | | |
| | | std::vector<std::string> split(const std::string &s, char delim); |
| | | |
| | | template<typename T> |
| | | void PrintMat(const std::vector<std::vector<T>> &mat, const std::string &name); |
| | | } // namespace funasr |
| | | #endif |
| | |
| | | exit(-1); |
| | | } |
| | | YAML::Node myList = config["token_list"]; |
| | | int i = 0; |
| | | for (YAML::const_iterator it = myList.begin(); it != myList.end(); ++it) { |
| | | vocab.push_back(it->as<string>()); |
| | | token_id[it->as<string>()] = i; |
| | | i ++; |
| | | } |
| | | } |
| | | |
| | | int Vocab::GetIdByToken(const std::string &token) { |
| | | if (token_id.count(token)) { |
| | | return token_id[token]; |
| | | } |
| | | return 0; |
| | | } |
| | | |
| | | void Vocab::Vector2String(vector<int> in, std::vector<std::string> &preds) |
| | | { |
| | | for (auto it = in.begin(); it != in.end(); it++) { |
| | |
| | | #include <stdint.h> |
| | | #include <string> |
| | | #include <vector> |
| | | #include <map> |
| | | using namespace std; |
| | | |
| | | namespace funasr { |
| | | class Vocab { |
| | | private: |
| | | vector<string> vocab; |
| | | std::map<string, int> token_id; |
| | | bool IsEnglish(string ch); |
| | | void LoadVocabFromYaml(const char* filename); |
| | | |
| | |
| | | bool IsChinese(string ch); |
| | | void Vector2String(vector<int> in, std::vector<std::string> &preds); |
| | | string Vector2StringV2(vector<int> in); |
| | | int GetIdByToken(const std::string &token); |
| | | }; |
| | | |
| | | } // namespace funasr |
| New file |
| | |
| | | from funasr_onnx import ContextualParaformer |
| | | from pathlib import Path |
| | | |
| | | model_dir = "./export/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404" |
| | | model = ContextualParaformer(model_dir, batch_size=1) |
| | | |
| | | wav_path = ['{}/.cache/modelscope/hub/damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404/example/asr_example.wav'.format(Path.home())] |
| | | hotwords = '随机热词 各种热词 魔搭 阿里巴巴' |
| | | |
| | | result = model(wav_path, hotwords) |
| | | print(result) |
| | |
| | | # -*- encoding: utf-8 -*- |
| | | from .paraformer_bin import Paraformer |
| | | from .paraformer_bin import Paraformer, ContextualParaformer |
| | | from .vad_bin import Fsmn_vad |
| | | from .vad_bin import Fsmn_vad_online |
| | | from .punc_bin import CT_Transformer |
| | |
| | | from typing import List, Union, Tuple |
| | | |
| | | import copy |
| | | import torch |
| | | import librosa |
| | | import numpy as np |
| | | |
| | |
| | | from .utils.postprocess_utils import sentence_postprocess |
| | | from .utils.frontend import WavFrontend |
| | | from .utils.timestamp_utils import time_stamp_lfr6_onnx |
| | | from .utils.utils import pad_list, make_pad_mask |
| | | |
| | | logging = get_logger() |
| | | |
| | |
| | | # texts = sentence_postprocess(token) |
| | | return token |
| | | |
| | | |
| | | class ContextualParaformer(Paraformer): |
| | | """ |
| | | Author: Speech Lab of DAMO Academy, Alibaba Group |
| | | Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition |
| | | https://arxiv.org/abs/2206.08317 |
| | | """ |
| | | def __init__(self, model_dir: Union[str, Path] = None, |
| | | batch_size: int = 1, |
| | | device_id: Union[str, int] = "-1", |
| | | plot_timestamp_to: str = "", |
| | | quantize: bool = False, |
| | | intra_op_num_threads: int = 4, |
| | | cache_dir: str = None |
| | | ): |
| | | |
| | | if not Path(model_dir).exists(): |
| | | from modelscope.hub.snapshot_download import snapshot_download |
| | | try: |
| | | model_dir = snapshot_download(model_dir, cache_dir=cache_dir) |
| | | except: |
| | | raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir) |
| | | |
| | | if quantize: |
| | | model_bb_file = os.path.join(model_dir, 'model_quant.onnx') |
| | | model_eb_file = os.path.join(model_dir, 'model_eb_quant.onnx') |
| | | else: |
| | | model_bb_file = os.path.join(model_dir, 'model.onnx') |
| | | model_eb_file = os.path.join(model_dir, 'model_eb.onnx') |
| | | |
| | | token_list_file = os.path.join(model_dir, 'tokens.txt') |
| | | self.vocab = {} |
| | | with open(Path(token_list_file), 'r') as fin: |
| | | for i, line in enumerate(fin.readlines()): |
| | | self.vocab[line.strip()] = i |
| | | |
| | | #if quantize: |
| | | # model_file = os.path.join(model_dir, 'model_quant.onnx') |
| | | #if not os.path.exists(model_file): |
| | | # logging.error(".onnx model not exist, please export first.") |
| | | |
| | | config_file = os.path.join(model_dir, 'config.yaml') |
| | | cmvn_file = os.path.join(model_dir, 'am.mvn') |
| | | config = read_yaml(config_file) |
| | | |
| | | self.converter = TokenIDConverter(config['token_list']) |
| | | self.tokenizer = CharTokenizer() |
| | | self.frontend = WavFrontend( |
| | | cmvn_file=cmvn_file, |
| | | **config['frontend_conf'] |
| | | ) |
| | | self.ort_infer_bb = OrtInferSession(model_bb_file, device_id, intra_op_num_threads=intra_op_num_threads) |
| | | self.ort_infer_eb = OrtInferSession(model_eb_file, device_id, intra_op_num_threads=intra_op_num_threads) |
| | | |
| | | self.batch_size = batch_size |
| | | self.plot_timestamp_to = plot_timestamp_to |
| | | if "predictor_bias" in config['model_conf'].keys(): |
| | | self.pred_bias = config['model_conf']['predictor_bias'] |
| | | else: |
| | | self.pred_bias = 0 |
| | | |
| | | def __call__(self, |
| | | wav_content: Union[str, np.ndarray, List[str]], |
| | | hotwords: str, |
| | | **kwargs) -> List: |
| | | # make hotword list |
| | | hotwords, hotwords_length = self.proc_hotword(hotwords) |
| | | # import pdb; pdb.set_trace() |
| | | [bias_embed] = self.eb_infer(hotwords, hotwords_length) |
| | | # index from bias_embed |
| | | bias_embed = bias_embed.transpose(1, 0, 2) |
| | | _ind = np.arange(0, len(hotwords)).tolist() |
| | | bias_embed = bias_embed[_ind, hotwords_length.cpu().numpy().tolist()] |
| | | waveform_list = self.load_data(wav_content, self.frontend.opts.frame_opts.samp_freq) |
| | | waveform_nums = len(waveform_list) |
| | | asr_res = [] |
| | | for beg_idx in range(0, waveform_nums, self.batch_size): |
| | | end_idx = min(waveform_nums, beg_idx + self.batch_size) |
| | | feats, feats_len = self.extract_feat(waveform_list[beg_idx:end_idx]) |
| | | bias_embed = np.expand_dims(bias_embed, axis=0) |
| | | bias_embed = np.repeat(bias_embed, feats.shape[0], axis=0) |
| | | try: |
| | | outputs = self.bb_infer(feats, feats_len, bias_embed) |
| | | am_scores, valid_token_lens = outputs[0], outputs[1] |
| | | except ONNXRuntimeError: |
| | | #logging.warning(traceback.format_exc()) |
| | | logging.warning("input wav is silence or noise") |
| | | preds = [''] |
| | | else: |
| | | preds = self.decode(am_scores, valid_token_lens) |
| | | for pred in preds: |
| | | pred = sentence_postprocess(pred) |
| | | asr_res.append({'preds': pred}) |
| | | return asr_res |
| | | |
| | | def proc_hotword(self, hotwords): |
| | | hotwords = hotwords.split(" ") |
| | | hotwords_length = [len(i) - 1 for i in hotwords] |
| | | hotwords_length.append(0) |
| | | hotwords_length = torch.Tensor(hotwords_length).to(torch.int32) |
| | | # hotwords.append('<s>') |
| | | def word_map(word): |
| | | return torch.tensor([self.vocab[i] for i in word]) |
| | | hotword_int = [word_map(i) for i in hotwords] |
| | | # import pdb; pdb.set_trace() |
| | | hotword_int.append(torch.tensor([1])) |
| | | hotwords = pad_list(hotword_int, pad_value=0, max_len=10) |
| | | return hotwords, hotwords_length |
| | | |
| | | def bb_infer(self, feats: np.ndarray, |
| | | feats_len: np.ndarray, bias_embed) -> Tuple[np.ndarray, np.ndarray]: |
| | | outputs = self.ort_infer_bb([feats, feats_len, bias_embed]) |
| | | return outputs |
| | | |
| | | def eb_infer(self, hotwords, hotwords_length): |
| | | outputs = self.ort_infer_eb([hotwords.to(torch.int32).numpy(), hotwords_length.to(torch.int32).numpy()]) |
| | | return outputs |
| | | |
| | | def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]: |
| | | return [self.decode_one(am_score, token_num) |
| | | for am_score, token_num in zip(am_scores, token_nums)] |
| | | |
| | | def decode_one(self, |
| | | am_score: np.ndarray, |
| | | valid_token_num: int) -> List[str]: |
| | | yseq = am_score.argmax(axis=-1) |
| | | score = am_score.max(axis=-1) |
| | | score = np.sum(score, axis=-1) |
| | | |
| | | # pad with mask tokens to ensure compatibility with sos/eos tokens |
| | | # asr_model.sos:1 asr_model.eos:2 |
| | | yseq = np.array([1] + yseq.tolist() + [2]) |
| | | hyp = Hypothesis(yseq=yseq, score=score) |
| | | |
| | | # remove sos/eos and get results |
| | | last_pos = -1 |
| | | token_int = hyp.yseq[1:last_pos].tolist() |
| | | |
| | | # remove blank symbol id, which is assumed to be 0 |
| | | token_int = list(filter(lambda x: x not in (0, 2), token_int)) |
| | | |
| | | # Change integer-ids to tokens |
| | | token = self.converter.ids2tokens(token_int) |
| | | token = token[:valid_token_num-self.pred_bias] |
| | | # texts = sentence_postprocess(token) |
| | | return token |
| | |
| | | from typing import Any, Dict, Iterable, List, NamedTuple, Set, Tuple, Union |
| | | |
| | | import re |
| | | import torch |
| | | import numpy as np |
| | | import yaml |
| | | try: |
| | |
| | | logger_initialized = {} |
| | | |
| | | |
| | | def pad_list(xs, pad_value, max_len=None): |
| | | n_batch = len(xs) |
| | | if max_len is None: |
| | | max_len = max(x.size(0) for x in xs) |
| | | pad = xs[0].new(n_batch, max_len, *xs[0].size()[1:]).fill_(pad_value) |
| | | |
| | | for i in range(n_batch): |
| | | pad[i, : xs[i].size(0)] = xs[i] |
| | | |
| | | return pad |
| | | |
| | | |
| | | def make_pad_mask(lengths, xs=None, length_dim=-1, maxlen=None): |
| | | if length_dim == 0: |
| | | raise ValueError("length_dim cannot be 0: {}".format(length_dim)) |
| | | |
| | | if not isinstance(lengths, list): |
| | | lengths = lengths.tolist() |
| | | bs = int(len(lengths)) |
| | | if maxlen is None: |
| | | if xs is None: |
| | | maxlen = int(max(lengths)) |
| | | else: |
| | | maxlen = xs.size(length_dim) |
| | | else: |
| | | assert xs is None |
| | | assert maxlen >= int(max(lengths)) |
| | | |
| | | seq_range = torch.arange(0, maxlen, dtype=torch.int64) |
| | | seq_range_expand = seq_range.unsqueeze(0).expand(bs, maxlen) |
| | | seq_length_expand = seq_range_expand.new(lengths).unsqueeze(-1) |
| | | mask = seq_range_expand >= seq_length_expand |
| | | |
| | | if xs is not None: |
| | | assert xs.size(0) == bs, (xs.size(0), bs) |
| | | |
| | | if length_dim < 0: |
| | | length_dim = xs.dim() + length_dim |
| | | # ind = (:, None, ..., None, :, , None, ..., None) |
| | | ind = tuple( |
| | | slice(None) if i in (0, length_dim) else None for i in range(xs.dim()) |
| | | ) |
| | | mask = mask[ind].expand_as(xs).to(xs.device) |
| | | return mask |
| | | |
| | | |
| | | class TokenIDConverter(): |
| | | def __init__(self, token_list: Union[List, str], |
| | | ): |
| | |
| | | # install openssl first apt-get install libssl-dev |
| | | find_package(OpenSSL REQUIRED) |
| | | |
| | | #message("CXX_FLAGS "${CMAKE_CXX_FLAGS}) |
| | | add_executable(funasr-wss-server "funasr-wss-server.cpp" "websocket-server.cpp") |
| | | add_executable(funasr-wss-server-2pass "funasr-wss-server-2pass.cpp" "websocket-server-2pass.cpp") |
| | | add_executable(funasr-wss-client "funasr-wss-client.cpp") |
| | |
| | | */ |
| | | void WaitABit() { |
| | | #ifdef WIN32 |
| | | Sleep(1000); |
| | | Sleep(500); |
| | | #else |
| | | sleep(1); |
| | | usleep(500); |
| | | #endif |
| | | } |
| | | std::atomic<int> wav_index(0); |
| | |
| | | case websocketpp::frame::opcode::text: |
| | | total_num=total_num+1; |
| | | LOG(INFO)<< "Thread: " << this_thread::get_id() <<",on_message = " << payload; |
| | | LOG(INFO) << "total_num=" << total_num << " wav_index=" <<wav_index; |
| | | if((total_num+1)==wav_index) |
| | | { |
| | | LOG(INFO) << "close client"; |
| | | websocketpp::lib::error_code ec; |
| | | m_client.close(m_hdl, websocketpp::close::status::going_away, "", ec); |
| | | if (ec){ |
| | |
| | | } |
| | | |
| | | // This method will block until the connection is complete |
| | | void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids) { |
| | | void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids, std::string hotwords) { |
| | | // Create a new connection to the given URI |
| | | websocketpp::lib::error_code ec; |
| | | typename websocketpp::client<T>::connection_ptr con = |
| | |
| | | // Create a thread to run the ASIO io_service event loop |
| | | websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run, |
| | | &m_client); |
| | | bool send_hotword = true; |
| | | while(true){ |
| | | int i = wav_index.fetch_add(1); |
| | | if (i >= wav_list.size()) { |
| | | break; |
| | | } |
| | | send_wav_data(wav_list[i], wav_ids[i]); |
| | | send_wav_data(wav_list[i], wav_ids[i], hotwords, send_hotword); |
| | | if(send_hotword){ |
| | | send_hotword = false; |
| | | } |
| | | } |
| | | WaitABit(); |
| | | |
| | |
| | | m_done = true; |
| | | } |
| | | // send wav to server |
| | | void send_wav_data(string wav_path, string wav_id) { |
| | | void send_wav_data(string wav_path, string wav_id, string hotwords, bool send_hotword) { |
| | | uint64_t count = 0; |
| | | std::stringstream val; |
| | | |
| | |
| | | jsonbegin["wav_name"] = wav_id; |
| | | jsonbegin["wav_format"] = wav_format; |
| | | jsonbegin["is_speaking"] = true; |
| | | if(send_hotword){ |
| | | LOG(INFO) << "hotwords: "<< hotwords; |
| | | jsonbegin["hotwords"] = hotwords; |
| | | } |
| | | m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text, |
| | | ec); |
| | | |
| | |
| | | jsonresult["is_speaking"] = false; |
| | | m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text, |
| | | ec); |
| | | // WaitABit(); |
| | | std::this_thread::sleep_for(std::chrono::milliseconds(100)); |
| | | } |
| | | websocketpp::client<T> m_client; |
| | | |
| | |
| | | TCLAP::ValueArg<int> is_ssl_( |
| | | "", "is-ssl", "is-ssl is 1 means use wss connection, or use ws connection", |
| | | false, 1, "int"); |
| | | TCLAP::ValueArg<std::string> hotword_("", HOTWORD, "*.txt(one hotword perline) or hotwords seperate by space (could be: 阿里巴巴 达摩院)", false, "", "string"); |
| | | |
| | | cmd.add(server_ip_); |
| | | cmd.add(port_); |
| | | cmd.add(wav_path_); |
| | | cmd.add(thread_num_); |
| | | cmd.add(is_ssl_); |
| | | cmd.add(hotword_); |
| | | cmd.parse(argc, argv); |
| | | |
| | | std::string server_ip = server_ip_.getValue(); |
| | |
| | | } else { |
| | | uri = "ws://" + server_ip + ":" + port; |
| | | } |
| | | |
| | | // read hotwords |
| | | std::string hotword = hotword_.getValue(); |
| | | std::string hotwords_; |
| | | |
| | | if(IsTargetFile(hotword, "txt")){ |
| | | ifstream in(hotword); |
| | | if (!in.is_open()) { |
| | | LOG(ERROR) << "Failed to open file: " << hotword; |
| | | return 0; |
| | | } |
| | | string line; |
| | | while(getline(in, line)) |
| | | { |
| | | hotwords_ +=line+HOTWORD_SEP; |
| | | } |
| | | in.close(); |
| | | }else{ |
| | | hotwords_ = hotword; |
| | | } |
| | | |
| | | |
| | | // read wav_path |
| | | std::vector<string> wav_list; |
| | |
| | | } |
| | | |
| | | for (size_t i = 0; i < threads_num; i++) { |
| | | client_threads.emplace_back([uri, wav_list, wav_ids, is_ssl]() { |
| | | client_threads.emplace_back([uri, wav_list, wav_ids, is_ssl, hotwords_]() { |
| | | if (is_ssl == 1) { |
| | | WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl); |
| | | |
| | | c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1)); |
| | | |
| | | c.run(uri, wav_list, wav_ids); |
| | | c.run(uri, wav_list, wav_ids, hotwords_); |
| | | } else { |
| | | WebsocketClient<websocketpp::config::asio_client> c(is_ssl); |
| | | |
| | | c.run(uri, wav_list, wav_ids); |
| | | c.run(uri, wav_list, wav_ids, hotwords_); |
| | | } |
| | | }); |
| | | } |
| | |
| | | server server_; // server for websocket |
| | | wss_server wss_server_; |
| | | if (is_ssl) { |
| | | LOG(INFO)<< "SSL is opened!"; |
| | | wss_server_.init_asio(&io_server); // init asio |
| | | wss_server_.set_reuse_addr( |
| | | true); // reuse address as we create multiple threads |
| | |
| | | websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model |
| | | |
| | | } else { |
| | | LOG(INFO)<< "SSL is closed!"; |
| | | server_.init_asio(&io_server); // init asio |
| | | server_.set_reuse_addr( |
| | | true); // reuse address as we create multiple threads |
| | |
| | | python_cmd_asr = python_cmd + " --model-name " + s_asr_path + " --export-dir ./ " + " --model_revision " + model_path["model-revision"]; |
| | | down_asr_path = s_asr_path; |
| | | }else{ |
| | | size_t found = s_asr_path.find("speech_paraformer-large-vad-punc_asr_nat-zh-cn-16k-common-vocab8404"); |
| | | if (found != std::string::npos) { |
| | | model_path["model-revision"]="v1.2.4"; |
| | | }else{ |
| | | found = s_asr_path.find("speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404"); |
| | | if (found != std::string::npos) { |
| | | model_path["model-revision"]="v1.0.3"; |
| | | model_path[QUANTIZE]=false; |
| | | s_asr_quant = false; |
| | | } |
| | | } |
| | | |
| | | // modelscope |
| | | LOG(INFO) << "Download model: " << s_asr_path << " from modelscope: "; |
| | | python_cmd_asr = python_cmd + " --model-name " + s_asr_path + " --export-dir " + s_download_model_dir + " --model_revision " + model_path["model-revision"]; |
| | |
| | | server server_; // server for websocket |
| | | wss_server wss_server_; |
| | | if (is_ssl) { |
| | | LOG(INFO)<< "SSL is opened!"; |
| | | wss_server_.init_asio(&io_server); // init asio |
| | | wss_server_.set_reuse_addr( |
| | | true); // reuse address as we create multiple threads |
| | |
| | | websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model |
| | | |
| | | } else { |
| | | LOG(INFO)<< "SSL is closed!"; |
| | | server_.init_asio(&io_server); // init asio |
| | | server_.set_reuse_addr( |
| | | true); // reuse address as we create multiple threads |
| | |
| | | ```
|
| | |
|
| | | API-reference:
|
| | |
|
| | | ```text
|
| | | --server-ip: The IP address of the machine where FunASR runtime-SDK service is deployed. The default value is the IP address of the local machine (127.0.0.1). If the client and service are not on the same server, it needs to be changed to the IP address of the deployment machine.
|
| | | --port: The port number of the deployed service is 10095.
|
| | |
| | | // feed buffer to asr engine for decoder
|
| | | void WebSocketServer::do_decoder(const std::vector<char>& buffer,
|
| | | websocketpp::connection_hdl& hdl,
|
| | | const nlohmann::json& msg) {
|
| | | websocketpp::lib::mutex& thread_lock,
|
| | | std::vector<std::vector<float>> &hotwords_embedding,
|
| | | std::string wav_name,
|
| | | std::string wav_format) {
|
| | | scoped_lock guard(thread_lock);
|
| | | try {
|
| | | int num_samples = buffer.size(); // the size of the buf
|
| | |
|
| | | if (!buffer.empty()) {
|
| | | // feed data to asr engine
|
| | | if (!buffer.empty() && hotwords_embedding.size() >0 ) {
|
| | | std::string asr_result;
|
| | | std::string stamp_res;
|
| | | try{
|
| | | FUNASR_RESULT Result = FunOfflineInferBuffer(
|
| | | asr_hanlde, buffer.data(), buffer.size(), RASR_NONE, NULL, 16000, msg["wav_format"]);
|
| | | asr_hanlde, buffer.data(), buffer.size(), RASR_NONE, NULL, hotwords_embedding, 16000, wav_format);
|
| | |
|
| | | std::string asr_result =
|
| | | ((FUNASR_RECOG_RESULT*)Result)->msg; // get decode result
|
| | | asr_result = ((FUNASR_RECOG_RESULT*)Result)->msg; // get decode result
|
| | | stamp_res = ((FUNASR_RECOG_RESULT*)Result)->stamp;
|
| | | FunASRFreeResult(Result);
|
| | | }catch (std::exception const& e) {
|
| | | LOG(ERROR) << e.what();
|
| | | return;
|
| | | }
|
| | |
|
| | | websocketpp::lib::error_code ec;
|
| | | nlohmann::json jsonresult; // result json
|
| | | jsonresult["text"] = asr_result; // put result in 'text'
|
| | | jsonresult["mode"] = "offline";
|
| | |
|
| | | jsonresult["wav_name"] = msg["wav_name"];
|
| | | if(stamp_res != ""){
|
| | | jsonresult["timestamp"] = stamp_res;
|
| | | }
|
| | | jsonresult["wav_name"] = wav_name;
|
| | |
|
| | | // send the json to client
|
| | | if (is_ssl) {
|
| | |
| | | }
|
| | |
|
| | | LOG(INFO) << "buffer.size=" << buffer.size() << ",result json=" << jsonresult.dump();
|
| | | if (!isonline) {
|
| | | // close the client if it is not online asr
|
| | | // server_->close(hdl, websocketpp::close::status::normal, "DONE", ec);
|
| | | // fout.close();
|
| | | }
|
| | | }
|
| | |
|
| | | } catch (std::exception const& e) {
|
| | |
| | |
|
| | | void WebSocketServer::on_open(websocketpp::connection_hdl hdl) {
|
| | | scoped_lock guard(m_lock); // for threads safty
|
| | | check_and_clean_connection(); // remove closed connection
|
| | |
|
| | | std::shared_ptr<FUNASR_MESSAGE> data_msg =
|
| | | std::make_shared<FUNASR_MESSAGE>(); // put a new data vector for new
|
| | | // connection
|
| | | data_msg->samples = std::make_shared<std::vector<char>>();
|
| | | data_msg->thread_lock = std::make_shared<websocketpp::lib::mutex>();
|
| | | data_msg->msg = nlohmann::json::parse("{}");
|
| | | data_msg->msg["wav_format"] = "pcm";
|
| | | data_map.emplace(hdl, data_msg);
|
| | |
| | |
|
| | | void WebSocketServer::on_close(websocketpp::connection_hdl hdl) {
|
| | | scoped_lock guard(m_lock);
|
| | | data_map.erase(hdl); // remove data vector when connection is closed
|
| | |
|
| | | std::shared_ptr<FUNASR_MESSAGE> data_msg = nullptr;
|
| | | auto it_data = data_map.find(hdl);
|
| | | if (it_data != data_map.end()) {
|
| | | data_msg = it_data->second;
|
| | | } else {
|
| | | return;
|
| | | }
|
| | | unique_lock guard_decoder(*(data_msg->thread_lock));
|
| | | data_msg->msg["is_eof"]=true;
|
| | | guard_decoder.unlock();
|
| | | // data_map.erase(hdl); // remove data vector when connection is closed
|
| | |
|
| | | LOG(INFO) << "on_close, active connections: " << data_map.size();
|
| | | }
|
| | |
|
| | | // remove closed connection
|
| | | void remove_hdl(
|
| | | websocketpp::connection_hdl hdl,
|
| | | std::map<websocketpp::connection_hdl, std::shared_ptr<FUNASR_MESSAGE>,
|
| | | std::owner_less<websocketpp::connection_hdl>>& data_map) {
|
| | | |
| | | std::shared_ptr<FUNASR_MESSAGE> data_msg = nullptr;
|
| | | auto it_data = data_map.find(hdl);
|
| | | if (it_data != data_map.end()) {
|
| | | data_msg = it_data->second;
|
| | | } else {
|
| | | return;
|
| | | }
|
| | | unique_lock guard_decoder(*(data_msg->thread_lock));
|
| | | if (data_msg->msg["is_eof"]==true) {
|
| | | data_map.erase(hdl);
|
| | | LOG(INFO) << "remove one connection";
|
| | | }
|
| | | guard_decoder.unlock();
|
| | | }
|
| | |
|
| | | void WebSocketServer::check_and_clean_connection() {
|
| | | while(true){
|
| | | std::this_thread::sleep_for(std::chrono::milliseconds(5000));
|
| | | std::vector<websocketpp::connection_hdl> to_remove; // remove list
|
| | | auto iter = data_map.begin();
|
| | | while (iter != data_map.end()) { // loop to find closed connection
|
| | | websocketpp::connection_hdl hdl = iter->first;
|
| | |
|
| | | try{
|
| | | if (is_ssl) {
|
| | | wss_server::connection_ptr con = wss_server_->get_con_from_hdl(hdl);
|
| | | if (con->get_state() != 1) { // session::state::open ==1
|
| | |
| | | to_remove.push_back(hdl);
|
| | | }
|
| | | }
|
| | | }
|
| | | catch (std::exception const &e)
|
| | | {
|
| | | // if connection is close, we set is_eof = true
|
| | | std::shared_ptr<FUNASR_MESSAGE> data_msg = nullptr;
|
| | | auto it_data = data_map.find(hdl);
|
| | | if (it_data != data_map.end()) {
|
| | | data_msg = it_data->second;
|
| | | } else {
|
| | | continue;
|
| | | }
|
| | | unique_lock guard_decoder(*(data_msg->thread_lock));
|
| | | data_msg->msg["is_eof"]=true;
|
| | | guard_decoder.unlock();
|
| | | to_remove.push_back(hdl);
|
| | | LOG(INFO)<<"connection is closed: "<<e.what();
|
| | |
|
| | | }
|
| | | iter++;
|
| | | }
|
| | | for (auto hdl : to_remove) {
|
| | | data_map.erase(hdl);
|
| | | LOG(INFO)<< "remove one connection ";
|
| | | remove_hdl(hdl, data_map);
|
| | | //LOG(INFO) << "remove one connection ";
|
| | | }
|
| | | }
|
| | | }
|
| | |
|
| | | void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
|
| | | message_ptr msg) {
|
| | | unique_lock lock(m_lock);
|
| | |
| | | msg_data = it_data->second;
|
| | | }
|
| | | std::shared_ptr<std::vector<char>> sample_data_p = msg_data->samples;
|
| | | std::shared_ptr<websocketpp::lib::mutex> thread_lock_p = msg_data->thread_lock;
|
| | |
|
| | | lock.unlock();
|
| | | if (sample_data_p == nullptr) {
|
| | |
| | | }
|
| | |
|
| | | const std::string& payload = msg->get_payload(); // get msg type
|
| | |
|
| | | unique_lock guard_decoder(*(thread_lock_p)); // mutex for one connection
|
| | | switch (msg->get_opcode()) {
|
| | | case websocketpp::frame::opcode::text: {
|
| | | nlohmann::json jsonresult = nlohmann::json::parse(payload);
|
| | |
| | | if (jsonresult["wav_format"] != nullptr) {
|
| | | msg_data->msg["wav_format"] = jsonresult["wav_format"];
|
| | | }
|
| | | if(msg_data->hotwords_embedding == NULL){
|
| | | if (jsonresult["hotwords"] != nullptr) {
|
| | | msg_data->msg["hotwords"] = jsonresult["hotwords"];
|
| | | if (!msg_data->msg["hotwords"].empty()) {
|
| | | std::string hw = msg_data->msg["hotwords"];
|
| | | LOG(INFO)<<"hotwords: " << hw;
|
| | | std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw);
|
| | | msg_data->hotwords_embedding =
|
| | | std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
|
| | | }
|
| | | }else{
|
| | | std::string hw = "";
|
| | | LOG(INFO)<<"hotwords: " << hw;
|
| | | std::vector<std::vector<float>> new_hotwords_embedding= CompileHotwordEmbedding(asr_hanlde, hw);
|
| | | msg_data->hotwords_embedding =
|
| | | std::make_shared<std::vector<std::vector<float>>>(new_hotwords_embedding);
|
| | | }
|
| | | }
|
| | |
|
| | | if (jsonresult["is_speaking"] == false ||
|
| | | jsonresult["is_finished"] == true) {
|
| | | LOG(INFO) << "client done";
|
| | |
|
| | | if (isonline) {
|
| | | // do_close(ws);
|
| | | } else {
|
| | | // add padding to the end of the wav data
|
| | | // std::vector<short> padding(static_cast<short>(0.3 * 16000));
|
| | | // sample_data_p->insert(sample_data_p->end(), padding.data(),
|
| | | // padding.data() + padding.size());
|
| | | // for offline, send all receive data to decoder engine
|
| | | std::vector<std::vector<float>> hotwords_embedding_(*(msg_data->hotwords_embedding));
|
| | | asio::post(io_decoder_,
|
| | | std::bind(&WebSocketServer::do_decoder, this,
|
| | | std::move(*(sample_data_p.get())),
|
| | | std::move(hdl), std::move(msg_data->msg)));
|
| | | }
|
| | | std::move(hdl), |
| | | std::ref(*thread_lock_p),
|
| | | std::move(hotwords_embedding_),
|
| | | msg_data->msg["wav_name"],
|
| | | msg_data->msg["wav_format"]));
|
| | | }
|
| | | break;
|
| | | }
|
| | |
| | | // recived binary data
|
| | | const auto* pcm_data = static_cast<const char*>(payload.data());
|
| | | int32_t num_samples = payload.size();
|
| | | //LOG(INFO) << "recv binary num_samples " << num_samples;
|
| | |
|
| | | if (isonline) {
|
| | | // if online TODO(zhaoming) still not done
|
| | | std::vector<char> s(pcm_data, pcm_data + num_samples);
|
| | | asio::post(io_decoder_,
|
| | | std::bind(&WebSocketServer::do_decoder, this, std::move(s),
|
| | | std::move(hdl), std::move(msg_data->msg)));
|
| | | // TODO
|
| | | } else {
|
| | | // for offline, we add receive data to end of the sample data vector
|
| | | sample_data_p->insert(sample_data_p->end(), pcm_data,
|
| | | pcm_data + num_samples);
|
| | | }
|
| | |
|
| | | break;
|
| | | }
|
| | | default:
|
| | |
| | | asr_hanlde = FunOfflineInit(model_path, thread_num);
|
| | | LOG(INFO) << "model successfully inited";
|
| | |
|
| | | LOG(INFO) << "initAsr run check_and_clean_connection";
|
| | | std::thread clean_thread(&WebSocketServer::check_and_clean_connection,this); |
| | | clean_thread.detach();
|
| | | LOG(INFO) << "initAsr run check_and_clean_connection finished";
|
| | |
|
| | | } catch (const std::exception& e) {
|
| | | LOG(INFO) << e.what();
|
| | | }
|
| | |
| | | context_ptr; |
| | | |
| | | typedef struct { |
| | | std::string msg; |
| | | float snippet_time; |
| | | std::string msg=""; |
| | | std::string stamp=""; |
| | | std::string tpass_msg=""; |
| | | float snippet_time=0; |
| | | } FUNASR_RECOG_RESULT; |
| | | |
| | | typedef struct { |
| | | nlohmann::json msg; |
| | | std::shared_ptr<std::vector<char>> samples; |
| | | std::shared_ptr<std::vector<std::vector<float>>> hotwords_embedding=NULL; |
| | | std::shared_ptr<websocketpp::lib::mutex> thread_lock; // lock for each connection |
| | | } FUNASR_MESSAGE; |
| | | |
| | | // See https://wiki.mozilla.org/Security/Server_Side_TLS for more details about |
| | |
| | | } |
| | | } |
| | | void do_decoder(const std::vector<char>& buffer, |
| | | websocketpp::connection_hdl& hdl, const nlohmann::json& msg); |
| | | websocketpp::connection_hdl& hdl, |
| | | websocketpp::lib::mutex& thread_lock, |
| | | std::vector<std::vector<float>> &hotwords_embedding, |
| | | std::string wav_name, std::string wav_format); |
| | | |
| | | void initAsr(std::map<std::string, std::string>& model_path, int thread_num); |
| | | void on_message(websocketpp::connection_hdl hdl, message_ptr msg); |