From b454a1054fadbff0ee963944ff42f66b98317582 Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期二, 08 八月 2023 11:17:43 +0800
Subject: [PATCH] update online runtime, including vad-online, paraformer-online, punc-online,2pass (#815)
---
funasr/runtime/onnxruntime/include/offline-stream.h | 6
funasr/runtime/websocket/funasr-wss-client-2pass.cpp | 430 ++++
funasr/runtime/onnxruntime/src/model.cpp | 57
funasr/runtime/grpc/Readme.md | 220 -
funasr/runtime/onnxruntime/bin/funasr-onnx-online-asr.cpp | 174 +
funasr/runtime/onnxruntime/src/ct-transformer-online.cpp | 9
funasr/runtime/python/grpc/Readme.md | 70
funasr/runtime/onnxruntime/src/fsmn-vad-online.h | 2
funasr/export/models/decoder/sanm_decoder.py | 155 +
funasr/runtime/grpc/paraformer-server.h | 88
funasr/export/export_model.py | 21
funasr/runtime/websocket/funasr-wss-server.cpp | 2
funasr/runtime/onnxruntime/include/tpass-stream.h | 31
funasr/runtime/python/onnxruntime/demo_paraformer_online.py | 30
funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py | 309 +++
funasr/runtime/websocket/readme.md | 12
funasr/runtime/onnxruntime/src/precomp.h | 6
funasr/runtime/websocket/funasr-wss-server-2pass.cpp | 419 ++++
funasr/runtime/onnxruntime/src/audio.cpp | 282 ++
funasr/runtime/onnxruntime/bin/CMakeLists.txt | 16
funasr/runtime/onnxruntime/include/com-define.h | 21
funasr/runtime/grpc/build.sh | 15
funasr/export/models/predictor/cif.py | 21
funasr/export/models/encoder/sanm_encoder.py | 7
funasr/runtime/onnxruntime/include/tpass-online-stream.h | 20
funasr/export/models/modules/multihead_att.py | 8
funasr/runtime/grpc/paraformer-server.cc | 438 ++--
funasr/runtime/grpc/run_server.sh | 12
funasr/runtime/docs/SDK_advanced_guide_online.md | 8
funasr/runtime/python/grpc/requirements.txt | 2
funasr/runtime/python/grpc/proto/paraformer.proto | 33
funasr/runtime/onnxruntime/bin/funasr-onnx-online-rtf.cpp | 278 ++
funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp | 7
funasr/export/models/e2e_asr_paraformer.py | 151 +
.gitignore | 3
funasr/runtime/onnxruntime/src/funasrruntime.cpp | 149 +
funasr/runtime/onnxruntime/include/model.h | 12
funasr/runtime/onnxruntime/src/paraformer-online.h | 111 +
funasr/runtime/python/grpc/grpc_main_client.py | 128
funasr/runtime/onnxruntime/src/commonfunc.h | 3
funasr/export/models/__init__.py | 7
funasr/runtime/docs/SDK_tutorial_online.md | 2
funasr/runtime/websocket/websocket-server-2pass.cpp | 369 +++
funasr/runtime/onnxruntime/include/audio.h | 34
funasr/runtime/onnxruntime/bin/funasr-onnx-2pass.cpp | 217 ++
funasr/runtime/onnxruntime/include/funasrruntime.h | 22
funasr/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py | 22
funasr/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp | 310 +++
funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp | 3
funasr/runtime/onnxruntime/src/paraformer.h | 66
/dev/null | 4
funasr/runtime/websocket/CMakeLists.txt | 4
funasr/runtime/grpc/CMakeLists.txt | 87
funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp | 2
funasr/runtime/onnxruntime/src/paraformer-online.cpp | 551 +++++
funasr/runtime/websocket/websocket-server-2pass.h | 148 +
funasr/runtime/onnxruntime/src/tpass-stream.cpp | 87
funasr/runtime/onnxruntime/src/paraformer.cpp | 218 +
funasr/runtime/onnxruntime/src/tpass-online-stream.cpp | 29
59 files changed, 5,192 insertions(+), 756 deletions(-)
diff --git a/.gitignore b/.gitignore
index d47674c..37f39fe 100644
--- a/.gitignore
+++ b/.gitignore
@@ -19,4 +19,5 @@
funasr.egg-info
docs/_build
modelscope
-samples
\ No newline at end of file
+samples
+.ipynb_checkpoints
diff --git a/funasr/export/export_model.py b/funasr/export/export_model.py
index f31f960..8c3108b 100644
--- a/funasr/export/export_model.py
+++ b/funasr/export/export_model.py
@@ -55,18 +55,21 @@
# export encoder1
self.export_config["model_name"] = "model"
- model = get_model(
+ models = get_model(
model,
self.export_config,
)
- model.eval()
- # self._export_onnx(model, verbose, export_dir)
- if self.onnx:
- self._export_onnx(model, verbose, export_dir)
- else:
- self._export_torchscripts(model, verbose, export_dir)
-
- print("output dir: {}".format(export_dir))
+ if not isinstance(models, tuple):
+ models = (models,)
+
+ for i, model in enumerate(models):
+ model.eval()
+ if self.onnx:
+ self._export_onnx(model, verbose, export_dir)
+ else:
+ self._export_torchscripts(model, verbose, export_dir)
+
+ print("output dir: {}".format(export_dir))
def _torch_quantize(self, model):
diff --git a/funasr/export/models/__init__.py b/funasr/export/models/__init__.py
index 6177119..fd0a15c 100644
--- a/funasr/export/models/__init__.py
+++ b/funasr/export/models/__init__.py
@@ -1,4 +1,4 @@
-from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer
+from funasr.models.e2e_asr_paraformer import Paraformer, BiCifParaformer, ParaformerOnline
from funasr.export.models.e2e_asr_paraformer import Paraformer as Paraformer_export
from funasr.export.models.e2e_asr_paraformer import BiCifParaformer as BiCifParaformer_export
from funasr.export.models.e2e_asr_conformer import Conformer as Conformer_export
@@ -10,10 +10,15 @@
from funasr.train.abs_model import PunctuationModel
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
from funasr.export.models.CT_Transformer import CT_Transformer_VadRealtime as CT_Transformer_VadRealtime_export
+from funasr.export.models.e2e_asr_paraformer import ParaformerOnline_encoder_predictor as ParaformerOnline_encoder_predictor_export
+from funasr.export.models.e2e_asr_paraformer import ParaformerOnline_decoder as ParaformerOnline_decoder_export
def get_model(model, export_config=None):
if isinstance(model, BiCifParaformer):
return BiCifParaformer_export(model, **export_config)
+ elif isinstance(model, ParaformerOnline):
+ return (ParaformerOnline_encoder_predictor_export(model, model_name="model"),
+ ParaformerOnline_decoder_export(model, model_name="decoder"))
elif isinstance(model, Paraformer):
return Paraformer_export(model, **export_config)
elif isinstance(model, Conformer_export):
diff --git a/funasr/export/models/decoder/sanm_decoder.py b/funasr/export/models/decoder/sanm_decoder.py
index 9084b7f..6966847 100644
--- a/funasr/export/models/decoder/sanm_decoder.py
+++ b/funasr/export/models/decoder/sanm_decoder.py
@@ -157,3 +157,158 @@
"n_layers": len(self.model.decoders) + len(self.model.decoders2),
"odim": self.model.decoders[0].size
}
+
+
+class ParaformerSANMDecoderOnline(nn.Module):
+ def __init__(self, model,
+ max_seq_len=512,
+ model_name='decoder',
+ onnx: bool = True, ):
+ super().__init__()
+ # self.embed = model.embed #Embedding(model.embed, max_seq_len)
+ self.model = model
+ if onnx:
+ self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+ else:
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ for i, d in enumerate(self.model.decoders):
+ if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
+ d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
+ if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+ d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
+ if isinstance(d.src_attn, MultiHeadedAttentionCrossAtt):
+ d.src_attn = MultiHeadedAttentionCrossAtt_export(d.src_attn)
+ self.model.decoders[i] = DecoderLayerSANM_export(d)
+
+ if self.model.decoders2 is not None:
+ for i, d in enumerate(self.model.decoders2):
+ if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
+ d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
+ if isinstance(d.self_attn, MultiHeadedAttentionSANMDecoder):
+ d.self_attn = MultiHeadedAttentionSANMDecoder_export(d.self_attn)
+ self.model.decoders2[i] = DecoderLayerSANM_export(d)
+
+ for i, d in enumerate(self.model.decoders3):
+ if isinstance(d.feed_forward, PositionwiseFeedForwardDecoderSANM):
+ d.feed_forward = PositionwiseFeedForwardDecoderSANM_export(d.feed_forward)
+ self.model.decoders3[i] = DecoderLayerSANM_export(d)
+
+ self.output_layer = model.output_layer
+ self.after_norm = model.after_norm
+ self.model_name = model_name
+
+ def prepare_mask(self, mask):
+ mask_3d_btd = mask[:, :, None]
+ if len(mask.shape) == 2:
+ mask_4d_bhlt = 1 - mask[:, None, None, :]
+ elif len(mask.shape) == 3:
+ mask_4d_bhlt = 1 - mask[:, None, :]
+ mask_4d_bhlt = mask_4d_bhlt * -10000.0
+
+ return mask_3d_btd, mask_4d_bhlt
+
+ def forward(
+ self,
+ hs_pad: torch.Tensor,
+ hlens: torch.Tensor,
+ ys_in_pad: torch.Tensor,
+ ys_in_lens: torch.Tensor,
+ *args,
+ ):
+
+ tgt = ys_in_pad
+ tgt_mask = self.make_pad_mask(ys_in_lens)
+ tgt_mask, _ = self.prepare_mask(tgt_mask)
+ # tgt_mask = myutils.sequence_mask(ys_in_lens, device=tgt.device)[:, :, None]
+
+ memory = hs_pad
+ memory_mask = self.make_pad_mask(hlens)
+ _, memory_mask = self.prepare_mask(memory_mask)
+ # memory_mask = myutils.sequence_mask(hlens, device=memory.device)[:, None, :]
+
+ x = tgt
+ out_caches = list()
+ for i, decoder in enumerate(self.model.decoders):
+ in_cache = args[i]
+ x, tgt_mask, memory, memory_mask, out_cache = decoder(
+ x, tgt_mask, memory, memory_mask, cache=in_cache
+ )
+ out_caches.append(out_cache)
+ if self.model.decoders2 is not None:
+ for i, decoder in enumerate(self.model.decoders2):
+ in_cache = args[i+len(self.model.decoders)]
+ x, tgt_mask, memory, memory_mask, out_cache = decoder(
+ x, tgt_mask, memory, memory_mask, cache=in_cache
+ )
+ out_caches.append(out_cache)
+ x, tgt_mask, memory, memory_mask, _ = self.model.decoders3(
+ x, tgt_mask, memory, memory_mask
+ )
+ x = self.after_norm(x)
+ x = self.output_layer(x)
+
+ return x, out_caches
+
+ def get_dummy_inputs(self, enc_size):
+ enc = torch.randn(2, 100, enc_size).type(torch.float32)
+ enc_len = torch.tensor([30, 100], dtype=torch.int32)
+ acoustic_embeds = torch.randn(2, 10, enc_size).type(torch.float32)
+ acoustic_embeds_len = torch.tensor([5, 10], dtype=torch.int32)
+ cache_num = len(self.model.decoders)
+ if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
+ cache_num += len(self.model.decoders2)
+ cache = [
+ torch.zeros((2, self.model.decoders[0].size, self.model.decoders[0].self_attn.kernel_size-1), dtype=torch.float32)
+ for _ in range(cache_num)
+ ]
+ return (enc, enc_len, acoustic_embeds, acoustic_embeds_len, *cache)
+
+ def get_input_names(self):
+ cache_num = len(self.model.decoders)
+ if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
+ cache_num += len(self.model.decoders2)
+ return ['enc', 'enc_len', 'acoustic_embeds', 'acoustic_embeds_len'] \
+ + ['in_cache_%d' % i for i in range(cache_num)]
+
+ def get_output_names(self):
+ cache_num = len(self.model.decoders)
+ if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
+ cache_num += len(self.model.decoders2)
+ return ['logits', 'sample_ids'] \
+ + ['out_cache_%d' % i for i in range(cache_num)]
+
+ def get_dynamic_axes(self):
+ ret = {
+ 'enc': {
+ 0: 'batch_size',
+ 1: 'enc_length'
+ },
+ 'acoustic_embeds': {
+ 0: 'batch_size',
+ 1: 'token_length'
+ },
+ 'enc_len': {
+ 0: 'batch_size',
+ },
+ 'acoustic_embeds_len': {
+ 0: 'batch_size',
+ },
+
+ }
+ cache_num = len(self.model.decoders)
+ if hasattr(self.model, 'decoders2') and self.model.decoders2 is not None:
+ cache_num += len(self.model.decoders2)
+ ret.update({
+ 'in_cache_%d' % d: {
+ 0: 'batch_size',
+ }
+ for d in range(cache_num)
+ })
+ ret.update({
+ 'out_cache_%d' % d: {
+ 0: 'batch_size',
+ }
+ for d in range(cache_num)
+ })
+ return ret
diff --git a/funasr/export/models/e2e_asr_paraformer.py b/funasr/export/models/e2e_asr_paraformer.py
index 52ad320..5697b77 100644
--- a/funasr/export/models/e2e_asr_paraformer.py
+++ b/funasr/export/models/e2e_asr_paraformer.py
@@ -4,7 +4,7 @@
from funasr.export.utils.torch_function import MakePadMask
from funasr.export.utils.torch_function import sequence_mask
-from funasr.models.encoder.sanm_encoder import SANMEncoder
+from funasr.models.encoder.sanm_encoder import SANMEncoder, SANMEncoderChunkOpt
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.export.models.encoder.sanm_encoder import SANMEncoder as SANMEncoder_export
from funasr.export.models.encoder.conformer_encoder import ConformerEncoder as ConformerEncoder_export
@@ -15,6 +15,7 @@
from funasr.models.decoder.transformer_decoder import ParaformerDecoderSAN
from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoder as ParaformerSANMDecoder_export
from funasr.export.models.decoder.transformer_decoder import ParaformerDecoderSAN as ParaformerDecoderSAN_export
+from funasr.export.models.decoder.sanm_decoder import ParaformerSANMDecoderOnline as ParaformerSANMDecoderOnline_export
class Paraformer(nn.Module):
@@ -216,4 +217,150 @@
0: 'batch_size',
1: 'alphas_length'
},
- }
\ No newline at end of file
+ }
+
+
+class ParaformerOnline_encoder_predictor(nn.Module):
+ """
+ Author: Speech Lab, Alibaba Group, China
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self,
+ model,
+ max_seq_len=512,
+ feats_dim=560,
+ model_name='model',
+ **kwargs,
+ ):
+ super().__init__()
+ onnx = False
+ if "onnx" in kwargs:
+ onnx = kwargs["onnx"]
+ if isinstance(model.encoder, SANMEncoder) or isinstance(model.encoder, SANMEncoderChunkOpt):
+ self.encoder = SANMEncoder_export(model.encoder, onnx=onnx)
+ elif isinstance(model.encoder, ConformerEncoder):
+ self.encoder = ConformerEncoder_export(model.encoder, onnx=onnx)
+ if isinstance(model.predictor, CifPredictorV2):
+ self.predictor = CifPredictorV2_export(model.predictor)
+
+ self.feats_dim = feats_dim
+ self.model_name = model_name
+
+ if onnx:
+ self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+ else:
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ ):
+ # a. To device
+ batch = {"speech": speech, "speech_lengths": speech_lengths, "online": True}
+ # batch = to_device(batch, device=self.device)
+
+ enc, enc_len = self.encoder(**batch)
+ mask = self.make_pad_mask(enc_len)[:, None, :]
+ alphas, _ = self.predictor.forward_cnn(enc, mask)
+
+ return enc, enc_len, alphas
+
+ def get_dummy_inputs(self):
+ speech = torch.randn(2, 30, self.feats_dim)
+ speech_lengths = torch.tensor([6, 30], dtype=torch.int32)
+ return (speech, speech_lengths)
+
+ def get_input_names(self):
+ return ['speech', 'speech_lengths']
+
+ def get_output_names(self):
+ return ['enc', 'enc_len', 'alphas']
+
+ def get_dynamic_axes(self):
+ return {
+ 'speech': {
+ 0: 'batch_size',
+ 1: 'feats_length'
+ },
+ 'speech_lengths': {
+ 0: 'batch_size',
+ },
+ 'enc': {
+ 0: 'batch_size',
+ 1: 'feats_length'
+ },
+ 'enc_len': {
+ 0: 'batch_size',
+ },
+ 'alphas': {
+ 0: 'batch_size',
+ 1: 'feats_length'
+ },
+ }
+
+
+class ParaformerOnline_decoder(nn.Module):
+ """
+ Author: Speech Lab, Alibaba Group, China
+ Paraformer: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ https://arxiv.org/abs/2206.08317
+ """
+
+ def __init__(
+ self,
+ model,
+ max_seq_len=512,
+ feats_dim=560,
+ model_name='model',
+ **kwargs,
+ ):
+ super().__init__()
+ onnx = False
+ if "onnx" in kwargs:
+ onnx = kwargs["onnx"]
+
+ if isinstance(model.decoder, ParaformerDecoderSAN):
+ self.decoder = ParaformerDecoderSAN_export(model.decoder, onnx=onnx)
+ elif isinstance(model.decoder, ParaformerSANMDecoder):
+ self.decoder = ParaformerSANMDecoderOnline_export(model.decoder, onnx=onnx)
+
+ self.feats_dim = feats_dim
+ self.model_name = model_name
+ self.enc_size = model.encoder._output_size
+
+ if onnx:
+ self.make_pad_mask = MakePadMask(max_seq_len, flip=False)
+ else:
+ self.make_pad_mask = sequence_mask(max_seq_len, flip=False)
+
+ def forward(
+ self,
+ enc: torch.Tensor,
+ enc_len: torch.Tensor,
+ acoustic_embeds: torch.Tensor,
+ acoustic_embeds_len: torch.Tensor,
+ *args,
+ ):
+ decoder_out, out_caches = self.decoder(enc, enc_len, acoustic_embeds, acoustic_embeds_len, *args)
+ sample_ids = decoder_out.argmax(dim=-1)
+
+ return decoder_out, sample_ids, out_caches
+
+ def get_dummy_inputs(self, ):
+ dummy_inputs = self.decoder.get_dummy_inputs(enc_size=self.enc_size)
+ return dummy_inputs
+
+ def get_input_names(self):
+
+ return self.decoder.get_input_names()
+
+ def get_output_names(self):
+
+ return self.decoder.get_output_names()
+
+ def get_dynamic_axes(self):
+ return self.decoder.get_dynamic_axes()
diff --git a/funasr/export/models/encoder/sanm_encoder.py b/funasr/export/models/encoder/sanm_encoder.py
index f583f56..d1b4b1e 100644
--- a/funasr/export/models/encoder/sanm_encoder.py
+++ b/funasr/export/models/encoder/sanm_encoder.py
@@ -8,6 +8,7 @@
from funasr.export.models.modules.encoder_layer import EncoderLayerSANM as EncoderLayerSANM_export
from funasr.modules.positionwise_feed_forward import PositionwiseFeedForward
from funasr.export.models.modules.feedforward import PositionwiseFeedForward as PositionwiseFeedForward_export
+from funasr.modules.embedding import StreamSinusoidalPositionEncoder
class SANMEncoder(nn.Module):
@@ -21,6 +22,8 @@
):
super().__init__()
self.embed = model.embed
+ if isinstance(self.embed, StreamSinusoidalPositionEncoder):
+ self.embed = None
self.model = model
self.feats_dim = feats_dim
self._output_size = model._output_size
@@ -63,8 +66,10 @@
def forward(self,
speech: torch.Tensor,
speech_lengths: torch.Tensor,
+ online: bool = False
):
- speech = speech * self._output_size ** 0.5
+ if not online:
+ speech = speech * self._output_size ** 0.5
mask = self.make_pad_mask(speech_lengths)
mask = self.prepare_mask(mask)
if self.embed is None:
diff --git a/funasr/export/models/modules/multihead_att.py b/funasr/export/models/modules/multihead_att.py
index 6fce851..4885c4e 100644
--- a/funasr/export/models/modules/multihead_att.py
+++ b/funasr/export/models/modules/multihead_att.py
@@ -64,14 +64,14 @@
return self.linear_out(context_layer) # (batch, time1, d_model)
-def preprocess_for_attn(x, mask, cache, pad_fn):
+def preprocess_for_attn(x, mask, cache, pad_fn, kernel_size):
x = x * mask
x = x.transpose(1, 2)
if cache is None:
x = pad_fn(x)
else:
- x = torch.cat((cache[:, :, 1:], x), dim=2)
- cache = x
+ x = torch.cat((cache, x), dim=2)
+ cache = x[:, :, -(kernel_size-1):]
return x, cache
@@ -90,7 +90,7 @@
self.attn = None
def forward(self, inputs, mask, cache=None):
- x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn)
+ x, cache = preprocess_for_attn(inputs, mask, cache, self.pad_fn, self.kernel_size)
x = self.fsmn_block(x)
x = x.transpose(1, 2)
diff --git a/funasr/export/models/predictor/cif.py b/funasr/export/models/predictor/cif.py
index 5ea4a34..dd5dd36 100644
--- a/funasr/export/models/predictor/cif.py
+++ b/funasr/export/models/predictor/cif.py
@@ -36,6 +36,17 @@
def forward(self, hidden: torch.Tensor,
mask: torch.Tensor,
):
+ alphas, token_num = self.forward_cnn(hidden, mask)
+ mask = mask.transpose(-1, -2).float()
+ mask = mask.squeeze(-1)
+ hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
+ acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
+
+ return acoustic_embeds, token_num, alphas, cif_peak
+
+ def forward_cnn(self, hidden: torch.Tensor,
+ mask: torch.Tensor,
+ ):
h = hidden
context = h.transpose(1, 2)
queries = self.pad(context)
@@ -49,12 +60,8 @@
alphas = alphas * mask
alphas = alphas.squeeze(-1)
token_num = alphas.sum(-1)
-
- mask = mask.squeeze(-1)
- hidden, alphas, token_num = self.tail_process_fn(hidden, alphas, mask=mask)
- acoustic_embeds, cif_peak = cif(hidden, alphas, self.threshold)
-
- return acoustic_embeds, token_num, alphas, cif_peak
+
+ return alphas, token_num
def tail_process_fn(self, hidden, alphas, token_num=None, mask=None):
b, t, d = hidden.size()
@@ -285,4 +292,4 @@
integrate)
fires = torch.stack(list_fires, 1)
- return fires
\ No newline at end of file
+ return fires
diff --git a/funasr/runtime/docs/SDK_advanced_guide_online.md b/funasr/runtime/docs/SDK_advanced_guide_online.md
index c9d6f7e..f5137d7 100644
--- a/funasr/runtime/docs/SDK_advanced_guide_online.md
+++ b/funasr/runtime/docs/SDK_advanced_guide_online.md
@@ -185,7 +185,9 @@
--port: the port number of the server listener.
--wav-path: the audio input. Input can be a path to a wav file or a wav.scp file (a Kaldi-formatted wav list in which each line includes a wav_id followed by a tab and a wav_path).
--is-ssl: whether to use SSL encryption. The default is to use SSL.
---mode: offline mode.
+--mode: 2pass.
+--thread-num 1
+
```
### Custom client
@@ -194,7 +196,9 @@
```text
# First communication
-{"mode": "offline", "wav_name": "wav_name", "is_speaking": True, "wav_format":"pcm", "chunk_size":[5,10,5]}# Send wav data
+{"mode": "offline", "wav_name": "wav_name", "is_speaking": True, "wav_format":"pcm", "chunk_size":[5,10,5]}
+# Send wav data
+
Bytes data
# Send end flag
{"is_speaking": False}
diff --git a/funasr/runtime/docs/SDK_tutorial_online.md b/funasr/runtime/docs/SDK_tutorial_online.md
index bf3508f..de7ea72 100644
--- a/funasr/runtime/docs/SDK_tutorial_online.md
+++ b/funasr/runtime/docs/SDK_tutorial_online.md
@@ -76,7 +76,7 @@
After entering the samples/cpp directory, you can test it with CPP. The command is as follows:
```shell
-./funasr-wss-client --server-ip 127.0.0.1 --port 10095 --wav-path ../audio/asr_example.wav
+./funasr-wss-client-2pass --server-ip 127.0.0.1 --port 10095 --wav-path ../audio/asr_example.wav
```
Command parameter description:
diff --git a/funasr/runtime/grpc/CMakeLists.txt b/funasr/runtime/grpc/CMakeLists.txt
index 98c4787..f8e6417 100644
--- a/funasr/runtime/grpc/CMakeLists.txt
+++ b/funasr/runtime/grpc/CMakeLists.txt
@@ -1,51 +1,44 @@
-# Copyright 2018 gRPC authors.
+# Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+# Reserved. MIT License (https://opensource.org/licenses/MIT)
#
-# Licensed under the Apache License, Version 2.0 (the "License");
-# you may not use this file except in compliance with the License.
-# You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-# cmake build file for C++ paraformer example.
-# Assumes protobuf and gRPC have been installed using cmake.
-# See cmake_externalproject/CMakeLists.txt for all-in-one cmake build
-# that automatically builds all the dependencies before building paraformer.
+# 2023 by burkliu(鍒樻煆鍩�) liubaiji@xverse.cn
cmake_minimum_required(VERSION 3.10)
project(ASR C CXX)
+set(CMAKE_CXX_STANDARD 14 CACHE STRING "The C++ version to be used.")
+set(CMAKE_POSITION_INDEPENDENT_CODE ON)
+set(CMAKE_RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/bin)
+set(CMAKE_VERBOSE_MAKEFILE on)
+set(BUILD_TESTING OFF)
+
include(common.cmake)
# Proto file
-get_filename_component(rg_proto "../python/grpc/proto/paraformer.proto" ABSOLUTE)
-get_filename_component(rg_proto_path "${rg_proto}" PATH)
+get_filename_component(rg_proto ../python/grpc/proto/paraformer.proto ABSOLUTE)
+get_filename_component(rg_proto_path ${rg_proto} PATH)
# Generated sources
-set(rg_proto_srcs "${CMAKE_CURRENT_BINARY_DIR}/paraformer.pb.cc")
-set(rg_proto_hdrs "${CMAKE_CURRENT_BINARY_DIR}/paraformer.pb.h")
-set(rg_grpc_srcs "${CMAKE_CURRENT_BINARY_DIR}/paraformer.grpc.pb.cc")
-set(rg_grpc_hdrs "${CMAKE_CURRENT_BINARY_DIR}/paraformer.grpc.pb.h")
+set(rg_proto_srcs ${CMAKE_CURRENT_BINARY_DIR}/paraformer.pb.cc)
+set(rg_proto_hdrs ${CMAKE_CURRENT_BINARY_DIR}/paraformer.pb.h)
+set(rg_grpc_srcs ${CMAKE_CURRENT_BINARY_DIR}/paraformer.grpc.pb.cc)
+set(rg_grpc_hdrs ${CMAKE_CURRENT_BINARY_DIR}/paraformer.grpc.pb.h)
add_custom_command(
- OUTPUT "${rg_proto_srcs}" "${rg_proto_hdrs}" "${rg_grpc_srcs}" "${rg_grpc_hdrs}"
- COMMAND ${_PROTOBUF_PROTOC}
- ARGS --grpc_out "${CMAKE_CURRENT_BINARY_DIR}"
- --cpp_out "${CMAKE_CURRENT_BINARY_DIR}"
- -I "${rg_proto_path}"
- --plugin=protoc-gen-grpc="${_GRPC_CPP_PLUGIN_EXECUTABLE}"
- "${rg_proto}"
- DEPENDS "${rg_proto}")
+ OUTPUT ${rg_proto_srcs} ${rg_proto_hdrs} ${rg_grpc_srcs} ${rg_grpc_hdrs}
+ COMMAND ${_PROTOBUF_PROTOC}
+ ARGS --grpc_out ${CMAKE_CURRENT_BINARY_DIR}
+ --cpp_out ${CMAKE_CURRENT_BINARY_DIR}
+ -I ${rg_proto_path}
+ --plugin=protoc-gen-grpc=${_GRPC_CPP_PLUGIN_EXECUTABLE}
+ ${rg_proto}
+ DEPENDS ${rg_proto})
# Include generated *.pb.h files
-include_directories("${CMAKE_CURRENT_BINARY_DIR}")
+include_directories(${CMAKE_CURRENT_BINARY_DIR})
link_directories(${ONNXRUNTIME_DIR}/lib)
+link_directories(${FFMPEG_DIR}/lib)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/include/)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp/include/)
@@ -53,33 +46,21 @@
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp yaml-cpp)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/kaldi-native-fbank/kaldi-native-fbank/csrc csrc)
-add_subdirectory("../onnxruntime/src" onnx_src)
+add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/src src)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog)
-set(BUILD_TESTING OFF)
add_subdirectory(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/glog glog)
# rg_grpc_proto
-add_library(rg_grpc_proto
- ${rg_grpc_srcs}
- ${rg_grpc_hdrs}
- ${rg_proto_srcs}
- ${rg_proto_hdrs})
+add_library(rg_grpc_proto ${rg_grpc_srcs} ${rg_grpc_hdrs} ${rg_proto_srcs} ${rg_proto_hdrs})
-target_link_libraries(rg_grpc_proto
+target_link_libraries(rg_grpc_proto ${_REFLECTION} ${_GRPC_GRPCPP} ${_PROTOBUF_LIBPROTOBUF})
+
+add_executable(paraformer-server paraformer-server.cc)
+target_link_libraries(paraformer-server
+ rg_grpc_proto
+ funasr
+ ${EXTRA_LIBS}
${_REFLECTION}
${_GRPC_GRPCPP}
${_PROTOBUF_LIBPROTOBUF})
-
-foreach(_target
- paraformer-server)
- add_executable(${_target}
- "${_target}.cc")
- target_link_libraries(${_target}
- rg_grpc_proto
- funasr
- ${EXTRA_LIBS}
- ${_REFLECTION}
- ${_GRPC_GRPCPP}
- ${_PROTOBUF_LIBPROTOBUF})
-endforeach()
diff --git a/funasr/runtime/grpc/Readme.md b/funasr/runtime/grpc/Readme.md
index 71bb035..3edb132 100644
--- a/funasr/runtime/grpc/Readme.md
+++ b/funasr/runtime/grpc/Readme.md
@@ -2,17 +2,20 @@
## For the Server
-### Build [onnxruntime](./onnxruntime_cpp.md) as it's document
+### 1. Build [onnxruntime](../websocket/readme.md) as it's document
-### Compile and install grpc v1.52.0 in case of grpc bugs
-```
-export GRPC_INSTALL_DIR=/data/soft/grpc
-export PKG_CONFIG_PATH=$GRPC_INSTALL_DIR/lib/pkgconfig
+### 2. Compile and install grpc v1.52.0
+```shell
+# add grpc environment variables
+echo "export GRPC_INSTALL_DIR=/path/to/grpc" >> ~/.bashrc
+echo "export PKG_CONFIG_PATH=\$GRPC_INSTALL_DIR/lib/pkgconfig" >> ~/.bashrc
+echo "export PATH=\$GRPC_INSTALL_DIR/bin/:\$PKG_CONFIG_PATH:\$PATH" >> ~/.bashrc
+source ~/.bashrc
-git clone -b v1.52.0 --depth=1 https://github.com/grpc/grpc.git
+# install grpc
+git clone --recurse-submodules -b v1.52.0 --depth 1 --shallow-submodules https://github.com/grpc/grpc
+
cd grpc
-git submodule update --init --recursive
-
mkdir -p cmake/build
pushd cmake/build
cmake -DgRPC_INSTALL=ON \
@@ -22,182 +25,57 @@
make
make install
popd
-
-echo "export GRPC_INSTALL_DIR=/data/soft/grpc" >> ~/.bashrc
-echo "export PKG_CONFIG_PATH=\$GRPC_INSTALL_DIR/lib/pkgconfig" >> ~/.bashrc
-echo "export PATH=\$GRPC_INSTALL_DIR/bin/:\$PKG_CONFIG_PATH:\$PATH" >> ~/.bashrc
-source ~/.bashrc
```
-### Compile and start grpc onnx paraformer server
-```
-# set -DONNXRUNTIME_DIR=/path/to/asrmodel/onnxruntime-linux-x64-1.14.0
-./rebuild.sh
+### 3. Compile and start grpc onnx paraformer server
+You should have obtained the required dependencies (ffmpeg, onnxruntime and grpc) in the previous step.
+
+If no, run [download_ffmpeg](../onnxruntime/third_party/download_ffmpeg.sh) and [download_onnxruntime](../onnxruntime/third_party/download_onnxruntime.sh)
+
+```shell
+cd /cfs/user/burkliu/work2023/FunASR/funasr/runtime/grpc
+./build.sh
```
-### Start grpc paraformer server
-```
+### 4. Download paraformer model
+To do.
-./cmake/build/paraformer-server --port-id <string> [--punc-quant <string>]
- [--punc-dir <string>] [--vad-quant <string>]
- [--vad-dir <string>] [--quantize <string>]
- --model-dir <string> [--] [--version] [-h]
+### 5. Start grpc paraformer server
+```shell
+# run as default
+./run_server.sh
+
+# or run server directly
+./build/bin/paraformer-server \
+ --port-id <string> \
+ --offline-model-dir <string> \
+ --online-model-dir <string> \
+ --quantize <string> \
+ --vad-dir <string> \
+ --vad-quant <string> \
+ --punc-dir <string> \
+ --punc-quant <string>
+
Where:
- --port-id <string>
- (required) port id
- --model-dir <string>
- (required) the asr model path, which contains model.onnx, config.yaml, am.mvn
- --quantize <string>
- false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
+ --port-id <string> (required) the port server listen to
- --vad-dir <string>
- the vad model path, which contains model.onnx, vad.yaml, vad.mvn
- --vad-quant <string>
- false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir
+ --offline-model-dir <string> (required) the offline asr model path
+ --online-model-dir <string> (required) the online asr model path
+ --quantize <string> (optional) false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir
- --punc-dir <string>
- the punc model path, which contains model.onnx, punc.yaml
- --punc-quant <string>
- false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir
-
- Required: --port-id <string> --model-dir <string>
- If use vad, please add: --vad-dir <string>
- If use punc, please add: --punc-dir <string>
+ --vad-dir <string> (required) the vad model path
+ --vad-quant <string> (optional) false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir
+
+ --punc-dir <string> (required) the punc model path
+ --punc-quant <string> (optional) false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir
```
## For the client
+Currently we only support python grpc server.
-### Install the requirements as in [grpc-python](./docs/grpc_python.md)
-
-```shell
-git clone https://github.com/alibaba/FunASR.git && cd FunASR
-cd funasr/runtime/python/grpc
-pip install -r requirements_client.txt
-```
-
-### Generate protobuf file
-Run on server, the two generated pb files are both used for server and client
-
-```shell
-# paraformer_pb2.py and paraformer_pb2_grpc.py are already generated,
-# regenerate it only when you make changes to ./proto/paraformer.proto file.
-python -m grpc_tools.protoc --proto_path=./proto -I ./proto --python_out=. --grpc_python_out=./ ./proto/paraformer.proto
-```
-
-### Start grpc client
-```
-# Start client.
-python grpc_main_client_mic.py --host 127.0.0.1 --port 10095
-```
-
-[//]: # (```)
-
-[//]: # (# go to ../python/grpc to find this package)
-
-[//]: # (import paraformer_pb2)
-
-[//]: # ()
-[//]: # ()
-[//]: # (class RecognizeStub:)
-
-[//]: # ( def __init__(self, channel):)
-
-[//]: # ( self.Recognize = channel.stream_stream()
-
-[//]: # ( '/paraformer.ASR/Recognize',)
-
-[//]: # ( request_serializer=paraformer_pb2.Request.SerializeToString,)
-
-[//]: # ( response_deserializer=paraformer_pb2.Response.FromString,)
-
-[//]: # ( ))
-
-[//]: # ()
-[//]: # ()
-[//]: # (async def send(channel, data, speaking, isEnd):)
-
-[//]: # ( stub = RecognizeStub(channel))
-
-[//]: # ( req = paraformer_pb2.Request())
-
-[//]: # ( if data:)
-
-[//]: # ( req.audio_data = data)
-
-[//]: # ( req.user = 'zz')
-
-[//]: # ( req.language = 'zh-CN')
-
-[//]: # ( req.speaking = speaking)
-
-[//]: # ( req.isEnd = isEnd)
-
-[//]: # ( q = queue.SimpleQueue())
-
-[//]: # ( q.put(req))
-
-[//]: # ( return stub.Recognize(iter(q.get, None)))
-
-[//]: # ()
-[//]: # (# send the audio data once)
-
-[//]: # (async def grpc_rec(data, grpc_uri):)
-
-[//]: # ( with grpc.insecure_channel(grpc_uri) as channel:)
-
-[//]: # ( b = time.time())
-
-[//]: # ( response = await send(channel, data, False, False))
-
-[//]: # ( resp = response.next())
-
-[//]: # ( text = '')
-
-[//]: # ( if 'decoding' == resp.action:)
-
-[//]: # ( resp = response.next())
-
-[//]: # ( if 'finish' == resp.action:)
-
-[//]: # ( text = json.loads(resp.sentence)['text'])
-
-[//]: # ( response = await send(channel, None, False, True))
-
-[//]: # ( return {)
-
-[//]: # ( 'text': text,)
-
-[//]: # ( 'time': time.time() - b,)
-
-[//]: # ( })
-
-[//]: # ()
-[//]: # (async def test():)
-
-[//]: # ( # fc = FunAsrGrpcClient('127.0.0.1', 9900))
-
-[//]: # ( # t = await fc.rec(wav.tobytes()))
-
-[//]: # ( # print(t))
-
-[//]: # ( wav, _ = sf.read('z-10s.wav', dtype='int16'))
-
-[//]: # ( uri = '127.0.0.1:9900')
-
-[//]: # ( res = await grpc_rec(wav.tobytes(), uri))
-
-[//]: # ( print(res))
-
-[//]: # ()
-[//]: # ()
-[//]: # (if __name__ == '__main__':)
-
-[//]: # ( asyncio.run(test()))
-
-[//]: # ()
-[//]: # (```)
+Install the requirements as in [grpc-python](../python/grpc/Readme.md)
## Acknowledge
1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
-2. We acknowledge [DeepScience](https://www.deepscience.cn) for contributing the grpc service.
+2. We acknowledge burkliu (鍒樻煆鍩�, liubaiji@xverse.cn) for contributing the grpc service.
diff --git a/funasr/runtime/grpc/build.sh b/funasr/runtime/grpc/build.sh
new file mode 100755
index 0000000..0311ca6
--- /dev/null
+++ b/funasr/runtime/grpc/build.sh
@@ -0,0 +1,15 @@
+#!/bin/bash
+
+mode=debug #[debug|release]
+onnxruntime_dir=`pwd`/../onnxruntime/onnxruntime-linux-x64-1.14.0
+ffmpeg_dir=`pwd`/../onnxruntime/ffmpeg-N-111383-g20b8688092-linux64-gpl-shared
+
+
+rm build -rf
+mkdir -p build
+cd build
+
+cmake -DCMAKE_BUILD_TYPE=$mode ../ -DONNXRUNTIME_DIR=$onnxruntime_dir -DFFMPEG_DIR=$ffmpeg_dir
+cmake --build . -j 4
+
+echo "Build server successfully!"
diff --git a/funasr/runtime/grpc/paraformer-server.cc b/funasr/runtime/grpc/paraformer-server.cc
index 734dadc..0fb047f 100644
--- a/funasr/runtime/grpc/paraformer-server.cc
+++ b/funasr/runtime/grpc/paraformer-server.cc
@@ -1,235 +1,261 @@
-#include <algorithm>
-#include <chrono>
-#include <cmath>
-#include <iostream>
-#include <sstream>
-#include <memory>
-#include <string>
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+ * Reserved. MIT License (https://opensource.org/licenses/MIT)
+ */
+/* 2023 by burkliu(鍒樻煆鍩�) liubaiji@xverse.cn */
-#include <grpc/grpc.h>
-#include <grpcpp/server.h>
-#include <grpcpp/server_builder.h>
-#include <grpcpp/server_context.h>
-#include <grpcpp/security/server_credentials.h>
-
-#include "paraformer.grpc.pb.h"
#include "paraformer-server.h"
-#include "tclap/CmdLine.h"
-#include "com-define.h"
-#include "glog/logging.h"
-using grpc::Server;
-using grpc::ServerBuilder;
-using grpc::ServerContext;
-using grpc::ServerReader;
-using grpc::ServerReaderWriter;
-using grpc::ServerWriter;
-using grpc::Status;
+GrpcEngine::GrpcEngine(
+ grpc::ServerReaderWriter<Response, Request>* stream,
+ std::shared_ptr<FUNASR_HANDLE> asr_handler)
+ : stream_(std::move(stream)),
+ asr_handler_(std::move(asr_handler)) {
-using paraformer::Request;
-using paraformer::Response;
-using paraformer::ASR;
-
-ASRServicer::ASRServicer(std::map<std::string, std::string>& model_path) {
- AsrHanlde=FunOfflineInit(model_path, 1);
- std::cout << "ASRServicer init" << std::endl;
- init_flag = 0;
+ request_ = std::make_shared<Request>();
}
-void ASRServicer::clear_states(const std::string& user) {
- clear_buffers(user);
- clear_transcriptions(user);
-}
+void GrpcEngine::DecodeThreadFunc() {
+ FUNASR_HANDLE tpass_online_handler = FunTpassOnlineInit(*asr_handler_, chunk_size_);
+ int step = (sampling_rate_ * step_duration_ms_ / 1000) * 2; // int16 = 2bytes;
+ std::vector<std::vector<std::string>> punc_cache(2);
-void ASRServicer::clear_buffers(const std::string& user) {
- if (client_buffers.count(user)) {
- client_buffers.erase(user);
- }
-}
+ bool is_final = false;
+ std::string online_result = "";
+ std::string tpass_result = "";
-void ASRServicer::clear_transcriptions(const std::string& user) {
- if (client_transcription.count(user)) {
- client_transcription.erase(user);
- }
-}
+ LOG(INFO) << "Decoder init, start decoding loop with mode";
-void ASRServicer::disconnect(const std::string& user) {
- clear_states(user);
- std::cout << "Disconnecting user: " << user << std::endl;
-}
+ while (true) {
+ if (audio_buffer_.length() > step || is_end_) {
+ if (audio_buffer_.length() <= step && is_end_) {
+ is_final = true;
+ step = audio_buffer_.length();
+ }
-grpc::Status ASRServicer::Recognize(
- grpc::ServerContext* context,
- grpc::ServerReaderWriter<Response, Request>* stream) {
+ FUNASR_RESULT result = FunTpassInferBuffer(*asr_handler_,
+ tpass_online_handler,
+ audio_buffer_.c_str(),
+ step,
+ punc_cache,
+ is_final,
+ sampling_rate_,
+ encoding_,
+ mode_);
+ audio_buffer_ = audio_buffer_.substr(step);
- Request req;
- while (stream->Read(&req)) {
- if (req.isend()) {
- std::cout << "asr end" << std::endl;
- disconnect(req.user());
- Response res;
- res.set_sentence(
- R"({"success": true, "detail": "asr end"})"
- );
- res.set_user(req.user());
- res.set_action("terminate");
- res.set_language(req.language());
- stream->Write(res);
- } else if (req.speaking()) {
- if (req.audio_data().size() > 0) {
- auto& buf = client_buffers[req.user()];
- buf.insert(buf.end(), req.audio_data().begin(), req.audio_data().end());
- }
- Response res;
- res.set_sentence(
- R"({"success": true, "detail": "speaking"})"
- );
- res.set_user(req.user());
- res.set_action("speaking");
- res.set_language(req.language());
- stream->Write(res);
- } else if (!req.speaking()) {
- if (client_buffers.count(req.user()) == 0 && req.audio_data().size() == 0) {
- Response res;
- res.set_sentence(
- R"({"success": true, "detail": "waiting_for_voice"})"
- );
- res.set_user(req.user());
- res.set_action("waiting");
- res.set_language(req.language());
- stream->Write(res);
- }else {
- auto begin_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
- if (req.audio_data().size() > 0) {
- auto& buf = client_buffers[req.user()];
- buf.insert(buf.end(), req.audio_data().begin(), req.audio_data().end());
- }
- std::string tmp_data = this->client_buffers[req.user()];
- this->clear_states(req.user());
-
- Response res;
- res.set_sentence(
- R"({"success": true, "detail": "decoding data: " + std::to_string(tmp_data.length()) + " bytes"})"
- );
- int data_len_int = tmp_data.length();
- std::string data_len = std::to_string(data_len_int);
- std::stringstream ss;
- ss << R"({"success": true, "detail": "decoding data: )" << data_len << R"( bytes")" << R"("})";
- std::string result = ss.str();
- res.set_sentence(result);
- res.set_user(req.user());
- res.set_action("decoding");
- res.set_language(req.language());
- stream->Write(res);
- if (tmp_data.length() < 800) { //min input_len for asr model
- auto end_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
- std::string delay_str = std::to_string(end_time - begin_time);
- std::cout << "user: " << req.user() << " , delay(ms): " << delay_str << ", error: data_is_not_long_enough" << std::endl;
- Response res;
- std::stringstream ss;
- std::string asr_result = "";
- ss << R"({"success": true, "detail": "finish_sentence","server_delay_ms":)" << delay_str << R"(,"text":")" << asr_result << R"("})";
- std::string result = ss.str();
- res.set_sentence(result);
- res.set_user(req.user());
- res.set_action("finish");
- res.set_language(req.language());
- stream->Write(res);
- }
- else {
- FUNASR_RESULT Result= FunOfflineInferBuffer(AsrHanlde, tmp_data.c_str(), data_len_int, RASR_NONE, NULL, 16000);
- std::string asr_result = ((FUNASR_RECOG_RESULT*)Result)->msg;
-
- auto end_time = std::chrono::duration_cast<std::chrono::milliseconds>(std::chrono::system_clock::now().time_since_epoch()).count();
- std::string delay_str = std::to_string(end_time - begin_time);
-
- std::cout << "user: " << req.user() << " , delay(ms): " << delay_str << ", text: " << asr_result << std::endl;
- Response res;
- std::stringstream ss;
- ss << R"({"success": true, "detail": "finish_sentence","server_delay_ms":)" << delay_str << R"(,"text":")" << asr_result << R"("})";
- std::string result = ss.str();
- res.set_sentence(result);
- res.set_user(req.user());
- res.set_action("finish");
- res.set_language(req.language());
-
- stream->Write(res);
- }
- }
- }else {
- Response res;
- res.set_sentence(
- R"({"success": false, "detail": "error, no condition matched! Unknown reason."})"
- );
- res.set_user(req.user());
- res.set_action("terminate");
- res.set_language(req.language());
- stream->Write(res);
+ if (result) {
+ std::string online_message = FunASRGetResult(result, 0);
+ online_result += online_message;
+ if(online_message != ""){
+ Response response;
+ response.set_mode(DecodeMode::online);
+ response.set_text(online_message);
+ response.set_is_final(is_final);
+ stream_->Write(response);
+ LOG(INFO) << "send online results: " << online_message;
}
+ std::string tpass_message = FunASRGetTpassResult(result, 0);
+ tpass_result += tpass_message;
+ if(tpass_message != ""){
+ Response response;
+ response.set_mode(DecodeMode::two_pass);
+ response.set_text(tpass_message);
+ response.set_is_final(is_final);
+ stream_->Write(response);
+ LOG(INFO) << "send offline results: " << tpass_message;
+ }
+ FunASRFreeResult(result);
+ }
+
+ if (is_final) {
+ FunTpassOnlineUninit(tpass_online_handler);
+ break;
+ }
}
- return Status::OK;
+ sleep(0.001);
+ }
}
-void RunServer(std::map<std::string, std::string>& model_path) {
- std::string port;
- try{
- port = model_path.at(PORT_ID);
- }catch(std::exception const &e){
- printf("Error when read port.\n");
- exit(0);
+void GrpcEngine::OnSpeechStart() {
+ if (request_->chunk_size_size() == 3) {
+ for (int i = 0; i < 3; i++) {
+ chunk_size_[i] = int(request_->chunk_size(i));
}
- std::string server_address;
- server_address = "0.0.0.0:" + port;
- ASRServicer service(model_path);
+ }
+ std::string chunk_size_str;
+ for (int i = 0; i < 3; i++) {
+ chunk_size_str = " " + chunk_size_[i];
+ }
+ LOG(INFO) << "chunk_size is" << chunk_size_str;
- ServerBuilder builder;
- builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
- builder.RegisterService(&service);
- std::unique_ptr<Server> server(builder.BuildAndStart());
- std::cout << "Server listening on " << server_address << std::endl;
- server->Wait();
+ if (request_->sampling_rate() != 0) {
+ sampling_rate_ = request_->sampling_rate();
+ }
+ LOG(INFO) << "sampling_rate is " << sampling_rate_;
+
+ switch(request_->wav_format()) {
+ case WavFormat::pcm: encoding_ = "pcm";
+ }
+ LOG(INFO) << "encoding is " << encoding_;
+
+ std::string mode_str;
+ switch(request_->mode()) {
+ case DecodeMode::offline:
+ mode_ = ASR_OFFLINE;
+ mode_str = "offline";
+ break;
+ case DecodeMode::online:
+ mode_ = ASR_ONLINE;
+ mode_str = "online";
+ break;
+ case DecodeMode::two_pass:
+ mode_ = ASR_TWO_PASS;
+ mode_str = "two_pass";
+ break;
+ }
+ LOG(INFO) << "decode mode is " << mode_str;
+
+ decode_thread_ = std::make_shared<std::thread>(&GrpcEngine::DecodeThreadFunc, this);
+ is_start_ = true;
}
-void GetValue(TCLAP::ValueArg<std::string>& value_arg, std::string key, std::map<std::string, std::string>& model_path)
-{
- if (value_arg.isSet()){
- model_path.insert({key, value_arg.getValue()});
- LOG(INFO)<< key << " : " << value_arg.getValue();
+void GrpcEngine::OnSpeechData() {
+ audio_buffer_ += request_->audio_data();
+}
+
+void GrpcEngine::OnSpeechEnd() {
+ is_end_ = true;
+ LOG(INFO) << "Read all pcm data, wait for decoding thread";
+ if (decode_thread_ != nullptr) {
+ decode_thread_->join();
+ }
+}
+
+void GrpcEngine::operator()() {
+ try {
+ LOG(INFO) << "start engine main loop";
+ while (stream_->Read(request_.get())) {
+ LOG(INFO) << "receive data";
+ if (!is_start_) {
+ OnSpeechStart();
+ }
+ OnSpeechData();
+ if (request_->is_final()) {
+ break;
+ }
}
+ OnSpeechEnd();
+ LOG(INFO) << "Connect finish";
+ } catch (std::exception const& e) {
+ LOG(ERROR) << e.what();
+ }
+}
+
+GrpcService::GrpcService(std::map<std::string, std::string>& config, int onnx_thread)
+ : config_(config) {
+
+ asr_handler_ = std::make_shared<FUNASR_HANDLE>(std::move(FunTpassInit(config_, onnx_thread)));
+ LOG(INFO) << "GrpcService model loaded";
+
+ std::vector<int> chunk_size = {5, 10, 5};
+ FUNASR_HANDLE tmp_online_handler = FunTpassOnlineInit(*asr_handler_, chunk_size);
+ int sampling_rate = 16000;
+ int buffer_len = sampling_rate * 1;
+ std::string tmp_data(buffer_len, '0');
+ std::vector<std::vector<std::string>> punc_cache(2);
+ bool is_final = true;
+ std::string encoding = "pcm";
+ FUNASR_RESULT result = FunTpassInferBuffer(*asr_handler_,
+ tmp_online_handler,
+ tmp_data.c_str(),
+ buffer_len,
+ punc_cache,
+ is_final,
+ buffer_len,
+ encoding,
+ ASR_TWO_PASS);
+ if (result) {
+ FunASRFreeResult(result);
+ }
+ FunTpassOnlineUninit(tmp_online_handler);
+ LOG(INFO) << "GrpcService model warmup";
+}
+
+grpc::Status GrpcService::Recognize(
+ grpc::ServerContext* context,
+ grpc::ServerReaderWriter<Response, Request>* stream) {
+ LOG(INFO) << "Get Recognize request";
+ GrpcEngine engine(
+ stream,
+ asr_handler_
+ );
+
+ std::thread t(std::move(engine));
+ t.join();
+ return grpc::Status::OK;
+}
+
+void GetValue(TCLAP::ValueArg<std::string>& value_arg, std::string key, std::map<std::string, std::string>& config) {
+ if (value_arg.isSet()) {
+ config.insert({key, value_arg.getValue()});
+ LOG(INFO) << key << " : " << value_arg.getValue();
+ }
}
int main(int argc, char* argv[]) {
+ FLAGS_logtostderr = true;
+ google::InitGoogleLogging(argv[0]);
- google::InitGoogleLogging(argv[0]);
- FLAGS_logtostderr = true;
+ TCLAP::CmdLine cmd("funasr-onnx-2pass", ' ', "1.0");
+ TCLAP::ValueArg<std::string> offline_model_dir("", OFFLINE_MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
+ TCLAP::ValueArg<std::string> online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains encoder.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string");
+ TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
+ TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
+ TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
+ TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
+ TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
+ TCLAP::ValueArg<std::int32_t> onnx_thread("", "onnx-inter-thread", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
+ TCLAP::ValueArg<std::string> port_id("", PORT_ID, "port id", true, "", "string");
- TCLAP::CmdLine cmd("paraformer-server", ' ', "1.0");
- TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the asr model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
- TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
- TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
- TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string");
- TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
- TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string");
- TCLAP::ValueArg<std::string> port_id("", PORT_ID, "port id", true, "", "string");
+ cmd.add(offline_model_dir);
+ cmd.add(online_model_dir);
+ cmd.add(quantize);
+ cmd.add(vad_dir);
+ cmd.add(vad_quant);
+ cmd.add(punc_dir);
+ cmd.add(punc_quant);
+ cmd.add(onnx_thread);
+ cmd.add(port_id);
+ cmd.parse(argc, argv);
- cmd.add(model_dir);
- cmd.add(quantize);
- cmd.add(vad_dir);
- cmd.add(vad_quant);
- cmd.add(punc_dir);
- cmd.add(punc_quant);
- cmd.add(port_id);
- cmd.parse(argc, argv);
+ std::map<std::string, std::string> config;
+ GetValue(offline_model_dir, OFFLINE_MODEL_DIR, config);
+ GetValue(online_model_dir, ONLINE_MODEL_DIR, config);
+ GetValue(quantize, QUANTIZE, config);
+ GetValue(vad_dir, VAD_DIR, config);
+ GetValue(vad_quant, VAD_QUANT, config);
+ GetValue(punc_dir, PUNC_DIR, config);
+ GetValue(punc_quant, PUNC_QUANT, config);
+ GetValue(port_id, PORT_ID, config);
- std::map<std::string, std::string> model_path;
- GetValue(model_dir, MODEL_DIR, model_path);
- GetValue(quantize, QUANTIZE, model_path);
- GetValue(vad_dir, VAD_DIR, model_path);
- GetValue(vad_quant, VAD_QUANT, model_path);
- GetValue(punc_dir, PUNC_DIR, model_path);
- GetValue(punc_quant, PUNC_QUANT, model_path);
- GetValue(port_id, PORT_ID, model_path);
+ std::string port;
+ try {
+ port = config.at(PORT_ID);
+ } catch(std::exception const &e) {
+ LOG(INFO) << ("Error when read port.");
+ exit(0);
+ }
+ std::string server_address;
+ server_address = "0.0.0.0:" + port;
+ GrpcService service(config, onnx_thread);
- RunServer(model_path);
- return 0;
+ grpc::ServerBuilder builder;
+ builder.AddListeningPort(server_address, grpc::InsecureServerCredentials());
+ builder.RegisterService(&service);
+ std::unique_ptr<grpc::Server> server(builder.BuildAndStart());
+ LOG(INFO) << "Server listening on " << server_address;
+ server->Wait();
+
+ return 0;
}
diff --git a/funasr/runtime/grpc/paraformer-server.h b/funasr/runtime/grpc/paraformer-server.h
index 760ea2a..8753e5c 100644
--- a/funasr/runtime/grpc/paraformer-server.h
+++ b/funasr/runtime/grpc/paraformer-server.h
@@ -1,55 +1,65 @@
-#include <algorithm>
-#include <chrono>
-#include <cmath>
-#include <iostream>
-#include <memory>
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+ * Reserved. MIT License (https://opensource.org/licenses/MIT)
+ */
+/* 2023 by burkliu(鍒樻煆鍩�) liubaiji@xverse.cn */
+
#include <string>
+#include <thread>
+#include <unistd.h>
-#include <grpc/grpc.h>
-#include <grpcpp/server.h>
-#include <grpcpp/server_builder.h>
-#include <grpcpp/server_context.h>
-#include <grpcpp/security/server_credentials.h>
-
-#include <unordered_map>
-#include <chrono>
-
+#include "grpcpp/server_builder.h"
#include "paraformer.grpc.pb.h"
#include "funasrruntime.h"
+#include "tclap/CmdLine.h"
+#include "com-define.h"
+#include "glog/logging.h"
-
-using grpc::Server;
-using grpc::ServerBuilder;
-using grpc::ServerContext;
-using grpc::ServerReader;
-using grpc::ServerReaderWriter;
-using grpc::ServerWriter;
-using grpc::Status;
-
-
+using paraformer::WavFormat;
+using paraformer::DecodeMode;
using paraformer::Request;
using paraformer::Response;
using paraformer::ASR;
typedef struct
{
- std::string msg;
- float snippet_time;
-}FUNASR_RECOG_RESULT;
+ std::string msg;
+ float snippet_time;
+} FUNASR_RECOG_RESULT;
-class ASRServicer final : public ASR::Service {
- private:
- int init_flag;
- std::unordered_map<std::string, std::string> client_buffers;
- std::unordered_map<std::string, std::string> client_transcription;
+class GrpcEngine {
+ public:
+ GrpcEngine(grpc::ServerReaderWriter<Response, Request>* stream, std::shared_ptr<FUNASR_HANDLE> asr_handler);
+ void operator()();
+ private:
+ void DecodeThreadFunc();
+ void OnSpeechStart();
+ void OnSpeechData();
+ void OnSpeechEnd();
+
+ grpc::ServerReaderWriter<Response, Request>* stream_;
+ std::shared_ptr<Request> request_;
+ std::shared_ptr<Response> response_;
+ std::shared_ptr<FUNASR_HANDLE> asr_handler_;
+ std::string audio_buffer_;
+ std::shared_ptr<std::thread> decode_thread_ = nullptr;
+ bool is_start_ = false;
+ bool is_end_ = false;
+
+ std::vector<int> chunk_size_ = {5, 10, 5};
+ int sampling_rate_ = 16000;
+ std::string encoding_;
+ ASR_TYPE mode_ = ASR_TWO_PASS;
+ int step_duration_ms_ = 100;
+};
+
+class GrpcService final : public ASR::Service {
public:
- ASRServicer(std::map<std::string, std::string>& model_path);
- void clear_states(const std::string& user);
- void clear_buffers(const std::string& user);
- void clear_transcriptions(const std::string& user);
- void disconnect(const std::string& user);
+ GrpcService(std::map<std::string, std::string>& config, int num_thread);
grpc::Status Recognize(grpc::ServerContext* context, grpc::ServerReaderWriter<Response, Request>* stream);
- FUNASR_HANDLE AsrHanlde;
-
+
+ private:
+ std::map<std::string, std::string> config_;
+ std::shared_ptr<FUNASR_HANDLE> asr_handler_;
};
diff --git a/funasr/runtime/grpc/rebuild.sh b/funasr/runtime/grpc/rebuild.sh
deleted file mode 100644
index 9b41ed6..0000000
--- a/funasr/runtime/grpc/rebuild.sh
+++ /dev/null
@@ -1,12 +0,0 @@
-#!/bin/bash
-
-rm cmake -rf
-mkdir -p cmake/build
-
-cd cmake/build
-
-cmake -DCMAKE_BUILD_TYPE=release ../.. -DONNXRUNTIME_DIR=/data/asrmodel/onnxruntime-linux-x64-1.14.0
-make
-
-
-echo "Build cmake/build/paraformer_server successfully!"
diff --git a/funasr/runtime/grpc/run_server.sh b/funasr/runtime/grpc/run_server.sh
new file mode 100755
index 0000000..7636a10
--- /dev/null
+++ b/funasr/runtime/grpc/run_server.sh
@@ -0,0 +1,12 @@
+#!/bin/bash
+
+./build/bin/paraformer-server \
+ --port-id 10100 \
+ --offline-model-dir funasr_models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \
+ --online-model-dir funasr_models/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online \
+ --quantize true \
+ --vad-dir funasr_models/damo/speech_fsmn_vad_zh-cn-16k-common-onnx \
+ --vad-quant true \
+ --punc-dir funasr_models/damo/punc_ct-transformer_zh-cn-common-vad_realtime-vocab272727 \
+ --punc-quant true \
+ 2>&1
diff --git a/funasr/runtime/onnxruntime/bin/CMakeLists.txt b/funasr/runtime/onnxruntime/bin/CMakeLists.txt
index 03c3a64..4870922 100644
--- a/funasr/runtime/onnxruntime/bin/CMakeLists.txt
+++ b/funasr/runtime/onnxruntime/bin/CMakeLists.txt
@@ -9,6 +9,9 @@
add_executable(funasr-onnx-online-vad "funasr-onnx-online-vad.cpp")
target_link_libraries(funasr-onnx-online-vad PUBLIC funasr)
+add_executable(funasr-onnx-online-asr "funasr-onnx-online-asr.cpp")
+target_link_libraries(funasr-onnx-online-asr PUBLIC funasr)
+
add_executable(funasr-onnx-offline-punc "funasr-onnx-offline-punc.cpp")
target_link_libraries(funasr-onnx-offline-punc PUBLIC funasr)
@@ -17,3 +20,16 @@
add_executable(funasr-onnx-offline-rtf "funasr-onnx-offline-rtf.cpp")
target_link_libraries(funasr-onnx-offline-rtf PUBLIC funasr)
+
+add_executable(funasr-onnx-2pass "funasr-onnx-2pass.cpp")
+target_link_libraries(funasr-onnx-2pass PUBLIC funasr)
+
+add_executable(funasr-onnx-2pass-rtf "funasr-onnx-2pass-rtf.cpp")
+target_link_libraries(funasr-onnx-2pass-rtf PUBLIC funasr)
+
+add_executable(funasr-onnx-online-rtf "funasr-onnx-online-rtf.cpp")
+target_link_libraries(funasr-onnx-online-rtf PUBLIC funasr)
+
+# include_directories(${FFMPEG_DIR}/include)
+# add_executable(ff "ffmpeg.cpp")
+# target_link_libraries(ff PUBLIC avutil avcodec avformat swresample)
diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp
new file mode 100644
index 0000000..c465a2b
--- /dev/null
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-2pass-rtf.cpp
@@ -0,0 +1,310 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
+#ifndef _WIN32
+#include <sys/time.h>
+#else
+#include <win_func.h>
+#endif
+
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <map>
+#include <atomic>
+#include <mutex>
+#include <thread>
+#include <glog/logging.h>
+#include "funasrruntime.h"
+#include "tclap/CmdLine.h"
+#include "com-define.h"
+#include "audio.h"
+
+using namespace std;
+
+std::atomic<int> wav_index(0);
+std::mutex mtx;
+
+bool is_target_file(const std::string& filename, const std::string target) {
+ std::size_t pos = filename.find_last_of(".");
+ if (pos == std::string::npos) {
+ return false;
+ }
+ std::string extension = filename.substr(pos + 1);
+ return (extension == target);
+}
+
+void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
+{
+ model_path.insert({key, value_arg.getValue()});
+ LOG(INFO)<< key << " : " << value_arg.getValue();
+}
+
+
+void runReg(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size, vector<string> wav_list, vector<string> wav_ids,
+ float* total_length, long* total_time, int core_id, ASR_TYPE asr_mode_) {
+
+ struct timeval start, end;
+ long seconds = 0;
+ float n_total_length = 0.0f;
+ long n_total_time = 0;
+
+ // init online features
+ FUNASR_HANDLE tpass_online_handle=FunTpassOnlineInit(tpass_handle, chunk_size);
+
+ // warm up
+ for (size_t i = 0; i < 2; i++)
+ {
+ int32_t sampling_rate_ = 16000;
+ funasr::Audio audio(1);
+ if(is_target_file(wav_list[0].c_str(), "wav")){
+ if(!audio.LoadWav2Char(wav_list[0].c_str(), &sampling_rate_)){
+ LOG(ERROR)<<"Failed to load "<< wav_list[0];
+ exit(-1);
+ }
+ }else if(is_target_file(wav_list[0].c_str(), "pcm")){
+ if (!audio.LoadPcmwav2Char(wav_list[0].c_str(), &sampling_rate_)){
+ LOG(ERROR)<<"Failed to load "<< wav_list[0];
+ exit(-1);
+ }
+ }else{
+ if (!audio.FfmpegLoad(wav_list[0].c_str(), true)){
+ LOG(ERROR)<<"Failed to load "<< wav_list[0];
+ exit(-1);
+ }
+ }
+ char* speech_buff = audio.GetSpeechChar();
+ int buff_len = audio.GetSpeechLen()*2;
+
+ int step = 1600*2;
+ bool is_final = false;
+
+ std::vector<std::vector<string>> punc_cache(2);
+ for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
+ if (sample_offset + step >= buff_len - 1) {
+ step = buff_len - sample_offset;
+ is_final = true;
+ } else {
+ is_final = false;
+ }
+ FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_);
+ if (result)
+ {
+ FunASRFreeResult(result);
+ }
+ }
+ }
+
+ while (true) {
+ // 浣跨敤鍘熷瓙鍙橀噺鑾峰彇绱㈠紩骞堕�掑
+ int i = wav_index.fetch_add(1);
+ if (i >= wav_list.size()) {
+ break;
+ }
+ int32_t sampling_rate_ = 16000;
+ funasr::Audio audio(1);
+ if(is_target_file(wav_list[i].c_str(), "wav")){
+ if(!audio.LoadWav2Char(wav_list[i].c_str(), &sampling_rate_)){
+ LOG(ERROR)<<"Failed to load "<< wav_list[i];
+ exit(-1);
+ }
+ }else if(is_target_file(wav_list[i].c_str(), "pcm")){
+ if (!audio.LoadPcmwav2Char(wav_list[i].c_str(), &sampling_rate_)){
+ LOG(ERROR)<<"Failed to load "<< wav_list[i];
+ exit(-1);
+ }
+ }else{
+ if (!audio.FfmpegLoad(wav_list[i].c_str(), true)){
+ LOG(ERROR)<<"Failed to load "<< wav_list[i];
+ exit(-1);
+ }
+ }
+ char* speech_buff = audio.GetSpeechChar();
+ int buff_len = audio.GetSpeechLen()*2;
+
+ int step = 1600*2;
+ bool is_final = false;
+
+ string online_res="";
+ string tpass_res="";
+ std::vector<std::vector<string>> punc_cache(2);
+ for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
+ if (sample_offset + step >= buff_len - 1) {
+ step = buff_len - sample_offset;
+ is_final = true;
+ } else {
+ is_final = false;
+ }
+ gettimeofday(&start, NULL);
+ FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_);
+ gettimeofday(&end, NULL);
+ seconds = (end.tv_sec - start.tv_sec);
+ long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+ n_total_time += taking_micros;
+
+ if (result)
+ {
+ string online_msg = FunASRGetResult(result, 0);
+ online_res += online_msg;
+ if(online_msg != ""){
+ LOG(INFO)<< wav_ids[i] <<" : "<<online_msg;
+ }
+ string tpass_msg = FunASRGetTpassResult(result, 0);
+ tpass_res += tpass_msg;
+ if(tpass_msg != ""){
+ LOG(INFO)<< wav_ids[i] <<" offline results : "<<tpass_msg;
+ }
+ float snippet_time = FunASRGetRetSnippetTime(result);
+ n_total_length += snippet_time;
+ FunASRFreeResult(result);
+ }
+ else
+ {
+ LOG(ERROR) << ("No return data!\n");
+ }
+ }
+ if(asr_mode_ == 2){
+ LOG(INFO) <<"Thread: " << this_thread::get_id() <<" " << wav_ids[i] << " Final online results "<<" : "<<online_res;
+ }
+ if(asr_mode_==1){
+ LOG(INFO) <<"Thread: " << this_thread::get_id() <<" " << wav_ids[i] << " Final online results "<<" : "<<tpass_res;
+ }
+ if(asr_mode_ == 0 || asr_mode_==2){
+ LOG(INFO) <<"Thread: " << this_thread::get_id() <<" " << wav_ids[i] << " Final offline results " <<" : "<<tpass_res;
+ }
+
+ }
+ {
+ lock_guard<mutex> guard(mtx);
+ *total_length += n_total_length;
+ if(*total_time < n_total_time){
+ *total_time = n_total_time;
+ }
+ }
+ FunTpassOnlineUninit(tpass_online_handle);
+}
+
+
+int main(int argc, char** argv)
+{
+ google::InitGoogleLogging(argv[0]);
+ FLAGS_logtostderr = true;
+
+ TCLAP::CmdLine cmd("funasr-onnx-2pass", ' ', "1.0");
+ TCLAP::ValueArg<std::string> offline_model_dir("", OFFLINE_MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
+ TCLAP::ValueArg<std::string> online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains encoder.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string");
+ TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
+ TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
+ TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
+ TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
+ TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
+ TCLAP::ValueArg<std::string> asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
+ TCLAP::ValueArg<std::int32_t> onnx_thread("", "onnx-inter-thread", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
+
+ TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
+
+ cmd.add(offline_model_dir);
+ cmd.add(online_model_dir);
+ cmd.add(quantize);
+ cmd.add(vad_dir);
+ cmd.add(vad_quant);
+ cmd.add(punc_dir);
+ cmd.add(punc_quant);
+ cmd.add(wav_path);
+ cmd.add(asr_mode);
+ cmd.add(onnx_thread);
+ cmd.parse(argc, argv);
+
+ std::map<std::string, std::string> model_path;
+ GetValue(offline_model_dir, OFFLINE_MODEL_DIR, model_path);
+ GetValue(online_model_dir, ONLINE_MODEL_DIR, model_path);
+ GetValue(quantize, QUANTIZE, model_path);
+ GetValue(vad_dir, VAD_DIR, model_path);
+ GetValue(vad_quant, VAD_QUANT, model_path);
+ GetValue(punc_dir, PUNC_DIR, model_path);
+ GetValue(punc_quant, PUNC_QUANT, model_path);
+ GetValue(wav_path, WAV_PATH, model_path);
+ GetValue(asr_mode, ASR_MODE, model_path);
+
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ int thread_num = onnx_thread.getValue();
+ int asr_mode_ = -1;
+ if(model_path[ASR_MODE] == "offline"){
+ asr_mode_ = 0;
+ }else if(model_path[ASR_MODE] == "online"){
+ asr_mode_ = 1;
+ }else if(model_path[ASR_MODE] == "2pass"){
+ asr_mode_ = 2;
+ }else{
+ LOG(ERROR) << "Wrong asr-mode : " << model_path[ASR_MODE];
+ exit(-1);
+ }
+ FUNASR_HANDLE tpass_hanlde=FunTpassInit(model_path, thread_num);
+
+ if (!tpass_hanlde)
+ {
+ LOG(ERROR) << "FunTpassInit init failed";
+ exit(-1);
+ }
+
+ gettimeofday(&end, NULL);
+ long seconds = (end.tv_sec - start.tv_sec);
+ long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+ LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
+
+ // read wav_path
+ vector<string> wav_list;
+ vector<string> wav_ids;
+ string default_id = "wav_default_id";
+ string wav_path_ = model_path.at(WAV_PATH);
+
+ if(is_target_file(wav_path_, "scp")){
+ ifstream in(wav_path_);
+ if (!in.is_open()) {
+ LOG(ERROR) << "Failed to open file: " << model_path.at(WAV_SCP) ;
+ return 0;
+ }
+ string line;
+ while(getline(in, line))
+ {
+ istringstream iss(line);
+ string column1, column2;
+ iss >> column1 >> column2;
+ wav_list.emplace_back(column2);
+ wav_ids.emplace_back(column1);
+ }
+ in.close();
+ }else{
+ wav_list.emplace_back(wav_path_);
+ wav_ids.emplace_back(default_id);
+ }
+
+ std::vector<int> chunk_size = {5,10,5};
+ // 澶氱嚎绋嬫祴璇�
+ float total_length = 0.0f;
+ long total_time = 0;
+ std::vector<std::thread> threads;
+
+ int rtf_threds = 5;
+ for (int i = 0; i < rtf_threds; i++)
+ {
+ threads.emplace_back(thread(runReg, tpass_hanlde, chunk_size, wav_list, wav_ids, &total_length, &total_time, i, (ASR_TYPE)asr_mode_));
+ }
+
+ for (auto& thread : threads)
+ {
+ thread.join();
+ }
+
+ LOG(INFO) << "total_time_wav " << (long)(total_length * 1000) << " ms";
+ LOG(INFO) << "total_time_comput " << total_time / 1000 << " ms";
+ LOG(INFO) << "total_rtf " << (double)total_time/ (total_length*1000000);
+ LOG(INFO) << "speedup " << 1.0/((double)total_time/ (total_length*1000000));
+
+ FunTpassUninit(tpass_hanlde);
+ return 0;
+}
+
diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-2pass.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-2pass.cpp
new file mode 100644
index 0000000..2faf56b
--- /dev/null
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-2pass.cpp
@@ -0,0 +1,217 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
+#ifndef _WIN32
+#include <sys/time.h>
+#else
+#include <win_func.h>
+#endif
+
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <map>
+#include <glog/logging.h>
+#include "funasrruntime.h"
+#include "tclap/CmdLine.h"
+#include "com-define.h"
+#include "audio.h"
+
+using namespace std;
+
+bool is_target_file(const std::string& filename, const std::string target) {
+ std::size_t pos = filename.find_last_of(".");
+ if (pos == std::string::npos) {
+ return false;
+ }
+ std::string extension = filename.substr(pos + 1);
+ return (extension == target);
+}
+
+void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
+{
+ model_path.insert({key, value_arg.getValue()});
+ LOG(INFO)<< key << " : " << value_arg.getValue();
+}
+
+int main(int argc, char** argv)
+{
+ google::InitGoogleLogging(argv[0]);
+ FLAGS_logtostderr = true;
+
+ TCLAP::CmdLine cmd("funasr-onnx-2pass", ' ', "1.0");
+ TCLAP::ValueArg<std::string> offline_model_dir("", OFFLINE_MODEL_DIR, "the asr offline model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
+ TCLAP::ValueArg<std::string> online_model_dir("", ONLINE_MODEL_DIR, "the asr online model path, which contains encoder.onnx, decoder.onnx, config.yaml, am.mvn", true, "", "string");
+ TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
+ TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad online model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
+ TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "true", "string");
+ TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc online model path, which contains model.onnx, punc.yaml", false, "", "string");
+ TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "true", "string");
+ TCLAP::ValueArg<std::string> asr_mode("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
+ TCLAP::ValueArg<std::int32_t> onnx_thread("", "onnx-inter-thread", "onnxruntime SetIntraOpNumThreads", false, 1, "int32_t");
+
+ TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
+
+ cmd.add(offline_model_dir);
+ cmd.add(online_model_dir);
+ cmd.add(quantize);
+ cmd.add(vad_dir);
+ cmd.add(vad_quant);
+ cmd.add(punc_dir);
+ cmd.add(punc_quant);
+ cmd.add(wav_path);
+ cmd.add(asr_mode);
+ cmd.add(onnx_thread);
+ cmd.parse(argc, argv);
+
+ std::map<std::string, std::string> model_path;
+ GetValue(offline_model_dir, OFFLINE_MODEL_DIR, model_path);
+ GetValue(online_model_dir, ONLINE_MODEL_DIR, model_path);
+ GetValue(quantize, QUANTIZE, model_path);
+ GetValue(vad_dir, VAD_DIR, model_path);
+ GetValue(vad_quant, VAD_QUANT, model_path);
+ GetValue(punc_dir, PUNC_DIR, model_path);
+ GetValue(punc_quant, PUNC_QUANT, model_path);
+ GetValue(wav_path, WAV_PATH, model_path);
+ GetValue(asr_mode, ASR_MODE, model_path);
+
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ int thread_num = onnx_thread.getValue();
+ int asr_mode_ = -1;
+ if(model_path[ASR_MODE] == "offline"){
+ asr_mode_ = 0;
+ }else if(model_path[ASR_MODE] == "online"){
+ asr_mode_ = 1;
+ }else if(model_path[ASR_MODE] == "2pass"){
+ asr_mode_ = 2;
+ }else{
+ LOG(ERROR) << "Wrong asr-mode : " << model_path[ASR_MODE];
+ exit(-1);
+ }
+ FUNASR_HANDLE tpass_handle=FunTpassInit(model_path, thread_num);
+
+ if (!tpass_handle)
+ {
+ LOG(ERROR) << "FunTpassInit init failed";
+ exit(-1);
+ }
+
+ gettimeofday(&end, NULL);
+ long seconds = (end.tv_sec - start.tv_sec);
+ long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+ LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
+
+ // read wav_path
+ vector<string> wav_list;
+ vector<string> wav_ids;
+ string default_id = "wav_default_id";
+ string wav_path_ = model_path.at(WAV_PATH);
+
+ if(is_target_file(wav_path_, "scp")){
+ ifstream in(wav_path_);
+ if (!in.is_open()) {
+ LOG(ERROR) << "Failed to open file: " << model_path.at(WAV_SCP) ;
+ return 0;
+ }
+ string line;
+ while(getline(in, line))
+ {
+ istringstream iss(line);
+ string column1, column2;
+ iss >> column1 >> column2;
+ wav_list.emplace_back(column2);
+ wav_ids.emplace_back(column1);
+ }
+ in.close();
+ }else{
+ wav_list.emplace_back(wav_path_);
+ wav_ids.emplace_back(default_id);
+ }
+
+ // init online features
+ std::vector<int> chunk_size = {5,10,5};
+ FUNASR_HANDLE tpass_online_handle=FunTpassOnlineInit(tpass_handle, chunk_size);
+ float snippet_time = 0.0f;
+ long taking_micros = 0;
+ for (int i = 0; i < wav_list.size(); i++) {
+ auto& wav_file = wav_list[i];
+ auto& wav_id = wav_ids[i];
+
+ int32_t sampling_rate_ = 16000;
+ funasr::Audio audio(1);
+ if(is_target_file(wav_file.c_str(), "wav")){
+ if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
+ LOG(ERROR)<<"Failed to load "<< wav_file;
+ exit(-1);
+ }
+ }else if(is_target_file(wav_file.c_str(), "pcm")){
+ if (!audio.LoadPcmwav2Char(wav_file.c_str(), &sampling_rate_)){
+ LOG(ERROR)<<"Failed to load "<< wav_file;
+ exit(-1);
+ }
+ }else{
+ if (!audio.FfmpegLoad(wav_file.c_str(), true)){
+ LOG(ERROR)<<"Failed to load "<< wav_file;
+ exit(-1);
+ }
+ }
+ char* speech_buff = audio.GetSpeechChar();
+ int buff_len = audio.GetSpeechLen()*2;
+
+ int step = 1600*2;
+ bool is_final = false;
+
+ string online_res="";
+ string tpass_res="";
+ std::vector<std::vector<string>> punc_cache(2);
+ for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
+ if (sample_offset + step >= buff_len - 1) {
+ step = buff_len - sample_offset;
+ is_final = true;
+ } else {
+ is_final = false;
+ }
+ gettimeofday(&start, NULL);
+ FUNASR_RESULT result = FunTpassInferBuffer(tpass_handle, tpass_online_handle, speech_buff+sample_offset, step, punc_cache, is_final, sampling_rate_, "pcm", (ASR_TYPE)asr_mode_);
+ gettimeofday(&end, NULL);
+ seconds = (end.tv_sec - start.tv_sec);
+ taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+
+ if (result)
+ {
+ string online_msg = FunASRGetResult(result, 0);
+ online_res += online_msg;
+ if(online_msg != ""){
+ LOG(INFO)<< wav_id <<" : "<<online_msg;
+ }
+ string tpass_msg = FunASRGetTpassResult(result, 0);
+ tpass_res += tpass_msg;
+ if(tpass_msg != ""){
+ LOG(INFO)<< wav_id <<" offline results : "<<tpass_msg;
+ }
+ snippet_time += FunASRGetRetSnippetTime(result);
+ FunASRFreeResult(result);
+ }
+ }
+ if(asr_mode_==2){
+ LOG(INFO) << wav_id << " Final online results "<<" : "<<online_res;
+ }
+ if(asr_mode_==1){
+ LOG(INFO) << wav_id << " Final online results "<<" : "<<tpass_res;
+ }
+ if(asr_mode_==0 || asr_mode_==2){
+ LOG(INFO) << wav_id << " Final offline results " <<" : "<<tpass_res;
+ }
+ }
+
+ LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
+ LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
+ LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
+ FunTpassOnlineUninit(tpass_online_handle);
+ FunTpassUninit(tpass_handle);
+ return 0;
+}
+
diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
index ee05d75..85d6f03 100644
--- a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline-rtf.cpp
@@ -40,6 +40,9 @@
for (size_t i = 0; i < 1; i++)
{
FUNASR_RESULT result=FunOfflineInfer(asr_handle, wav_list[0].c_str(), RASR_NONE, NULL, 16000);
+ if(result){
+ FunASRFreeResult(result);
+ }
}
while (true) {
diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-online-asr.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-online-asr.cpp
new file mode 100644
index 0000000..de0893f
--- /dev/null
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-online-asr.cpp
@@ -0,0 +1,174 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
+#ifndef _WIN32
+#include <sys/time.h>
+#else
+#include <win_func.h>
+#endif
+
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <map>
+#include <vector>
+#include <glog/logging.h>
+#include "funasrruntime.h"
+#include "tclap/CmdLine.h"
+#include "com-define.h"
+#include "audio.h"
+
+using namespace std;
+
+bool is_target_file(const std::string& filename, const std::string target) {
+ std::size_t pos = filename.find_last_of(".");
+ if (pos == std::string::npos) {
+ return false;
+ }
+ std::string extension = filename.substr(pos + 1);
+ return (extension == target);
+}
+
+void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
+{
+ if (value_arg.isSet()){
+ model_path.insert({key, value_arg.getValue()});
+ LOG(INFO)<< key << " : " << value_arg.getValue();
+ }
+}
+
+int main(int argc, char *argv[])
+{
+ google::InitGoogleLogging(argv[0]);
+ FLAGS_logtostderr = true;
+
+ TCLAP::CmdLine cmd("funasr-onnx-offline-vad", ' ', "1.0");
+ TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", true, "", "string");
+ TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
+
+ TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
+
+ cmd.add(model_dir);
+ cmd.add(quantize);
+ cmd.add(wav_path);
+ cmd.parse(argc, argv);
+
+ std::map<std::string, std::string> model_path;
+ GetValue(model_dir, MODEL_DIR, model_path);
+ GetValue(quantize, QUANTIZE, model_path);
+ GetValue(wav_path, WAV_PATH, model_path);
+
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ int thread_num = 1;
+ FUNASR_HANDLE asr_handle=FunASRInit(model_path, thread_num, ASR_ONLINE);
+
+ if (!asr_handle)
+ {
+ LOG(ERROR) << "FunVad init failed";
+ exit(-1);
+ }
+
+ gettimeofday(&end, NULL);
+ long seconds = (end.tv_sec - start.tv_sec);
+ long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+ LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
+
+ // read wav_path
+ vector<string> wav_list;
+ vector<string> wav_ids;
+ string default_id = "wav_default_id";
+ string wav_path_ = model_path.at(WAV_PATH);
+ if(is_target_file(wav_path_, "scp")){
+ ifstream in(wav_path_);
+ if (!in.is_open()) {
+ LOG(ERROR) << "Failed to open file: " << model_path.at(WAV_SCP) ;
+ return 0;
+ }
+ string line;
+ while(getline(in, line))
+ {
+ istringstream iss(line);
+ string column1, column2;
+ iss >> column1 >> column2;
+ wav_list.emplace_back(column2);
+ wav_ids.emplace_back(column1);
+ }
+ in.close();
+ }else{
+ wav_list.emplace_back(wav_path_);
+ wav_ids.emplace_back(default_id);
+ }
+
+ // init online features
+ FUNASR_HANDLE online_handle=FunASROnlineInit(asr_handle);
+ float snippet_time = 0.0f;
+ long taking_micros = 0;
+ for (int i = 0; i < wav_list.size(); i++) {
+ auto& wav_file = wav_list[i];
+ auto& wav_id = wav_ids[i];
+
+ int32_t sampling_rate_ = -1;
+ funasr::Audio audio(1);
+ if(is_target_file(wav_file.c_str(), "wav")){
+ if(!audio.LoadWav2Char(wav_file.c_str(), &sampling_rate_)){
+ LOG(ERROR)<<"Failed to load "<< wav_file;
+ exit(-1);
+ }
+ }else if(is_target_file(wav_file.c_str(), "pcm")){
+ if (!audio.LoadPcmwav2Char(wav_file.c_str(), &sampling_rate_)){
+ LOG(ERROR)<<"Failed to load "<< wav_file;
+ exit(-1);
+ }
+ }else{
+ if (!audio.FfmpegLoad(wav_file.c_str(), true)){
+ LOG(ERROR)<<"Failed to load "<< wav_file;
+ exit(-1);
+ }
+ }
+ char* speech_buff = audio.GetSpeechChar();
+ int buff_len = audio.GetSpeechLen()*2;
+
+ int step = 9600*2;
+ bool is_final = false;
+
+ string final_res="";
+ for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
+ if (sample_offset + step >= buff_len - 1) {
+ step = buff_len - sample_offset;
+ is_final = true;
+ } else {
+ is_final = false;
+ }
+ gettimeofday(&start, NULL);
+ FUNASR_RESULT result = FunASRInferBuffer(online_handle, speech_buff+sample_offset, step, RASR_NONE, NULL, is_final, 16000);
+ gettimeofday(&end, NULL);
+ seconds = (end.tv_sec - start.tv_sec);
+ taking_micros += ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+
+ if (result)
+ {
+ string msg = FunASRGetResult(result, 0);
+ final_res += msg;
+ LOG(INFO)<< wav_id <<" : "<<msg;
+ snippet_time += FunASRGetRetSnippetTime(result);
+ FunASRFreeResult(result);
+ }
+ else
+ {
+ LOG(ERROR) << ("No return data!\n");
+ }
+ }
+ LOG(INFO)<<"Final results " << wav_id <<" : "<<final_res;
+ }
+
+ LOG(INFO) << "Audio length: " << (double)snippet_time << " s";
+ LOG(INFO) << "Model inference takes: " << (double)taking_micros / 1000000 <<" s";
+ LOG(INFO) << "Model inference RTF: " << (double)taking_micros/ (snippet_time*1000000);
+ FunASRUninit(asr_handle);
+ FunASRUninit(online_handle);
+ return 0;
+}
+
diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-online-rtf.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-online-rtf.cpp
new file mode 100644
index 0000000..64f5e73
--- /dev/null
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-online-rtf.cpp
@@ -0,0 +1,278 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
+#ifndef _WIN32
+#include <sys/time.h>
+#else
+#include <win_func.h>
+#endif
+
+#include <glog/logging.h>
+#include "funasrruntime.h"
+#include "tclap/CmdLine.h"
+#include "com-define.h"
+
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <vector>
+#include <atomic>
+#include <mutex>
+#include <thread>
+#include <map>
+#include "audio.h"
+
+using namespace std;
+
+std::atomic<int> wav_index(0);
+std::mutex mtx;
+
+bool is_target_file(const std::string& filename, const std::string target) {
+ std::size_t pos = filename.find_last_of(".");
+ if (pos == std::string::npos) {
+ return false;
+ }
+ std::string extension = filename.substr(pos + 1);
+ return (extension == target);
+}
+
+void runReg(FUNASR_HANDLE asr_handle, vector<string> wav_list, vector<string> wav_ids,
+ float* total_length, long* total_time, int core_id) {
+
+ struct timeval start, end;
+ long seconds = 0;
+ float n_total_length = 0.0f;
+ long n_total_time = 0;
+
+ // init online features
+ FUNASR_HANDLE online_handle=FunASROnlineInit(asr_handle);
+
+ // warm up
+ for (size_t i = 0; i < 10; i++)
+ {
+ int32_t sampling_rate_ = -1;
+ funasr::Audio audio(1);
+ if(is_target_file(wav_list[0].c_str(), "wav")){
+ if(!audio.LoadWav2Char(wav_list[0].c_str(), &sampling_rate_)){
+ LOG(ERROR)<<"Failed to load "<< wav_list[0];
+ exit(-1);
+ }
+ }else if(is_target_file(wav_list[0].c_str(), "pcm")){
+ if (!audio.LoadPcmwav2Char(wav_list[0].c_str(), &sampling_rate_)){
+ LOG(ERROR)<<"Failed to load "<< wav_list[0];
+ exit(-1);
+ }
+ }else{
+ if (!audio.FfmpegLoad(wav_list[0].c_str(), true)){
+ LOG(ERROR)<<"Failed to load "<< wav_list[0];
+ exit(-1);
+ }
+ }
+ char* speech_buff = audio.GetSpeechChar();
+ int buff_len = audio.GetSpeechLen()*2;
+
+ int step = 9600*2;
+ bool is_final = false;
+
+ string final_res="";
+ for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
+ if (sample_offset + step >= buff_len - 1) {
+ step = buff_len - sample_offset;
+ is_final = true;
+ } else {
+ is_final = false;
+ }
+ FUNASR_RESULT result = FunASRInferBuffer(online_handle, speech_buff+sample_offset, step, RASR_NONE, NULL, is_final, 16000);
+ if (result)
+ {
+ FunASRFreeResult(result);
+ }
+ }
+ }
+
+ while (true) {
+ // 浣跨敤鍘熷瓙鍙橀噺鑾峰彇绱㈠紩骞堕�掑
+ int i = wav_index.fetch_add(1);
+ if (i >= wav_list.size()) {
+ break;
+ }
+ int32_t sampling_rate_ = -1;
+ funasr::Audio audio(1);
+ if(is_target_file(wav_list[i].c_str(), "wav")){
+ if(!audio.LoadWav2Char(wav_list[i].c_str(), &sampling_rate_)){
+ LOG(ERROR)<<"Failed to load "<< wav_list[i];
+ exit(-1);
+ }
+ }else if(is_target_file(wav_list[i].c_str(), "pcm")){
+ if (!audio.LoadPcmwav2Char(wav_list[i].c_str(), &sampling_rate_)){
+ LOG(ERROR)<<"Failed to load "<< wav_list[i];
+ exit(-1);
+ }
+ }else{
+ if (!audio.FfmpegLoad(wav_list[i].c_str(), true)){
+ LOG(ERROR)<<"Failed to load "<< wav_list[i];
+ exit(-1);
+ }
+ }
+ char* speech_buff = audio.GetSpeechChar();
+ int buff_len = audio.GetSpeechLen()*2;
+
+ int step = 9600*2;
+ bool is_final = false;
+
+ string final_res="";
+ for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
+ if (sample_offset + step >= buff_len - 1) {
+ step = buff_len - sample_offset;
+ is_final = true;
+ } else {
+ is_final = false;
+ }
+ gettimeofday(&start, NULL);
+ FUNASR_RESULT result = FunASRInferBuffer(online_handle, speech_buff+sample_offset, step, RASR_NONE, NULL, is_final, 16000);
+ gettimeofday(&end, NULL);
+ seconds = (end.tv_sec - start.tv_sec);
+ long taking_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+ n_total_time += taking_micros;
+
+ if (result)
+ {
+ string msg = FunASRGetResult(result, 0);
+ final_res += msg;
+ LOG(INFO) << "Thread: " << this_thread::get_id() << "," << wav_ids[i] << " : " << msg;
+ float snippet_time = FunASRGetRetSnippetTime(result);
+ n_total_length += snippet_time;
+ FunASRFreeResult(result);
+ }
+ else
+ {
+ LOG(ERROR) << ("No return data!\n");
+ }
+ }
+ LOG(INFO) << "Thread: " << this_thread::get_id() << ", Final results " << wav_ids[i] << " : " << final_res;
+
+ }
+ {
+ lock_guard<mutex> guard(mtx);
+ *total_length += n_total_length;
+ if(*total_time < n_total_time){
+ *total_time = n_total_time;
+ }
+ }
+ FunASRUninit(online_handle);
+}
+
+void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key, std::map<std::string, std::string>& model_path)
+{
+ if (value_arg.isSet()){
+ model_path.insert({key, value_arg.getValue()});
+ LOG(INFO)<< key << " : " << value_arg.getValue();
+ }
+}
+
+int main(int argc, char *argv[])
+{
+ google::InitGoogleLogging(argv[0]);
+ FLAGS_logtostderr = true;
+
+ TCLAP::CmdLine cmd("funasr-onnx-online-rtf", ' ', "1.0");
+ TCLAP::ValueArg<std::string> model_dir("", MODEL_DIR, "the model path, which contains model.onnx, config.yaml, am.mvn", true, "", "string");
+ TCLAP::ValueArg<std::string> quantize("", QUANTIZE, "false (Default), load the model of model.onnx in model_dir. If set true, load the model of model_quant.onnx in model_dir", false, "false", "string");
+ TCLAP::ValueArg<std::string> vad_dir("", VAD_DIR, "the vad model path, which contains model.onnx, vad.yaml, vad.mvn", false, "", "string");
+ TCLAP::ValueArg<std::string> vad_quant("", VAD_QUANT, "false (Default), load the model of model.onnx in vad_dir. If set true, load the model of model_quant.onnx in vad_dir", false, "false", "string");
+ TCLAP::ValueArg<std::string> punc_dir("", PUNC_DIR, "the punc model path, which contains model.onnx, punc.yaml", false, "", "string");
+ TCLAP::ValueArg<std::string> punc_quant("", PUNC_QUANT, "false (Default), load the model of model.onnx in punc_dir. If set true, load the model of model_quant.onnx in punc_dir", false, "false", "string");
+
+ TCLAP::ValueArg<std::string> wav_path("", WAV_PATH, "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)", true, "", "string");
+ TCLAP::ValueArg<std::int32_t> thread_num("", THREAD_NUM, "multi-thread num for rtf", true, 0, "int32_t");
+
+ cmd.add(model_dir);
+ cmd.add(quantize);
+ cmd.add(vad_dir);
+ cmd.add(vad_quant);
+ cmd.add(punc_dir);
+ cmd.add(punc_quant);
+ cmd.add(wav_path);
+ cmd.add(thread_num);
+ cmd.parse(argc, argv);
+
+ std::map<std::string, std::string> model_path;
+ GetValue(model_dir, MODEL_DIR, model_path);
+ GetValue(quantize, QUANTIZE, model_path);
+ GetValue(vad_dir, VAD_DIR, model_path);
+ GetValue(vad_quant, VAD_QUANT, model_path);
+ GetValue(punc_dir, PUNC_DIR, model_path);
+ GetValue(punc_quant, PUNC_QUANT, model_path);
+ GetValue(wav_path, WAV_PATH, model_path);
+
+ struct timeval start, end;
+ gettimeofday(&start, NULL);
+ FUNASR_HANDLE asr_handle=FunASRInit(model_path, 1, ASR_ONLINE);
+
+ if (!asr_handle)
+ {
+ LOG(ERROR) << "FunASR init failed";
+ exit(-1);
+ }
+
+ gettimeofday(&end, NULL);
+ long seconds = (end.tv_sec - start.tv_sec);
+ long modle_init_micros = ((seconds * 1000000) + end.tv_usec) - (start.tv_usec);
+ LOG(INFO) << "Model initialization takes " << (double)modle_init_micros / 1000000 << " s";
+
+ // read wav_path
+ vector<string> wav_list;
+ vector<string> wav_ids;
+ string default_id = "wav_default_id";
+ string wav_path_ = model_path.at(WAV_PATH);
+ if(is_target_file(wav_path_, "wav") || is_target_file(wav_path_, "pcm")){
+ wav_list.emplace_back(wav_path_);
+ wav_ids.emplace_back(default_id);
+ }
+ else if(is_target_file(wav_path_, "scp")){
+ ifstream in(wav_path_);
+ if (!in.is_open()) {
+ LOG(ERROR) << "Failed to open file: " << model_path.at(WAV_SCP) ;
+ return 0;
+ }
+ string line;
+ while(getline(in, line))
+ {
+ istringstream iss(line);
+ string column1, column2;
+ iss >> column1 >> column2;
+ wav_list.emplace_back(column2);
+ wav_ids.emplace_back(column1);
+ }
+ in.close();
+ }else{
+ LOG(ERROR)<<"Please check the wav extension!";
+ exit(-1);
+ }
+
+ // 澶氱嚎绋嬫祴璇�
+ float total_length = 0.0f;
+ long total_time = 0;
+ std::vector<std::thread> threads;
+
+ int rtf_threds = thread_num.getValue();
+ for (int i = 0; i < rtf_threds; i++)
+ {
+ threads.emplace_back(thread(runReg, asr_handle, wav_list, wav_ids, &total_length, &total_time, i));
+ }
+
+ for (auto& thread : threads)
+ {
+ thread.join();
+ }
+
+ LOG(INFO) << "total_time_wav " << (long)(total_length * 1000) << " ms";
+ LOG(INFO) << "total_time_comput " << total_time / 1000 << " ms";
+ LOG(INFO) << "total_rtf " << (double)total_time/ (total_length*1000000);
+ LOG(INFO) << "speedup " << 1.0/((double)total_time/ (total_length*1000000));
+
+ FunASRUninit(asr_handle);
+ return 0;
+}
diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp
index 68e32e5..b36771d 100644
--- a/funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-online-vad.cpp
@@ -159,7 +159,7 @@
char* speech_buff = audio.GetSpeechChar();
int buff_len = audio.GetSpeechLen()*2;
- int step = 3200;
+ int step = 800*2;
bool is_final = false;
for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
diff --git a/funasr/runtime/onnxruntime/include/audio.h b/funasr/runtime/onnxruntime/include/audio.h
index a1b6312..c8ca876 100644
--- a/funasr/runtime/onnxruntime/include/audio.h
+++ b/funasr/runtime/onnxruntime/include/audio.h
@@ -5,6 +5,7 @@
#include <stdint.h>
#include "vad-model.h"
#include "offline-stream.h"
+#include "com-define.h"
#ifndef WAV_HEADER_SIZE
#define WAV_HEADER_SIZE 44
@@ -17,11 +18,13 @@
private:
int start;
int end;
- int len;
+
public:
AudioFrame();
AudioFrame(int len);
+ AudioFrame(const AudioFrame &other);
+ AudioFrame(int start, int end, bool is_final);
~AudioFrame();
int SetStart(int val);
@@ -29,6 +32,10 @@
int GetStart();
int GetLen();
int Disp();
+ // 2pass
+ bool is_final = false;
+ float* data = nullptr;
+ int len;
};
class Audio {
@@ -38,10 +45,11 @@
char* speech_char=nullptr;
int speech_len;
int speech_align_len;
- int offset;
float align_size;
int data_type;
queue<AudioFrame *> frame_queue;
+ queue<AudioFrame *> asr_online_queue;
+ queue<AudioFrame *> asr_offline_queue;
public:
Audio(int data_type);
@@ -56,17 +64,35 @@
bool LoadPcmwav(const char* filename, int32_t* sampling_rate);
bool LoadPcmwav2Char(const char* filename, int32_t* sampling_rate);
bool LoadOthers2Char(const char* filename);
- bool FfmpegLoad(const char *filename);
+ bool FfmpegLoad(const char *filename, bool copy2char=false);
bool FfmpegLoad(const char* buf, int n_file_len);
- int FetchChunck(float *&dout, int len);
+ int FetchChunck(AudioFrame *&frame);
+ int FetchTpass(AudioFrame *&frame);
int Fetch(float *&dout, int &len, int &flag);
void Padding();
void Split(OfflineStream* offline_streamj);
void Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished=true);
+ void Split(VadModel* vad_obj, int chunk_len, bool input_finished=true, ASR_TYPE asr_mode=ASR_TWO_PASS);
float GetTimeLen();
int GetQueueSize() { return (int)frame_queue.size(); }
char* GetSpeechChar(){return speech_char;}
int GetSpeechLen(){return speech_len;}
+
+ // 2pass
+ vector<float> all_samples;
+ int offset = 0;
+ int speech_start=-1, speech_end=0;
+ int speech_offline_start=-1;
+
+ int seg_sample = MODEL_SAMPLE_RATE/1000;
+ bool LoadPcmwavOnline(const char* buf, int n_file_len, int32_t* sampling_rate);
+ void ResetIndex(){
+ speech_start=-1;
+ speech_end=0;
+ speech_offline_start=-1;
+ offset = 0;
+ all_samples.clear();
+ }
};
} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/include/com-define.h b/funasr/runtime/onnxruntime/include/com-define.h
index 0d3aee0..a1a5e0b 100644
--- a/funasr/runtime/onnxruntime/include/com-define.h
+++ b/funasr/runtime/onnxruntime/include/com-define.h
@@ -13,11 +13,14 @@
// parser option
#define MODEL_DIR "model-dir"
+#define OFFLINE_MODEL_DIR "model-dir"
+#define ONLINE_MODEL_DIR "online-model-dir"
#define VAD_DIR "vad-dir"
#define PUNC_DIR "punc-dir"
#define QUANTIZE "quantize"
#define VAD_QUANT "vad-quant"
#define PUNC_QUANT "punc-quant"
+#define ASR_MODE "mode"
#define WAV_PATH "wav-path"
#define WAV_SCP "wav-scp"
@@ -42,6 +45,11 @@
#define AM_CONFIG_NAME "config.yaml"
#define PUNC_CONFIG_NAME "punc.yaml"
+#define ENCODER_NAME "model.onnx"
+#define QUANT_ENCODER_NAME "model_quant.onnx"
+#define DECODER_NAME "decoder.onnx"
+#define QUANT_DECODER_NAME "decoder_quant.onnx"
+
// vad
#ifndef VAD_SILENCE_DURATION
#define VAD_SILENCE_DURATION 800
@@ -63,6 +71,19 @@
#define VAD_LFR_N 1
#endif
+// asr
+#ifndef PARA_LFR_M
+#define PARA_LFR_M 7
+#endif
+
+#ifndef PARA_LFR_N
+#define PARA_LFR_N 6
+#endif
+
+#ifndef ONLINE_STEP
+#define ONLINE_STEP 9600
+#endif
+
// punc
#define UNK_CHAR "<unk>"
#define TOKEN_LEN 20
diff --git a/funasr/runtime/onnxruntime/include/funasrruntime.h b/funasr/runtime/onnxruntime/include/funasrruntime.h
index ddb65b9..c1059a6 100644
--- a/funasr/runtime/onnxruntime/include/funasrruntime.h
+++ b/funasr/runtime/onnxruntime/include/funasrruntime.h
@@ -47,20 +47,28 @@
}FUNASR_MODEL_TYPE;
typedef enum {
+ ASR_OFFLINE=0,
+ ASR_ONLINE=1,
+ ASR_TWO_PASS=2,
+}ASR_TYPE;
+
+typedef enum {
PUNC_OFFLINE=0,
PUNC_ONLINE=1,
}PUNC_TYPE;
typedef void (* QM_CALLBACK)(int cur_step, int n_total); // n_total: total steps; cur_step: Current Step.
-
+
// ASR
-_FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num);
+_FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num, ASR_TYPE type=ASR_OFFLINE);
+_FUNASRAPI FUNASR_HANDLE FunASROnlineInit(FUNASR_HANDLE asr_handle, std::vector<int> chunk_size={5,10,5});
// buffer
-_FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000, std::string wav_format="pcm");
+_FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool input_finished=true, int sampling_rate=16000, std::string wav_format="pcm");
// file, support wav & pcm
_FUNASRAPI FUNASR_RESULT FunASRInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
_FUNASRAPI const char* FunASRGetResult(FUNASR_RESULT result,int n_index);
+_FUNASRAPI const char* FunASRGetTpassResult(FUNASR_RESULT result,int n_index);
_FUNASRAPI const int FunASRGetRetNumber(FUNASR_RESULT result);
_FUNASRAPI void FunASRFreeResult(FUNASR_RESULT result);
_FUNASRAPI void FunASRUninit(FUNASR_HANDLE handle);
@@ -94,6 +102,14 @@
_FUNASRAPI FUNASR_RESULT FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
_FUNASRAPI void FunOfflineUninit(FUNASR_HANDLE handle);
+//2passStream
+_FUNASRAPI FUNASR_HANDLE FunTpassInit(std::map<std::string, std::string>& model_path, int thread_num);
+_FUNASRAPI FUNASR_HANDLE FunTpassOnlineInit(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size={5,10,5});
+// buffer
+_FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished=true, int sampling_rate=16000, std::string wav_format="pcm", ASR_TYPE mode=ASR_TWO_PASS);
+_FUNASRAPI void FunTpassUninit(FUNASR_HANDLE handle);
+_FUNASRAPI void FunTpassOnlineUninit(FUNASR_HANDLE handle);
+
#ifdef __cplusplus
}
diff --git a/funasr/runtime/onnxruntime/include/model.h b/funasr/runtime/onnxruntime/include/model.h
index 44bd022..ecd8aaf 100644
--- a/funasr/runtime/onnxruntime/include/model.h
+++ b/funasr/runtime/onnxruntime/include/model.h
@@ -4,17 +4,21 @@
#include <string>
#include <map>
+#include "funasrruntime.h"
namespace funasr {
class Model {
public:
virtual ~Model(){};
virtual void Reset() = 0;
- virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num)=0;
- virtual std::string ForwardChunk(float *din, int len, int flag) = 0;
- virtual std::string Forward(float *din, int len, int flag) = 0;
+ virtual void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
+ virtual void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
+ virtual void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){};
+ virtual std::string Forward(float *din, int len, bool input_finished){return "";};
virtual std::string Rescoring() = 0;
};
-Model *CreateModel(std::map<std::string, std::string>& model_path,int thread_num=1);
+Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num=1, ASR_TYPE type=ASR_OFFLINE);
+Model *CreateModel(void* asr_handle, std::vector<int> chunk_size);
+
} // namespace funasr
#endif
diff --git a/funasr/runtime/onnxruntime/include/offline-stream.h b/funasr/runtime/onnxruntime/include/offline-stream.h
index a9ce88e..bc19bf2 100644
--- a/funasr/runtime/onnxruntime/include/offline-stream.h
+++ b/funasr/runtime/onnxruntime/include/offline-stream.h
@@ -14,9 +14,9 @@
OfflineStream(std::map<std::string, std::string>& model_path, int thread_num);
~OfflineStream(){};
- std::unique_ptr<VadModel> vad_handle;
- std::unique_ptr<Model> asr_handle;
- std::unique_ptr<PuncModel> punc_handle;
+ std::unique_ptr<VadModel> vad_handle= nullptr;
+ std::unique_ptr<Model> asr_handle= nullptr;
+ std::unique_ptr<PuncModel> punc_handle= nullptr;
bool UseVad(){return use_vad;};
bool UsePunc(){return use_punc;};
diff --git a/funasr/runtime/onnxruntime/include/tpass-online-stream.h b/funasr/runtime/onnxruntime/include/tpass-online-stream.h
new file mode 100644
index 0000000..e092880
--- /dev/null
+++ b/funasr/runtime/onnxruntime/include/tpass-online-stream.h
@@ -0,0 +1,20 @@
+#ifndef TPASS_ONLINE_STREAM_H
+#define TPASS_ONLINE_STREAM_H
+
+#include <memory>
+#include "tpass-stream.h"
+#include "model.h"
+#include "vad-model.h"
+
+namespace funasr {
+class TpassOnlineStream {
+ public:
+ TpassOnlineStream(TpassStream* tpass_stream, std::vector<int> chunk_size);
+ ~TpassOnlineStream(){};
+
+ std::unique_ptr<VadModel> vad_online_handle = nullptr;
+ std::unique_ptr<Model> asr_online_handle = nullptr;
+};
+TpassOnlineStream* CreateTpassOnlineStream(void* tpass_stream, std::vector<int> chunk_size);
+} // namespace funasr
+#endif
diff --git a/funasr/runtime/onnxruntime/include/tpass-stream.h b/funasr/runtime/onnxruntime/include/tpass-stream.h
new file mode 100644
index 0000000..f9a5385
--- /dev/null
+++ b/funasr/runtime/onnxruntime/include/tpass-stream.h
@@ -0,0 +1,31 @@
+#ifndef TPASS_STREAM_H
+#define TPASS_STREAM_H
+
+#include <memory>
+#include <string>
+#include <map>
+#include "model.h"
+#include "punc-model.h"
+#include "vad-model.h"
+
+namespace funasr {
+class TpassStream {
+ public:
+ TpassStream(std::map<std::string, std::string>& model_path, int thread_num);
+ ~TpassStream(){};
+
+ // std::unique_ptr<VadModel> vad_handle = nullptr;
+ std::unique_ptr<VadModel> vad_handle = nullptr;
+ std::unique_ptr<Model> asr_handle = nullptr;
+ std::unique_ptr<PuncModel> punc_online_handle = nullptr;
+ bool UseVad(){return use_vad;};
+ bool UsePunc(){return use_punc;};
+
+ private:
+ bool use_vad=false;
+ bool use_punc=false;
+};
+
+TpassStream *CreateTpassStream(std::map<std::string, std::string>& model_path, int thread_num=1);
+} // namespace funasr
+#endif
diff --git a/funasr/runtime/onnxruntime/src/audio.cpp b/funasr/runtime/onnxruntime/src/audio.cpp
index 85633b7..2ba9c30 100644
--- a/funasr/runtime/onnxruntime/src/audio.cpp
+++ b/funasr/runtime/onnxruntime/src/audio.cpp
@@ -132,40 +132,54 @@
};
};
-AudioFrame::AudioFrame(){};
+AudioFrame::AudioFrame(){}
AudioFrame::AudioFrame(int len) : len(len)
{
start = 0;
-};
-AudioFrame::~AudioFrame(){};
+}
+AudioFrame::AudioFrame(const AudioFrame &other)
+{
+ start = other.start;
+ end = other.end;
+ len = other.len;
+ is_final = other.is_final;
+}
+AudioFrame::AudioFrame(int start, int end, bool is_final):start(start),end(end),is_final(is_final){
+ len = end - start;
+}
+AudioFrame::~AudioFrame(){
+ if(data != NULL){
+ free(data);
+ }
+}
int AudioFrame::SetStart(int val)
{
start = val < 0 ? 0 : val;
return start;
-};
+}
int AudioFrame::SetEnd(int val)
{
end = val;
len = end - start;
return end;
-};
+}
int AudioFrame::GetStart()
{
return start;
-};
+}
int AudioFrame::GetLen()
{
return len;
-};
+}
int AudioFrame::Disp()
{
LOG(ERROR) << "Not imp!!!!";
return 0;
-};
+}
Audio::Audio(int data_type) : data_type(data_type)
{
@@ -230,7 +244,7 @@
copy(samples.begin(), samples.end(), speech_data);
}
-bool Audio::FfmpegLoad(const char *filename){
+bool Audio::FfmpegLoad(const char *filename, bool copy2char){
// from file
AVFormatContext* formatContext = avformat_alloc_context();
if (avformat_open_input(&formatContext, filename, NULL, NULL) != 0) {
@@ -353,8 +367,17 @@
if (speech_buff != NULL) {
free(speech_buff);
}
+ if (speech_char != NULL) {
+ free(speech_char);
+ }
offset = 0;
+ if(copy2char){
+ speech_char = (char *)malloc(resampled_buffers.size());
+ memset(speech_char, 0, resampled_buffers.size());
+ memcpy((void*)speech_char, (const void*)resampled_buffers.data(), resampled_buffers.size());
+ }
+
speech_len = (resampled_buffers.size()) / 2;
speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
if (speech_buff)
@@ -762,6 +785,55 @@
return false;
}
+bool Audio::LoadPcmwavOnline(const char* buf, int n_buf_len, int32_t* sampling_rate)
+{
+ if (speech_data != NULL) {
+ free(speech_data);
+ }
+ if (speech_buff != NULL) {
+ free(speech_buff);
+ }
+ if (speech_char != NULL) {
+ free(speech_char);
+ }
+
+ speech_len = n_buf_len / 2;
+ speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
+ if (speech_buff)
+ {
+ memset(speech_buff, 0, sizeof(int16_t) * speech_len);
+ memcpy((void*)speech_buff, (const void*)buf, speech_len * sizeof(int16_t));
+
+ speech_data = (float*)malloc(sizeof(float) * speech_len);
+ memset(speech_data, 0, sizeof(float) * speech_len);
+
+ float scale = 1;
+ if (data_type == 1) {
+ scale = 32768;
+ }
+
+ for (int32_t i = 0; i != speech_len; ++i) {
+ speech_data[i] = (float)speech_buff[i] / scale;
+ }
+
+ //resample
+ if(*sampling_rate != MODEL_SAMPLE_RATE){
+ WavResample(*sampling_rate, speech_data, speech_len);
+ }
+
+ for (int32_t i = 0; i != speech_len; ++i) {
+ all_samples.emplace_back(speech_data[i]);
+ }
+
+ AudioFrame* frame = new AudioFrame(speech_len);
+ frame_queue.push(frame);
+ return true;
+
+ }
+ else
+ return false;
+}
+
bool Audio::LoadPcmwav(const char* filename, int32_t* sampling_rate)
{
if (speech_data != NULL) {
@@ -870,24 +942,25 @@
return true;
}
-int Audio::FetchChunck(float *&dout, int len)
+int Audio::FetchTpass(AudioFrame *&frame)
{
- if (offset >= speech_align_len) {
- dout = NULL;
- return S_ERR;
- } else if (offset == speech_align_len - len) {
- dout = speech_data + offset;
- offset = speech_align_len;
- // 涓存椂瑙e喅
- AudioFrame *frame = frame_queue.front();
- frame_queue.pop();
- delete frame;
-
- return S_END;
+ if (asr_offline_queue.size() > 0) {
+ frame = asr_offline_queue.front();
+ asr_offline_queue.pop();
+ return 1;
} else {
- dout = speech_data + offset;
- offset += len;
- return S_MIDDLE;
+ return 0;
+ }
+}
+
+int Audio::FetchChunck(AudioFrame *&frame)
+{
+ if (asr_online_queue.size() > 0) {
+ frame = asr_online_queue.front();
+ asr_online_queue.pop();
+ return 1;
+ } else {
+ return 0;
}
}
@@ -956,7 +1029,6 @@
std::vector<float> pcm_data(speech_data, speech_data+sp_len);
vector<std::vector<int>> vad_segments = (offline_stream->vad_handle)->Infer(pcm_data);
- int seg_sample = MODEL_SAMPLE_RATE/1000;
for(vector<int> segment:vad_segments)
{
frame = new AudioFrame();
@@ -968,7 +1040,6 @@
frame = NULL;
}
}
-
void Audio::Split(VadModel* vad_obj, vector<std::vector<int>>& vad_segments, bool input_finished)
{
@@ -984,4 +1055,161 @@
vad_segments = vad_obj->Infer(pcm_data, input_finished);
}
+// 2pass
+void Audio::Split(VadModel* vad_obj, int chunk_len, bool input_finished, ASR_TYPE asr_mode)
+{
+ AudioFrame *frame;
+
+ frame = frame_queue.front();
+ frame_queue.pop();
+ int sp_len = frame->GetLen();
+ delete frame;
+ frame = NULL;
+
+ std::vector<float> pcm_data(speech_data, speech_data+sp_len);
+ vector<std::vector<int>> vad_segments = vad_obj->Infer(pcm_data, input_finished);
+
+ speech_end += sp_len/seg_sample;
+ if(vad_segments.size() == 0){
+ if(speech_start != -1){
+ int start = speech_start*seg_sample;
+ int end = speech_end*seg_sample;
+ int buff_len = end-start;
+ int step = chunk_len;
+
+ if(asr_mode != ASR_OFFLINE){
+ if(buff_len >= step){
+ frame = new AudioFrame(step);
+ frame->data = (float*)malloc(sizeof(float) * step);
+ memcpy(frame->data, all_samples.data()+start-offset, step*sizeof(float));
+ asr_online_queue.push(frame);
+ frame = NULL;
+ speech_start += step/seg_sample;
+ }
+ }
+ }
+ }else{
+ for(auto vad_segment: vad_segments){
+ int speech_start_i=-1, speech_end_i=-1;
+ if(vad_segment[0] != -1){
+ speech_start_i = vad_segment[0];
+ }
+ if(vad_segment[1] != -1){
+ speech_end_i = vad_segment[1];
+ }
+
+ // [1, 100]
+ if(speech_start_i != -1 && speech_end_i != -1){
+ int start = speech_start_i*seg_sample;
+ int end = speech_end_i*seg_sample;
+
+ if(asr_mode != ASR_OFFLINE){
+ frame = new AudioFrame(end-start);
+ frame->is_final = true;
+ frame->data = (float*)malloc(sizeof(float) * (end-start));
+ memcpy(frame->data, all_samples.data()+start-offset, (end-start)*sizeof(float));
+ asr_online_queue.push(frame);
+ frame = NULL;
+ }
+
+ if(asr_mode != ASR_ONLINE){
+ frame = new AudioFrame(end-start);
+ frame->is_final = true;
+ frame->data = (float*)malloc(sizeof(float) * (end-start));
+ memcpy(frame->data, all_samples.data()+start-offset, (end-start)*sizeof(float));
+ asr_offline_queue.push(frame);
+ frame = NULL;
+ }
+
+ speech_start = -1;
+ speech_offline_start = -1;
+ // [70, -1]
+ }else if(speech_start_i != -1){
+ speech_start = speech_start_i;
+ speech_offline_start = speech_start_i;
+
+ int start = speech_start*seg_sample;
+ int end = speech_end*seg_sample;
+ int buff_len = end-start;
+ int step = chunk_len;
+
+ if(asr_mode != ASR_OFFLINE){
+ if(buff_len >= step){
+ frame = new AudioFrame(step);
+ frame->data = (float*)malloc(sizeof(float) * step);
+ memcpy(frame->data, all_samples.data()+start-offset, step*sizeof(float));
+ asr_online_queue.push(frame);
+ frame = NULL;
+ speech_start += step/seg_sample;
+ }
+ }
+
+ }else if(speech_end_i != -1){ // [-1,100]
+ if(speech_start == -1 or speech_offline_start == -1){
+ LOG(ERROR) <<"Vad start is null while vad end is available." ;
+ exit(-1);
+ }
+
+ int start = speech_start*seg_sample;
+ int offline_start = speech_offline_start*seg_sample;
+ int end = speech_end_i*seg_sample;
+ int buff_len = end-start;
+ int step = chunk_len;
+
+ if(asr_mode != ASR_ONLINE){
+ frame = new AudioFrame(end-offline_start);
+ frame->is_final = true;
+ frame->data = (float*)malloc(sizeof(float) * (end-offline_start));
+ memcpy(frame->data, all_samples.data()+offline_start-offset, (end-offline_start)*sizeof(float));
+ asr_offline_queue.push(frame);
+ frame = NULL;
+ }
+
+ if(asr_mode != ASR_OFFLINE){
+ if(buff_len > 0){
+ for (int sample_offset = 0; sample_offset < buff_len; sample_offset += std::min(step, buff_len - sample_offset)) {
+ bool is_final = false;
+ if (sample_offset + step >= buff_len - 1) {
+ step = buff_len - sample_offset;
+ is_final = true;
+ }
+ frame = new AudioFrame(step);
+ frame->is_final = is_final;
+ frame->data = (float*)malloc(sizeof(float) * step);
+ memcpy(frame->data, all_samples.data()+start-offset+sample_offset, step*sizeof(float));
+ asr_online_queue.push(frame);
+ frame = NULL;
+ }
+ }else{
+ frame = new AudioFrame(0);
+ frame->is_final = true;
+ asr_online_queue.push(frame);
+ frame = NULL;
+ }
+ }
+ speech_start = -1;
+ speech_offline_start = -1;
+ }
+ }
+ }
+
+ // erase all_samples
+ int vector_cache = MODEL_SAMPLE_RATE*2;
+ if(speech_offline_start == -1){
+ if(all_samples.size() > vector_cache){
+ int erase_num = all_samples.size() - vector_cache;
+ all_samples.erase(all_samples.begin(), all_samples.begin()+erase_num);
+ offset += erase_num;
+ }
+ }else{
+ int offline_start = speech_offline_start*seg_sample;
+ if(offline_start-offset > vector_cache){
+ int erase_num = offline_start-offset - vector_cache;
+ all_samples.erase(all_samples.begin(), all_samples.begin()+erase_num);
+ offset += erase_num;
+ }
+ }
+
+}
+
} // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/src/commonfunc.h b/funasr/runtime/onnxruntime/src/commonfunc.h
index b74c1c1..8734d6d 100644
--- a/funasr/runtime/onnxruntime/src/commonfunc.h
+++ b/funasr/runtime/onnxruntime/src/commonfunc.h
@@ -5,7 +5,8 @@
typedef struct
{
std::string msg;
- float snippet_time;
+ std::string tpass_msg;
+ float snippet_time;
}FUNASR_RECOG_RESULT;
typedef struct
diff --git a/funasr/runtime/onnxruntime/src/ct-transformer-online.cpp b/funasr/runtime/onnxruntime/src/ct-transformer-online.cpp
index 191cda8..14601a5 100644
--- a/funasr/runtime/onnxruntime/src/ct-transformer-online.cpp
+++ b/funasr/runtime/onnxruntime/src/ct-transformer-online.cpp
@@ -181,11 +181,12 @@
text_lengths_dim.size()); //, ONNX_TENSOR_ELEMENT_DATA_TYPE_INT32);
//vad_mask
- vector<float> arVadMask,arSubMask;
+ // vector<float> arVadMask,arSubMask;
+ vector<float> arVadMask;
int nTextLength = input_data.size();
VadMask(nTextLength, nCacheSize, arVadMask);
- Triangle(nTextLength, arSubMask);
+ // Triangle(nTextLength, arSubMask);
std::array<int64_t, 4> VadMask_Dim{ 1,1, nTextLength ,nTextLength };
Ort::Value onnx_vad_mask = Ort::Value::CreateTensor<float>(
m_memoryInfo,
@@ -198,8 +199,8 @@
std::array<int64_t, 4> SubMask_Dim{ 1,1, nTextLength ,nTextLength };
Ort::Value onnx_sub_mask = Ort::Value::CreateTensor<float>(
m_memoryInfo,
- arSubMask.data(),
- arSubMask.size() ,
+ arVadMask.data(),
+ arVadMask.size(),
SubMask_Dim.data(),
SubMask_Dim.size()); // , ONNX_TENSOR_ELEMENT_DATA_TYPE_FLOAT);
diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp b/funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp
index 0346916..e16a1fc 100644
--- a/funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp
+++ b/funasr/runtime/onnxruntime/src/fsmn-vad-online.cpp
@@ -55,7 +55,7 @@
int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0;
int lfr_splice_frame_idxs = OnlineLfrCmvn(vad_feats, input_finished);
- int reserve_frame_idx = lfr_splice_frame_idxs - minus_frame;
+ int reserve_frame_idx = std::abs(lfr_splice_frame_idxs - minus_frame);
reserve_waveforms_.clear();
reserve_waveforms_.insert(reserve_waveforms_.begin(),
waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
@@ -86,7 +86,7 @@
int FsmnVadOnline::OnlineLfrCmvn(vector<vector<float>> &vad_feats, bool input_finished) {
vector<vector<float>> out_feats;
int T = vad_feats.size();
- int T_lrf = ceil((T - (lfr_m - 1) / 2) / lfr_n);
+ int T_lrf = ceil((T - (lfr_m - 1) / 2) / (float)lfr_n);
int lfr_splice_frame_idxs = T_lrf;
vector<float> p;
for (int i = 0; i < T_lrf; i++) {
@@ -175,6 +175,9 @@
vad_silence_duration_ = vad_silence_duration;
vad_max_len_ = vad_max_len;
vad_speech_noise_thres_ = vad_speech_noise_thres;
+
+ // 2pass
+ audio_handle = make_unique<Audio>(1);
}
FsmnVadOnline::~FsmnVadOnline() {
diff --git a/funasr/runtime/onnxruntime/src/fsmn-vad-online.h b/funasr/runtime/onnxruntime/src/fsmn-vad-online.h
index 4d429b6..9191304 100644
--- a/funasr/runtime/onnxruntime/src/fsmn-vad-online.h
+++ b/funasr/runtime/onnxruntime/src/fsmn-vad-online.h
@@ -21,6 +21,8 @@
std::vector<std::vector<int>> Infer(std::vector<float> &waves, bool input_finished);
void ExtractFeats(float sample_rate, vector<vector<float>> &vad_feats, vector<float> &waves, bool input_finished);
void Reset();
+ // 2pass
+ std::unique_ptr<Audio> audio_handle = nullptr;
private:
E2EVadModel vad_scorer = E2EVadModel();
diff --git a/funasr/runtime/onnxruntime/src/funasrruntime.cpp b/funasr/runtime/onnxruntime/src/funasrruntime.cpp
index a1829fd..2e6a079 100644
--- a/funasr/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/funasr/runtime/onnxruntime/src/funasrruntime.cpp
@@ -5,9 +5,15 @@
#endif
// APIs for Init
- _FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num)
+ _FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num, ASR_TYPE type)
{
- funasr::Model* mm = funasr::CreateModel(model_path, thread_num);
+ funasr::Model* mm = funasr::CreateModel(model_path, thread_num, type);
+ return mm;
+ }
+
+ _FUNASRAPI FUNASR_HANDLE FunASROnlineInit(FUNASR_HANDLE asr_hanlde, std::vector<int> chunk_size)
+ {
+ funasr::Model* mm = funasr::CreateModel(asr_hanlde, chunk_size);
return mm;
}
@@ -35,8 +41,19 @@
return mm;
}
+ _FUNASRAPI FUNASR_HANDLE FunTpassInit(std::map<std::string, std::string>& model_path, int thread_num)
+ {
+ funasr::TpassStream* mm = funasr::CreateTpassStream(model_path, thread_num);
+ return mm;
+ }
+
+ _FUNASRAPI FUNASR_HANDLE FunTpassOnlineInit(FUNASR_HANDLE tpass_handle, std::vector<int> chunk_size)
+ {
+ return funasr::CreateTpassOnlineStream(tpass_handle, chunk_size);
+ }
+
// APIs for ASR Infer
- _FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate, std::string wav_format)
+ _FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, bool input_finished, int sampling_rate, std::string wav_format)
{
funasr::Model* recog_obj = (funasr::Model*)handle;
if (!recog_obj)
@@ -57,12 +74,12 @@
funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
p_result->snippet_time = audio.GetTimeLen();
if(p_result->snippet_time == 0){
- return p_result;
- }
+ return p_result;
+ }
int n_step = 0;
int n_total = audio.GetQueueSize();
while (audio.Fetch(buff, len, flag) > 0) {
- string msg = recog_obj->Forward(buff, len, flag);
+ string msg = recog_obj->Forward(buff, len, input_finished);
p_result->msg += msg;
n_step++;
if (fn_callback)
@@ -102,7 +119,7 @@
return p_result;
}
while (audio.Fetch(buff, len, flag) > 0) {
- string msg = recog_obj->Forward(buff, len, flag);
+ string msg = recog_obj->Forward(buff, len, true);
p_result->msg += msg;
n_step++;
if (fn_callback)
@@ -230,7 +247,7 @@
int n_step = 0;
int n_total = audio.GetQueueSize();
while (audio.Fetch(buff, len, flag) > 0) {
- string msg = (offline_stream->asr_handle)->Forward(buff, len, flag);
+ string msg = (offline_stream->asr_handle)->Forward(buff, len, true);
p_result->msg += msg;
n_step++;
if (fn_callback)
@@ -277,7 +294,7 @@
int n_step = 0;
int n_total = audio.GetQueueSize();
while (audio.Fetch(buff, len, flag) > 0) {
- string msg = (offline_stream->asr_handle)->Forward(buff, len, flag);
+ string msg = (offline_stream->asr_handle)->Forward(buff, len, true);
p_result->msg+= msg;
n_step++;
if (fn_callback)
@@ -288,6 +305,91 @@
p_result->msg = punc_res;
}
+ return p_result;
+ }
+
+ // APIs for 2pass-stream Infer
+ _FUNASRAPI FUNASR_RESULT FunTpassInferBuffer(FUNASR_HANDLE handle, FUNASR_HANDLE online_handle, const char* sz_buf, int n_len, std::vector<std::vector<std::string>> &punc_cache, bool input_finished, int sampling_rate, std::string wav_format, ASR_TYPE mode)
+ {
+ funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
+ funasr::TpassOnlineStream* tpass_online_stream = (funasr::TpassOnlineStream*)online_handle;
+ if (!tpass_stream || !tpass_online_stream)
+ return nullptr;
+
+ funasr::VadModel* vad_online_handle = (tpass_online_stream->vad_online_handle).get();
+ if (!vad_online_handle)
+ return nullptr;
+
+ funasr::Audio* audio = ((funasr::FsmnVadOnline*)vad_online_handle)->audio_handle.get();
+
+ funasr::Model* asr_online_handle = (tpass_online_stream->asr_online_handle).get();
+ if (!asr_online_handle)
+ return nullptr;
+ int chunk_len = ((funasr::ParaformerOnline*)asr_online_handle)->chunk_len;
+
+ funasr::Model* asr_handle = (tpass_stream->asr_handle).get();
+ if (!asr_handle)
+ return nullptr;
+
+ funasr::PuncModel* punc_online_handle = (tpass_stream->punc_online_handle).get();
+ if (!punc_online_handle)
+ return nullptr;
+
+ if(wav_format == "pcm" || wav_format == "PCM"){
+ if (!audio->LoadPcmwavOnline(sz_buf, n_len, &sampling_rate))
+ return nullptr;
+ }else{
+ // if (!audio->FfmpegLoad(sz_buf, n_len))
+ // return nullptr;
+ LOG(ERROR) <<"Wrong wav_format: " << wav_format ;
+ exit(-1);
+ }
+
+ funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
+ p_result->snippet_time = audio->GetTimeLen();
+ if(p_result->snippet_time == 0){
+ return p_result;
+ }
+
+ audio->Split(vad_online_handle, chunk_len, input_finished, mode);
+
+ funasr::AudioFrame* frame = NULL;
+ while(audio->FetchChunck(frame) > 0){
+ string msg = asr_online_handle->Forward(frame->data, frame->len, frame->is_final);
+ if(mode == ASR_ONLINE){
+ ((funasr::ParaformerOnline*)asr_online_handle)->online_res += msg;
+ if(frame->is_final){
+ string online_msg = ((funasr::ParaformerOnline*)asr_online_handle)->online_res;
+ string msg_punc = punc_online_handle->AddPunc(online_msg.c_str(), punc_cache[0]);
+ p_result->tpass_msg = msg_punc;
+ ((funasr::ParaformerOnline*)asr_online_handle)->online_res = "";
+ p_result->msg += msg;
+ }else{
+ p_result->msg += msg;
+ }
+ }else if(mode == ASR_TWO_PASS){
+ p_result->msg += msg;
+ }
+ if(frame != NULL){
+ delete frame;
+ frame = NULL;
+ }
+ }
+
+ while(audio->FetchTpass(frame) > 0){
+ string msg = asr_handle->Forward(frame->data, frame->len, frame->is_final);
+ string msg_punc = punc_online_handle->AddPunc(msg.c_str(), punc_cache[1]);
+ p_result->tpass_msg = msg_punc;
+ if(frame != NULL){
+ delete frame;
+ frame = NULL;
+ }
+ }
+
+ if(input_finished){
+ audio->ResetIndex();
+ }
+
return p_result;
}
@@ -324,6 +426,15 @@
return nullptr;
return p_result->msg.c_str();
+ }
+
+ _FUNASRAPI const char* FunASRGetTpassResult(FUNASR_RESULT result,int n_index)
+ {
+ funasr::FUNASR_RECOG_RESULT * p_result = (funasr::FUNASR_RECOG_RESULT*)result;
+ if(!p_result)
+ return nullptr;
+
+ return p_result->tpass_msg.c_str();
}
_FUNASRAPI const char* CTTransformerGetResult(FUNASR_RESULT result,int n_index)
@@ -414,6 +525,26 @@
delete offline_stream;
}
+ _FUNASRAPI void FunTpassUninit(FUNASR_HANDLE handle)
+ {
+ funasr::TpassStream* tpass_stream = (funasr::TpassStream*)handle;
+
+ if (!tpass_stream)
+ return;
+
+ delete tpass_stream;
+ }
+
+ _FUNASRAPI void FunTpassOnlineUninit(FUNASR_HANDLE handle)
+ {
+ funasr::TpassOnlineStream* tpass_online_stream = (funasr::TpassOnlineStream*)handle;
+
+ if (!tpass_online_stream)
+ return;
+
+ delete tpass_online_stream;
+ }
+
#ifdef __cplusplus
}
diff --git a/funasr/runtime/onnxruntime/src/model.cpp b/funasr/runtime/onnxruntime/src/model.cpp
index 6badde6..646f260 100644
--- a/funasr/runtime/onnxruntime/src/model.cpp
+++ b/funasr/runtime/onnxruntime/src/model.cpp
@@ -1,22 +1,55 @@
#include "precomp.h"
namespace funasr {
-Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num)
+Model *CreateModel(std::map<std::string, std::string>& model_path, int thread_num, ASR_TYPE type)
{
- string am_model_path;
- string am_cmvn_path;
- string am_config_path;
+ // offline
+ if(type == ASR_OFFLINE){
+ string am_model_path;
+ string am_cmvn_path;
+ string am_config_path;
- am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
- if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
- am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
+ am_model_path = PathAppend(model_path.at(MODEL_DIR), MODEL_NAME);
+ if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
+ am_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_MODEL_NAME);
+ }
+ am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
+ am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
+
+ Model *mm;
+ mm = new Paraformer();
+ mm->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num);
+ return mm;
+ }else if(type == ASR_ONLINE){
+ // online
+ string en_model_path;
+ string de_model_path;
+ string am_cmvn_path;
+ string am_config_path;
+
+ en_model_path = PathAppend(model_path.at(MODEL_DIR), ENCODER_NAME);
+ de_model_path = PathAppend(model_path.at(MODEL_DIR), DECODER_NAME);
+ if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
+ en_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_ENCODER_NAME);
+ de_model_path = PathAppend(model_path.at(MODEL_DIR), QUANT_DECODER_NAME);
+ }
+ am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
+ am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
+
+ Model *mm;
+ mm = new Paraformer();
+ mm->InitAsr(en_model_path, de_model_path, am_cmvn_path, am_config_path, thread_num);
+ return mm;
+ }else{
+ LOG(ERROR)<<"Wrong ASR_TYPE : " << type;
+ exit(-1);
}
- am_cmvn_path = PathAppend(model_path.at(MODEL_DIR), AM_CMVN_NAME);
- am_config_path = PathAppend(model_path.at(MODEL_DIR), AM_CONFIG_NAME);
+}
- Model *mm;
- mm = new Paraformer();
- mm->InitAsr(am_model_path, am_cmvn_path, am_config_path, thread_num);
+Model *CreateModel(void* asr_handle, std::vector<int> chunk_size)
+{
+ Model* mm;
+ mm = new ParaformerOnline((Paraformer*)asr_handle, chunk_size);
return mm;
}
diff --git a/funasr/runtime/onnxruntime/src/paraformer-online.cpp b/funasr/runtime/onnxruntime/src/paraformer-online.cpp
new file mode 100644
index 0000000..1787f02
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/paraformer-online.cpp
@@ -0,0 +1,551 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+
+#include "precomp.h"
+
+using namespace std;
+
+namespace funasr {
+
+ParaformerOnline::ParaformerOnline(Paraformer* para_handle, std::vector<int> chunk_size)
+:para_handle_(std::move(para_handle)),chunk_size(chunk_size),session_options_{}{
+ InitOnline(
+ para_handle_->fbank_opts_,
+ para_handle_->encoder_session_,
+ para_handle_->decoder_session_,
+ para_handle_->en_szInputNames_,
+ para_handle_->en_szOutputNames_,
+ para_handle_->de_szInputNames_,
+ para_handle_->de_szOutputNames_,
+ para_handle_->means_list_,
+ para_handle_->vars_list_);
+ InitCache();
+}
+
+void ParaformerOnline::InitOnline(
+ knf::FbankOptions &fbank_opts,
+ std::shared_ptr<Ort::Session> &encoder_session,
+ std::shared_ptr<Ort::Session> &decoder_session,
+ vector<const char*> &en_szInputNames,
+ vector<const char*> &en_szOutputNames,
+ vector<const char*> &de_szInputNames,
+ vector<const char*> &de_szOutputNames,
+ vector<float> &means_list,
+ vector<float> &vars_list){
+ fbank_opts_ = fbank_opts;
+ encoder_session_ = encoder_session;
+ decoder_session_ = decoder_session;
+ en_szInputNames_ = en_szInputNames;
+ en_szOutputNames_ = en_szOutputNames;
+ de_szInputNames_ = de_szInputNames;
+ de_szOutputNames_ = de_szOutputNames;
+ means_list_ = means_list;
+ vars_list_ = vars_list;
+
+ frame_length = para_handle_->frame_length;
+ frame_shift = para_handle_->frame_shift;
+ n_mels = para_handle_->n_mels;
+ lfr_m = para_handle_->lfr_m;
+ lfr_n = para_handle_->lfr_n;
+ encoder_size = para_handle_->encoder_size;
+ fsmn_layers = para_handle_->fsmn_layers;
+ fsmn_lorder = para_handle_->fsmn_lorder;
+ fsmn_dims = para_handle_->fsmn_dims;
+ cif_threshold = para_handle_->cif_threshold;
+ tail_alphas = para_handle_->tail_alphas;
+
+ // other vars
+ sqrt_factor = std::sqrt(encoder_size);
+ for(int i=0; i<fsmn_lorder*fsmn_dims; i++){
+ fsmn_init_cache_.emplace_back(0);
+ }
+ chunk_len = chunk_size[1]*frame_shift*lfr_n*MODEL_SAMPLE_RATE/1000;
+}
+
+void ParaformerOnline::FbankKaldi(float sample_rate, std::vector<std::vector<float>> &wav_feats,
+ std::vector<float> &waves) {
+ knf::OnlineFbank fbank(fbank_opts_);
+ // cache merge
+ waves.insert(waves.begin(), input_cache_.begin(), input_cache_.end());
+ int frame_number = ComputeFrameNum(waves.size(), frame_sample_length_, frame_shift_sample_length_);
+ // Send the audio after the last frame shift position to the cache
+ input_cache_.clear();
+ input_cache_.insert(input_cache_.begin(), waves.begin() + frame_number * frame_shift_sample_length_, waves.end());
+ if (frame_number == 0) {
+ return;
+ }
+ // Delete audio that haven't undergone fbank processing
+ waves.erase(waves.begin() + (frame_number - 1) * frame_shift_sample_length_ + frame_sample_length_, waves.end());
+
+ std::vector<float> buf(waves.size());
+ for (int32_t i = 0; i != waves.size(); ++i) {
+ buf[i] = waves[i] * 32768;
+ }
+ fbank.AcceptWaveform(sample_rate, buf.data(), buf.size());
+ int32_t frames = fbank.NumFramesReady();
+ for (int32_t i = 0; i != frames; ++i) {
+ const float *frame = fbank.GetFrame(i);
+ vector<float> frame_vector(frame, frame + fbank_opts_.mel_opts.num_bins);
+ wav_feats.emplace_back(frame_vector);
+ }
+}
+
+void ParaformerOnline::ExtractFeats(float sample_rate, vector<std::vector<float>> &wav_feats,
+ vector<float> &waves, bool input_finished) {
+ FbankKaldi(sample_rate, wav_feats, waves);
+ // cache deal & online lfr,cmvn
+ if (wav_feats.size() > 0) {
+ if (!reserve_waveforms_.empty()) {
+ waves.insert(waves.begin(), reserve_waveforms_.begin(), reserve_waveforms_.end());
+ }
+ if (lfr_splice_cache_.empty()) {
+ for (int i = 0; i < (lfr_m - 1) / 2; i++) {
+ lfr_splice_cache_.emplace_back(wav_feats[0]);
+ }
+ }
+ if (wav_feats.size() + lfr_splice_cache_.size() >= lfr_m) {
+ wav_feats.insert(wav_feats.begin(), lfr_splice_cache_.begin(), lfr_splice_cache_.end());
+ int frame_from_waves = (waves.size() - frame_sample_length_) / frame_shift_sample_length_ + 1;
+ int minus_frame = reserve_waveforms_.empty() ? (lfr_m - 1) / 2 : 0;
+ int lfr_splice_frame_idxs = OnlineLfrCmvn(wav_feats, input_finished);
+ int reserve_frame_idx = std::abs(lfr_splice_frame_idxs - minus_frame);
+ reserve_waveforms_.clear();
+ reserve_waveforms_.insert(reserve_waveforms_.begin(),
+ waves.begin() + reserve_frame_idx * frame_shift_sample_length_,
+ waves.begin() + frame_from_waves * frame_shift_sample_length_);
+ int sample_length = (frame_from_waves - 1) * frame_shift_sample_length_ + frame_sample_length_;
+ waves.erase(waves.begin() + sample_length, waves.end());
+ } else {
+ reserve_waveforms_.clear();
+ reserve_waveforms_.insert(reserve_waveforms_.begin(),
+ waves.begin() + frame_sample_length_ - frame_shift_sample_length_, waves.end());
+ lfr_splice_cache_.insert(lfr_splice_cache_.end(), wav_feats.begin(), wav_feats.end());
+ }
+ } else {
+ if (input_finished) {
+ if (!reserve_waveforms_.empty()) {
+ waves = reserve_waveforms_;
+ }
+ wav_feats = lfr_splice_cache_;
+ OnlineLfrCmvn(wav_feats, input_finished);
+ }
+ }
+ if(input_finished){
+ ResetCache();
+ }
+}
+
+int ParaformerOnline::OnlineLfrCmvn(vector<vector<float>> &wav_feats, bool input_finished) {
+ vector<vector<float>> out_feats;
+ int T = wav_feats.size();
+ int T_lrf = ceil((T - (lfr_m - 1) / 2) / (float)lfr_n);
+ int lfr_splice_frame_idxs = T_lrf;
+ vector<float> p;
+ for (int i = 0; i < T_lrf; i++) {
+ if (lfr_m <= T - i * lfr_n) {
+ for (int j = 0; j < lfr_m; j++) {
+ p.insert(p.end(), wav_feats[i * lfr_n + j].begin(), wav_feats[i * lfr_n + j].end());
+ }
+ out_feats.emplace_back(p);
+ p.clear();
+ } else {
+ if (input_finished) {
+ int num_padding = lfr_m - (T - i * lfr_n);
+ for (int j = 0; j < (wav_feats.size() - i * lfr_n); j++) {
+ p.insert(p.end(), wav_feats[i * lfr_n + j].begin(), wav_feats[i * lfr_n + j].end());
+ }
+ for (int j = 0; j < num_padding; j++) {
+ p.insert(p.end(), wav_feats[wav_feats.size() - 1].begin(), wav_feats[wav_feats.size() - 1].end());
+ }
+ out_feats.emplace_back(p);
+ } else {
+ lfr_splice_frame_idxs = i;
+ break;
+ }
+ }
+ }
+ lfr_splice_frame_idxs = std::min(T - 1, lfr_splice_frame_idxs * lfr_n);
+ lfr_splice_cache_.clear();
+ lfr_splice_cache_.insert(lfr_splice_cache_.begin(), wav_feats.begin() + lfr_splice_frame_idxs, wav_feats.end());
+
+ // Apply cmvn
+ for (auto &out_feat: out_feats) {
+ for (int j = 0; j < means_list_.size(); j++) {
+ out_feat[j] = (out_feat[j] + means_list_[j]) * vars_list_[j];
+ }
+ }
+ wav_feats = out_feats;
+ return lfr_splice_frame_idxs;
+}
+
+void ParaformerOnline::GetPosEmb(std::vector<std::vector<float>> &wav_feats, int timesteps, int feat_dim)
+{
+ int start_idx = start_idx_cache_;
+ start_idx_cache_ += timesteps;
+ int mm = start_idx_cache_;
+
+ int i;
+ float scale = -0.0330119726594128;
+
+ std::vector<float> tmp(mm*feat_dim);
+
+ for (i = 0; i < feat_dim/2; i++) {
+ float tmptime = exp(i * scale);
+ int j;
+ for (j = 0; j < mm; j++) {
+ int sin_idx = j * feat_dim + i;
+ int cos_idx = j * feat_dim + i + feat_dim/2;
+ float coe = tmptime * (j + 1);
+ tmp[sin_idx] = sin(coe);
+ tmp[cos_idx] = cos(coe);
+ }
+ }
+
+ for (i = start_idx; i < start_idx + timesteps; i++) {
+ for (int j = 0; j < feat_dim; j++) {
+ wav_feats[i-start_idx][j] += tmp[i*feat_dim+j];
+ }
+ }
+}
+
+void ParaformerOnline::CifSearch(std::vector<std::vector<float>> hidden, std::vector<float> alphas, bool is_final, std::vector<std::vector<float>>& list_frame)
+{
+ try{
+ int hidden_size = 0;
+ if(hidden.size() > 0){
+ hidden_size = hidden[0].size();
+ }
+ // cache
+ int i,j;
+ int chunk_size_pre = chunk_size[0];
+ for (i = 0; i < chunk_size_pre; i++)
+ alphas[i] = 0.0;
+
+ int chunk_size_suf = std::accumulate(chunk_size.begin(), chunk_size.end()-1, 0);
+ for (i = chunk_size_suf; i < alphas.size(); i++){
+ alphas[i] = 0.0;
+ }
+
+ if(hidden_cache_.size()>0){
+ hidden.insert(hidden.begin(), hidden_cache_.begin(), hidden_cache_.end());
+ alphas.insert(alphas.begin(), alphas_cache_.begin(), alphas_cache_.end());
+ hidden_cache_.clear();
+ alphas_cache_.clear();
+ }
+
+ if (is_last_chunk) {
+ std::vector<float> tail_hidden(hidden_size, 0);
+ hidden.emplace_back(tail_hidden);
+ alphas.emplace_back(tail_alphas);
+ }
+
+ float intergrate = 0.0;
+ int len_time = alphas.size();
+ std::vector<float> frames(hidden_size, 0);
+ std::vector<float> list_fire;
+
+ for (i = 0; i < len_time; i++) {
+ float alpha = alphas[i];
+ if (alpha + intergrate < cif_threshold) {
+ intergrate += alpha;
+ list_fire.emplace_back(intergrate);
+ for (j = 0; j < hidden_size; j++) {
+ frames[j] += alpha * hidden[i][j];
+ }
+ } else {
+ for (j = 0; j < hidden_size; j++) {
+ frames[j] += (cif_threshold - intergrate) * hidden[i][j];
+ }
+ std::vector<float> frames_cp(frames);
+ list_frame.emplace_back(frames_cp);
+ intergrate += alpha;
+ list_fire.emplace_back(intergrate);
+ intergrate -= cif_threshold;
+ for (j = 0; j < hidden_size; j++) {
+ frames[j] = intergrate * hidden[i][j];
+ }
+ }
+ }
+
+ // cache
+ alphas_cache_.emplace_back(intergrate);
+ if (intergrate > 0.0) {
+ std::vector<float> hidden_cache(hidden_size, 0);
+ for (i = 0; i < hidden_size; i++) {
+ hidden_cache[i] = frames[i] / intergrate;
+ }
+ hidden_cache_.emplace_back(hidden_cache);
+ } else {
+ std::vector<float> frames_cp(frames);
+ hidden_cache_.emplace_back(frames_cp);
+ }
+ }catch (std::exception const &e)
+ {
+ LOG(ERROR)<<e.what();
+ }
+}
+
+void ParaformerOnline::InitCache(){
+
+ start_idx_cache_ = 0;
+ is_first_chunk = true;
+ is_last_chunk = false;
+ hidden_cache_.clear();
+ alphas_cache_.clear();
+ feats_cache_.clear();
+ decoder_onnx.clear();
+
+ // cif cache
+ std::vector<float> hidden_cache(encoder_size, 0);
+ hidden_cache_.emplace_back(hidden_cache);
+ alphas_cache_.emplace_back(0);
+
+ // feats
+ std::vector<float> feat_cache(feat_dims, 0);
+ for(int i=0; i<(chunk_size[0]+chunk_size[2]); i++){
+ feats_cache_.emplace_back(feat_cache);
+ }
+
+ // fsmn cache
+#ifdef _WIN_X86
+ Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
+#else
+ Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
+#endif
+ const int64_t fsmn_shape_[3] = {1, fsmn_dims, fsmn_lorder};
+ for(int l=0; l<fsmn_layers; l++){
+ Ort::Value onnx_fsmn_cache = Ort::Value::CreateTensor<float>(
+ m_memoryInfo,
+ fsmn_init_cache_.data(),
+ fsmn_init_cache_.size(),
+ fsmn_shape_,
+ 3);
+ decoder_onnx.emplace_back(std::move(onnx_fsmn_cache));
+ }
+};
+
+void ParaformerOnline::Reset()
+{
+ InitCache();
+}
+
+void ParaformerOnline::ResetCache() {
+ reserve_waveforms_.clear();
+ input_cache_.clear();
+ lfr_splice_cache_.clear();
+}
+
+void ParaformerOnline::AddOverlapChunk(std::vector<std::vector<float>> &wav_feats, bool input_finished){
+ wav_feats.insert(wav_feats.begin(), feats_cache_.begin(), feats_cache_.end());
+ if(input_finished){
+ feats_cache_.clear();
+ feats_cache_.insert(feats_cache_.begin(), wav_feats.end()-chunk_size[0], wav_feats.end());
+ if(!is_last_chunk){
+ int padding_length = std::accumulate(chunk_size.begin(), chunk_size.end(), 0) - wav_feats.size();
+ std::vector<float> tmp(feat_dims, 0);
+ for(int i=0; i<padding_length; i++){
+ wav_feats.emplace_back(feat_dims);
+ }
+ }
+ }else{
+ feats_cache_.clear();
+ feats_cache_.insert(feats_cache_.begin(), wav_feats.end()-chunk_size[0]-chunk_size[2], wav_feats.end());
+ }
+}
+
+string ParaformerOnline::ForwardChunk(std::vector<std::vector<float>> &chunk_feats, bool input_finished)
+{
+ string result;
+ try{
+ int32_t num_frames = chunk_feats.size();
+
+ #ifdef _WIN_X86
+ Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtDeviceAllocator, OrtMemTypeCPU);
+ #else
+ Ort::MemoryInfo m_memoryInfo = Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
+ #endif
+ const int64_t input_shape_[3] = {1, num_frames, feat_dims};
+ std::vector<float> wav_feats;
+ for (const auto &chunk_feat: chunk_feats) {
+ wav_feats.insert(wav_feats.end(), chunk_feat.begin(), chunk_feat.end());
+ }
+ Ort::Value onnx_feats = Ort::Value::CreateTensor<float>(
+ m_memoryInfo,
+ wav_feats.data(),
+ wav_feats.size(),
+ input_shape_,
+ 3);
+
+ const int64_t paraformer_length_shape[1] = {1};
+ std::vector<int32_t> paraformer_length;
+ paraformer_length.emplace_back(num_frames);
+ Ort::Value onnx_feats_len = Ort::Value::CreateTensor<int32_t>(
+ m_memoryInfo, paraformer_length.data(), paraformer_length.size(), paraformer_length_shape, 1);
+
+ std::vector<Ort::Value> input_onnx;
+ input_onnx.emplace_back(std::move(onnx_feats));
+ input_onnx.emplace_back(std::move(onnx_feats_len));
+
+ auto encoder_tensor = encoder_session_->Run(Ort::RunOptions{nullptr}, en_szInputNames_.data(), input_onnx.data(), input_onnx.size(), en_szOutputNames_.data(), en_szOutputNames_.size());
+
+ // get enc_vec
+ std::vector<int64_t> enc_shape = encoder_tensor[0].GetTensorTypeAndShapeInfo().GetShape();
+ float* enc_data = encoder_tensor[0].GetTensorMutableData<float>();
+ std::vector<std::vector<float>> enc_vec(enc_shape[1], std::vector<float>(enc_shape[2]));
+ for (int i = 0; i < enc_shape[1]; i++) {
+ for (int j = 0; j < enc_shape[2]; j++) {
+ enc_vec[i][j] = enc_data[i * enc_shape[2] + j];
+ }
+ }
+
+ // get alpha_vec
+ std::vector<int64_t> alpha_shape = encoder_tensor[2].GetTensorTypeAndShapeInfo().GetShape();
+ float* alpha_data = encoder_tensor[2].GetTensorMutableData<float>();
+ std::vector<float> alpha_vec(alpha_shape[1]);
+ for (int i = 0; i < alpha_shape[1]; i++) {
+ alpha_vec[i] = alpha_data[i];
+ }
+
+ std::vector<std::vector<float>> list_frame;
+ CifSearch(enc_vec, alpha_vec, input_finished, list_frame);
+
+
+ if(list_frame.size()>0){
+ // enc
+ decoder_onnx.insert(decoder_onnx.begin(), std::move(encoder_tensor[0]));
+ // enc_lens
+ decoder_onnx.insert(decoder_onnx.begin()+1, std::move(encoder_tensor[1]));
+
+ // acoustic_embeds
+ const int64_t emb_shape_[3] = {1, (int64_t)list_frame.size(), (int64_t)list_frame[0].size()};
+ std::vector<float> emb_input;
+ for (const auto &list_frame_: list_frame) {
+ emb_input.insert(emb_input.end(), list_frame_.begin(), list_frame_.end());
+ }
+ Ort::Value onnx_emb = Ort::Value::CreateTensor<float>(
+ m_memoryInfo,
+ emb_input.data(),
+ emb_input.size(),
+ emb_shape_,
+ 3);
+ decoder_onnx.insert(decoder_onnx.begin()+2, std::move(onnx_emb));
+
+ // acoustic_embeds_len
+ const int64_t emb_length_shape[1] = {1};
+ std::vector<int32_t> emb_length;
+ emb_length.emplace_back(list_frame.size());
+ Ort::Value onnx_emb_len = Ort::Value::CreateTensor<int32_t>(
+ m_memoryInfo, emb_length.data(), emb_length.size(), emb_length_shape, 1);
+ decoder_onnx.insert(decoder_onnx.begin()+3, std::move(onnx_emb_len));
+
+ auto decoder_tensor = decoder_session_->Run(Ort::RunOptions{nullptr}, de_szInputNames_.data(), decoder_onnx.data(), decoder_onnx.size(), de_szOutputNames_.data(), de_szOutputNames_.size());
+ // fsmn cache
+ try{
+ decoder_onnx.clear();
+ }catch (std::exception const &e)
+ {
+ LOG(ERROR)<<e.what();
+ return result;
+ }
+ for(int l=0;l<fsmn_layers;l++){
+ decoder_onnx.emplace_back(std::move(decoder_tensor[2+l]));
+ }
+
+ std::vector<int64_t> decoder_shape = decoder_tensor[0].GetTensorTypeAndShapeInfo().GetShape();
+ float* float_data = decoder_tensor[0].GetTensorMutableData<float>();
+ result = para_handle_->GreedySearch(float_data, list_frame.size(), decoder_shape[2]);
+ }
+ }catch (std::exception const &e)
+ {
+ LOG(ERROR)<<e.what();
+ return result;
+ }
+ return result;
+}
+
+string ParaformerOnline::Forward(float* din, int len, bool input_finished)
+{
+ std::vector<std::vector<float>> wav_feats;
+ std::vector<float> waves(din, din+len);
+
+ string result="";
+ try{
+ if(len <16*60 && input_finished && !is_first_chunk){
+ is_last_chunk = true;
+ wav_feats = feats_cache_;
+ result = ForwardChunk(wav_feats, is_last_chunk);
+ // reset
+ ResetCache();
+ Reset();
+ return result;
+ }
+ if(is_first_chunk){
+ is_first_chunk = false;
+ }
+ ExtractFeats(MODEL_SAMPLE_RATE, wav_feats, waves, input_finished);
+ if(wav_feats.size() == 0){
+ return result;
+ }
+
+ for (auto& row : wav_feats) {
+ for (auto& val : row) {
+ val *= sqrt_factor;
+ }
+ }
+
+ GetPosEmb(wav_feats, wav_feats.size(), wav_feats[0].size());
+ if(input_finished){
+ if(wav_feats.size()+chunk_size[2] <= chunk_size[1]){
+ is_last_chunk = true;
+ AddOverlapChunk(wav_feats, input_finished);
+ }else{
+ // first chunk
+ std::vector<std::vector<float>> first_chunk;
+ first_chunk.insert(first_chunk.begin(), wav_feats.begin(), wav_feats.end());
+ AddOverlapChunk(first_chunk, input_finished);
+ string str_first_chunk = ForwardChunk(first_chunk, is_last_chunk);
+
+ // last chunk
+ is_last_chunk = true;
+ std::vector<std::vector<float>> last_chunk;
+ last_chunk.insert(last_chunk.begin(), wav_feats.end()-(wav_feats.size()+chunk_size[2]-chunk_size[1]), wav_feats.end());
+ AddOverlapChunk(last_chunk, input_finished);
+ string str_last_chunk = ForwardChunk(last_chunk, is_last_chunk);
+
+ result = str_first_chunk+str_last_chunk;
+ // reset
+ ResetCache();
+ Reset();
+ return result;
+ }
+ }else{
+ AddOverlapChunk(wav_feats, input_finished);
+ }
+
+ result = ForwardChunk(wav_feats, is_last_chunk);
+ if(input_finished){
+ // reset
+ ResetCache();
+ Reset();
+ }
+ }catch (std::exception const &e)
+ {
+ LOG(ERROR)<<e.what();
+ return result;
+ }
+
+ return result;
+}
+
+ParaformerOnline::~ParaformerOnline()
+{
+}
+
+string ParaformerOnline::Rescoring()
+{
+ LOG(ERROR)<<"Not Imp!!!!!!";
+ return "";
+}
+} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/paraformer-online.h b/funasr/runtime/onnxruntime/src/paraformer-online.h
new file mode 100644
index 0000000..d0265f2
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/paraformer-online.h
@@ -0,0 +1,111 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights Reserved.
+ * MIT License (https://opensource.org/licenses/MIT)
+*/
+#pragma once
+
+#include "precomp.h"
+
+namespace funasr {
+
+ class ParaformerOnline : public Model {
+ /**
+ * Author: Speech Lab of DAMO Academy, Alibaba Group
+ * ParaformerOnline: Fast and Accurate Parallel Transformer for Non-autoregressive End-to-End Speech Recognition
+ * https://arxiv.org/pdf/2206.08317.pdf
+ */
+ private:
+
+ void FbankKaldi(float sample_rate, std::vector<std::vector<float>> &wav_feats,
+ std::vector<float> &waves);
+ int OnlineLfrCmvn(vector<vector<float>> &wav_feats, bool input_finished);
+ void GetPosEmb(std::vector<std::vector<float>> &wav_feats, int timesteps, int feat_dim);
+ void CifSearch(std::vector<std::vector<float>> hidden, std::vector<float> alphas, bool is_final, std::vector<std::vector<float>> &list_frame);
+
+ static int ComputeFrameNum(int sample_length, int frame_sample_length, int frame_shift_sample_length) {
+ int frame_num = static_cast<int>((sample_length - frame_sample_length) / frame_shift_sample_length + 1);
+ if (frame_num >= 1 && sample_length >= frame_sample_length)
+ return frame_num;
+ else
+ return 0;
+ }
+ void InitOnline(
+ knf::FbankOptions &fbank_opts,
+ std::shared_ptr<Ort::Session> &encoder_session,
+ std::shared_ptr<Ort::Session> &decoder_session,
+ vector<const char*> &en_szInputNames,
+ vector<const char*> &en_szOutputNames,
+ vector<const char*> &de_szInputNames,
+ vector<const char*> &de_szOutputNames,
+ vector<float> &means_list,
+ vector<float> &vars_list);
+
+ Paraformer* para_handle_ = nullptr;
+ // from para_handle_
+ knf::FbankOptions fbank_opts_;
+ std::shared_ptr<Ort::Session> encoder_session_ = nullptr;
+ std::shared_ptr<Ort::Session> decoder_session_ = nullptr;
+ Ort::SessionOptions session_options_;
+ vector<const char*> en_szInputNames_;
+ vector<const char*> en_szOutputNames_;
+ vector<const char*> de_szInputNames_;
+ vector<const char*> de_szOutputNames_;
+ vector<float> means_list_;
+ vector<float> vars_list_;
+ // configs from para_handle_
+ int frame_length = 25;
+ int frame_shift = 10;
+ int n_mels = 80;
+ int lfr_m = PARA_LFR_M;
+ int lfr_n = PARA_LFR_N;
+ int encoder_size = 512;
+ int fsmn_layers = 16;
+ int fsmn_lorder = 10;
+ int fsmn_dims = 512;
+ float cif_threshold = 1.0;
+ float tail_alphas = 0.45;
+
+ // configs
+ int feat_dims = lfr_m*n_mels;
+ std::vector<int> chunk_size = {5,10,5};
+ int frame_sample_length_ = MODEL_SAMPLE_RATE / 1000 * frame_length;
+ int frame_shift_sample_length_ = MODEL_SAMPLE_RATE / 1000 * frame_shift;
+
+ // The reserved waveforms by fbank
+ std::vector<float> reserve_waveforms_;
+ // waveforms reserved after last shift position
+ std::vector<float> input_cache_;
+ // lfr reserved cache
+ std::vector<std::vector<float>> lfr_splice_cache_;
+ // position index cache
+ int start_idx_cache_ = 0;
+ // cif alpha
+ std::vector<float> alphas_cache_;
+ std::vector<std::vector<float>> hidden_cache_;
+ std::vector<std::vector<float>> feats_cache_;
+ // fsmn init caches
+ std::vector<float> fsmn_init_cache_;
+ std::vector<Ort::Value> decoder_onnx;
+
+ bool is_first_chunk = true;
+ bool is_last_chunk = false;
+ double sqrt_factor;
+
+ public:
+ ParaformerOnline(Paraformer* para_handle, std::vector<int> chunk_size);
+ ~ParaformerOnline();
+ void Reset();
+ void ResetCache();
+ void InitCache();
+ void ExtractFeats(float sample_rate, vector<vector<float>> &wav_feats, vector<float> &waves, bool input_finished);
+ void AddOverlapChunk(std::vector<std::vector<float>> &wav_feats, bool input_finished);
+
+ string ForwardChunk(std::vector<std::vector<float>> &wav_feats, bool input_finished);
+ string Forward(float* din, int len, bool input_finished);
+ string Rescoring();
+ // 2pass
+ std::string online_res;
+ int chunk_len;
+ };
+
+} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/paraformer.cpp b/funasr/runtime/onnxruntime/src/paraformer.cpp
index b605fff..ef2a182 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.cpp
+++ b/funasr/runtime/onnxruntime/src/paraformer.cpp
@@ -10,29 +10,30 @@
namespace funasr {
Paraformer::Paraformer()
-:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options{}{
+:env_(ORT_LOGGING_LEVEL_ERROR, "paraformer"),session_options_{}{
}
+// offline
void Paraformer::InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
// knf options
- fbank_opts.frame_opts.dither = 0;
- fbank_opts.mel_opts.num_bins = 80;
- fbank_opts.frame_opts.samp_freq = MODEL_SAMPLE_RATE;
- fbank_opts.frame_opts.window_type = "hamming";
- fbank_opts.frame_opts.frame_shift_ms = 10;
- fbank_opts.frame_opts.frame_length_ms = 25;
- fbank_opts.energy_floor = 0;
- fbank_opts.mel_opts.debug_mel = false;
+ fbank_opts_.frame_opts.dither = 0;
+ fbank_opts_.mel_opts.num_bins = n_mels;
+ fbank_opts_.frame_opts.samp_freq = MODEL_SAMPLE_RATE;
+ fbank_opts_.frame_opts.window_type = window_type;
+ fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
+ fbank_opts_.frame_opts.frame_length_ms = frame_length;
+ fbank_opts_.energy_floor = 0;
+ fbank_opts_.mel_opts.debug_mel = false;
// fbank_ = std::make_unique<knf::OnlineFbank>(fbank_opts);
- // session_options.SetInterOpNumThreads(1);
- session_options.SetIntraOpNumThreads(thread_num);
- session_options.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
+ // session_options_.SetInterOpNumThreads(1);
+ session_options_.SetIntraOpNumThreads(thread_num);
+ session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
// DisableCpuMemArena can improve performance
- session_options.DisableCpuMemArena();
+ session_options_.DisableCpuMemArena();
try {
- m_session = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options);
+ m_session_ = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options_);
LOG(INFO) << "Successfully load model from " << am_model;
} catch (std::exception const &e) {
LOG(ERROR) << "Error when load am onnx model: " << e.what();
@@ -40,14 +41,14 @@
}
string strName;
- GetInputName(m_session.get(), strName);
+ GetInputName(m_session_.get(), strName);
m_strInputNames.push_back(strName.c_str());
- GetInputName(m_session.get(), strName,1);
+ GetInputName(m_session_.get(), strName,1);
m_strInputNames.push_back(strName);
- GetOutputName(m_session.get(), strName);
+ GetOutputName(m_session_.get(), strName);
m_strOutputNames.push_back(strName);
- GetOutputName(m_session.get(), strName,1);
+ GetOutputName(m_session_.get(), strName,1);
m_strOutputNames.push_back(strName);
for (auto& item : m_strInputNames)
@@ -56,6 +57,152 @@
m_szOutputNames.push_back(item.c_str());
vocab = new Vocab(am_config.c_str());
LoadCmvn(am_cmvn.c_str());
+}
+
+// online
+void Paraformer::InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
+
+ LoadOnlineConfigFromYaml(am_config.c_str());
+ // knf options
+ fbank_opts_.frame_opts.dither = 0;
+ fbank_opts_.mel_opts.num_bins = n_mels;
+ fbank_opts_.frame_opts.samp_freq = MODEL_SAMPLE_RATE;
+ fbank_opts_.frame_opts.window_type = window_type;
+ fbank_opts_.frame_opts.frame_shift_ms = frame_shift;
+ fbank_opts_.frame_opts.frame_length_ms = frame_length;
+ fbank_opts_.energy_floor = 0;
+ fbank_opts_.mel_opts.debug_mel = false;
+
+ // session_options_.SetInterOpNumThreads(1);
+ session_options_.SetIntraOpNumThreads(thread_num);
+ session_options_.SetGraphOptimizationLevel(ORT_ENABLE_ALL);
+ // DisableCpuMemArena can improve performance
+ session_options_.DisableCpuMemArena();
+
+ try {
+ encoder_session_ = std::make_unique<Ort::Session>(env_, en_model.c_str(), session_options_);
+ LOG(INFO) << "Successfully load model from " << en_model;
+ } catch (std::exception const &e) {
+ LOG(ERROR) << "Error when load am encoder model: " << e.what();
+ exit(0);
+ }
+
+ try {
+ decoder_session_ = std::make_unique<Ort::Session>(env_, de_model.c_str(), session_options_);
+ LOG(INFO) << "Successfully load model from " << de_model;
+ } catch (std::exception const &e) {
+ LOG(ERROR) << "Error when load am decoder model: " << e.what();
+ exit(0);
+ }
+
+ // encoder
+ string strName;
+ GetInputName(encoder_session_.get(), strName);
+ en_strInputNames.push_back(strName.c_str());
+ GetInputName(encoder_session_.get(), strName,1);
+ en_strInputNames.push_back(strName);
+
+ GetOutputName(encoder_session_.get(), strName);
+ en_strOutputNames.push_back(strName);
+ GetOutputName(encoder_session_.get(), strName,1);
+ en_strOutputNames.push_back(strName);
+ GetOutputName(encoder_session_.get(), strName,2);
+ en_strOutputNames.push_back(strName);
+
+ for (auto& item : en_strInputNames)
+ en_szInputNames_.push_back(item.c_str());
+ for (auto& item : en_strOutputNames)
+ en_szOutputNames_.push_back(item.c_str());
+
+ // decoder
+ int de_input_len = 4 + fsmn_layers;
+ int de_out_len = 2 + fsmn_layers;
+ for(int i=0;i<de_input_len; i++){
+ GetInputName(decoder_session_.get(), strName, i);
+ de_strInputNames.push_back(strName.c_str());
+ }
+
+ for(int i=0;i<de_out_len; i++){
+ GetOutputName(decoder_session_.get(), strName,i);
+ de_strOutputNames.push_back(strName);
+ }
+
+ for (auto& item : de_strInputNames)
+ de_szInputNames_.push_back(item.c_str());
+ for (auto& item : de_strOutputNames)
+ de_szOutputNames_.push_back(item.c_str());
+
+ vocab = new Vocab(am_config.c_str());
+ LoadCmvn(am_cmvn.c_str());
+}
+
+// 2pass
+void Paraformer::InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num){
+ // online
+ InitAsr(en_model, de_model, am_cmvn, am_config, thread_num);
+
+ // offline
+ try {
+ m_session_ = std::make_unique<Ort::Session>(env_, am_model.c_str(), session_options_);
+ LOG(INFO) << "Successfully load model from " << am_model;
+ } catch (std::exception const &e) {
+ LOG(ERROR) << "Error when load am onnx model: " << e.what();
+ exit(0);
+ }
+
+ string strName;
+ GetInputName(m_session_.get(), strName);
+ m_strInputNames.push_back(strName.c_str());
+ GetInputName(m_session_.get(), strName,1);
+ m_strInputNames.push_back(strName);
+
+ GetOutputName(m_session_.get(), strName);
+ m_strOutputNames.push_back(strName);
+ GetOutputName(m_session_.get(), strName,1);
+ m_strOutputNames.push_back(strName);
+
+ for (auto& item : m_strInputNames)
+ m_szInputNames.push_back(item.c_str());
+ for (auto& item : m_strOutputNames)
+ m_szOutputNames.push_back(item.c_str());
+}
+
+void Paraformer::LoadOnlineConfigFromYaml(const char* filename){
+
+ YAML::Node config;
+ try{
+ config = YAML::LoadFile(filename);
+ }catch(exception const &e){
+ LOG(ERROR) << "Error loading file, yaml file error or not exist.";
+ exit(-1);
+ }
+
+ try{
+ YAML::Node frontend_conf = config["frontend_conf"];
+ YAML::Node encoder_conf = config["encoder_conf"];
+ YAML::Node decoder_conf = config["decoder_conf"];
+ YAML::Node predictor_conf = config["predictor_conf"];
+
+ this->window_type = frontend_conf["window"].as<string>();
+ this->n_mels = frontend_conf["n_mels"].as<int>();
+ this->frame_length = frontend_conf["frame_length"].as<int>();
+ this->frame_shift = frontend_conf["frame_shift"].as<int>();
+ this->lfr_m = frontend_conf["lfr_m"].as<int>();
+ this->lfr_n = frontend_conf["lfr_n"].as<int>();
+
+ this->encoder_size = encoder_conf["output_size"].as<int>();
+ this->fsmn_dims = encoder_conf["output_size"].as<int>();
+
+ this->fsmn_layers = decoder_conf["num_blocks"].as<int>();
+ this->fsmn_lorder = decoder_conf["kernel_size"].as<int>()-1;
+
+ this->cif_threshold = predictor_conf["threshold"].as<double>();
+ this->tail_alphas = predictor_conf["tail_threshold"].as<double>();
+
+ }catch(exception const &e){
+ LOG(ERROR) << "Error when load argument from vad config YAML.";
+ exit(-1);
+ }
}
Paraformer::~Paraformer()
@@ -69,7 +216,7 @@
}
vector<float> Paraformer::FbankKaldi(float sample_rate, const float* waves, int len) {
- knf::OnlineFbank fbank_(fbank_opts);
+ knf::OnlineFbank fbank_(fbank_opts_);
std::vector<float> buf(len);
for (int32_t i = 0; i != len; ++i) {
buf[i] = waves[i] * 32768;
@@ -77,7 +224,7 @@
fbank_.AcceptWaveform(sample_rate, buf.data(), buf.size());
//fbank_->InputFinished();
int32_t frames = fbank_.NumFramesReady();
- int32_t feature_dim = fbank_opts.mel_opts.num_bins;
+ int32_t feature_dim = fbank_opts_.mel_opts.num_bins;
vector<float> features(frames * feature_dim);
float *p = features.data();
@@ -108,7 +255,7 @@
vector<string> means_lines{istream_iterator<string>{means_lines_stream}, istream_iterator<string>{}};
if (means_lines[0] == "<LearnRateCoef>") {
for (int j = 3; j < means_lines.size() - 1; j++) {
- means_list.push_back(stof(means_lines[j]));
+ means_list_.push_back(stof(means_lines[j]));
}
continue;
}
@@ -119,7 +266,7 @@
vector<string> vars_lines{istream_iterator<string>{vars_lines_stream}, istream_iterator<string>{}};
if (vars_lines[0] == "<LearnRateCoef>") {
for (int j = 3; j < vars_lines.size() - 1; j++) {
- vars_list.push_back(stof(vars_lines[j])*scale);
+ vars_list_.push_back(stof(vars_lines[j])*scale);
}
continue;
}
@@ -143,11 +290,11 @@
vector<float> Paraformer::ApplyLfr(const std::vector<float> &in)
{
- int32_t in_feat_dim = fbank_opts.mel_opts.num_bins;
+ int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
int32_t in_num_frames = in.size() / in_feat_dim;
int32_t out_num_frames =
- (in_num_frames - lfr_window_size) / lfr_window_shift + 1;
- int32_t out_feat_dim = in_feat_dim * lfr_window_size;
+ (in_num_frames - lfr_m) / lfr_n + 1;
+ int32_t out_feat_dim = in_feat_dim * lfr_m;
std::vector<float> out(out_num_frames * out_feat_dim);
@@ -158,7 +305,7 @@
std::copy(p_in, p_in + out_feat_dim, p_out);
p_out += out_feat_dim;
- p_in += lfr_window_shift * in_feat_dim;
+ p_in += lfr_n * in_feat_dim;
}
return out;
@@ -166,29 +313,29 @@
void Paraformer::ApplyCmvn(std::vector<float> *v)
{
- int32_t dim = means_list.size();
+ int32_t dim = means_list_.size();
int32_t num_frames = v->size() / dim;
float *p = v->data();
for (int32_t i = 0; i != num_frames; ++i) {
for (int32_t k = 0; k != dim; ++k) {
- p[k] = (p[k] + means_list[k]) * vars_list[k];
+ p[k] = (p[k] + means_list_[k]) * vars_list_[k];
}
p += dim;
}
}
-string Paraformer::Forward(float* din, int len, int flag)
+string Paraformer::Forward(float* din, int len, bool input_finished)
{
- int32_t in_feat_dim = fbank_opts.mel_opts.num_bins;
+ int32_t in_feat_dim = fbank_opts_.mel_opts.num_bins;
std::vector<float> wav_feats = FbankKaldi(MODEL_SAMPLE_RATE, din, len);
wav_feats = ApplyLfr(wav_feats);
ApplyCmvn(&wav_feats);
- int32_t feat_dim = lfr_window_size*in_feat_dim;
+ int32_t feat_dim = lfr_m*in_feat_dim;
int32_t num_frames = wav_feats.size() / feat_dim;
#ifdef _WIN_X86
@@ -216,7 +363,7 @@
string result;
try {
- auto outputTensor = m_session->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), input_onnx.size(), m_szOutputNames.data(), m_szOutputNames.size());
+ auto outputTensor = m_session_->Run(Ort::RunOptions{nullptr}, m_szInputNames.data(), input_onnx.data(), input_onnx.size(), m_szOutputNames.data(), m_szOutputNames.size());
std::vector<int64_t> outputShape = outputTensor[0].GetTensorTypeAndShapeInfo().GetShape();
int64_t outputCount = std::accumulate(outputShape.begin(), outputShape.end(), 1, std::multiplies<int64_t>());
@@ -230,13 +377,6 @@
}
return result;
-}
-
-string Paraformer::ForwardChunk(float* din, int len, int flag)
-{
-
- LOG(ERROR)<<"Not Imp!!!!!!";
- return "";
}
string Paraformer::Rescoring()
diff --git a/funasr/runtime/onnxruntime/src/paraformer.h b/funasr/runtime/onnxruntime/src/paraformer.h
index 9df0977..16460bf 100644
--- a/funasr/runtime/onnxruntime/src/paraformer.h
+++ b/funasr/runtime/onnxruntime/src/paraformer.h
@@ -15,38 +15,66 @@
* https://arxiv.org/pdf/2206.08317.pdf
*/
private:
- //std::unique_ptr<knf::OnlineFbank> fbank_;
- knf::FbankOptions fbank_opts;
-
Vocab* vocab = nullptr;
- vector<float> means_list;
- vector<float> vars_list;
- const float scale = 22.6274169979695;
- int32_t lfr_window_size = 7;
- int32_t lfr_window_shift = 6;
+ //const float scale = 22.6274169979695;
+ const float scale = 1.0;
+ void LoadOnlineConfigFromYaml(const char* filename);
void LoadCmvn(const char *filename);
vector<float> ApplyLfr(const vector<float> &in);
void ApplyCmvn(vector<float> *v);
- string GreedySearch( float* in, int n_len, int64_t token_nums);
-
- std::shared_ptr<Ort::Session> m_session = nullptr;
- Ort::Env env_;
- Ort::SessionOptions session_options;
-
- vector<string> m_strInputNames, m_strOutputNames;
- vector<const char*> m_szInputNames;
- vector<const char*> m_szOutputNames;
public:
Paraformer();
~Paraformer();
void InitAsr(const std::string &am_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
+ // online
+ void InitAsr(const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
+ // 2pass
+ void InitAsr(const std::string &am_model, const std::string &en_model, const std::string &de_model, const std::string &am_cmvn, const std::string &am_config, int thread_num);
void Reset();
vector<float> FbankKaldi(float sample_rate, const float* waves, int len);
- string ForwardChunk(float* din, int len, int flag);
- string Forward(float* din, int len, int flag);
+ string Forward(float* din, int len, bool input_finished=true);
+ string GreedySearch( float* in, int n_len, int64_t token_nums);
string Rescoring();
+
+ knf::FbankOptions fbank_opts_;
+ vector<float> means_list_;
+ vector<float> vars_list_;
+ int lfr_m = PARA_LFR_M;
+ int lfr_n = PARA_LFR_N;
+
+ // paraformer-offline
+ std::shared_ptr<Ort::Session> m_session_ = nullptr;
+ Ort::Env env_;
+ Ort::SessionOptions session_options_;
+
+ vector<string> m_strInputNames, m_strOutputNames;
+ vector<const char*> m_szInputNames;
+ vector<const char*> m_szOutputNames;
+
+ // paraformer-online
+ std::shared_ptr<Ort::Session> encoder_session_ = nullptr;
+ std::shared_ptr<Ort::Session> decoder_session_ = nullptr;
+ vector<string> en_strInputNames, en_strOutputNames;
+ vector<const char*> en_szInputNames_;
+ vector<const char*> en_szOutputNames_;
+ vector<string> de_strInputNames, de_strOutputNames;
+ vector<const char*> de_szInputNames_;
+ vector<const char*> de_szOutputNames_;
+
+ string window_type = "hamming";
+ int frame_length = 25;
+ int frame_shift = 10;
+ int n_mels = 80;
+ int encoder_size = 512;
+ int fsmn_layers = 16;
+ int fsmn_lorder = 10;
+ int fsmn_dims = 512;
+ float cif_threshold = 1.0;
+ float tail_alphas = 0.45;
+
+
};
} // namespace funasr
diff --git a/funasr/runtime/onnxruntime/src/precomp.h b/funasr/runtime/onnxruntime/src/precomp.h
index 26ed2c5..5298b4b 100644
--- a/funasr/runtime/onnxruntime/src/precomp.h
+++ b/funasr/runtime/onnxruntime/src/precomp.h
@@ -33,18 +33,20 @@
#include "model.h"
#include "vad-model.h"
#include "punc-model.h"
-#include "offline-stream.h"
#include "tokenizer.h"
#include "ct-transformer.h"
#include "ct-transformer-online.h"
#include "e2e-vad.h"
#include "fsmn-vad.h"
-#include "fsmn-vad-online.h"
#include "vocab.h"
#include "audio.h"
+#include "fsmn-vad-online.h"
#include "tensor.h"
#include "util.h"
#include "resample.h"
#include "paraformer.h"
+#include "paraformer-online.h"
#include "offline-stream.h"
+#include "tpass-stream.h"
+#include "tpass-online-stream.h"
#include "funasrruntime.h"
diff --git a/funasr/runtime/onnxruntime/src/tpass-online-stream.cpp b/funasr/runtime/onnxruntime/src/tpass-online-stream.cpp
new file mode 100644
index 0000000..d99c871
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/tpass-online-stream.cpp
@@ -0,0 +1,29 @@
+#include "precomp.h"
+#include <unistd.h>
+
+namespace funasr {
+TpassOnlineStream::TpassOnlineStream(TpassStream* tpass_stream, std::vector<int> chunk_size){
+ TpassStream* tpass_obj = (TpassStream*)tpass_stream;
+ if(tpass_obj->vad_handle){
+ vad_online_handle = make_unique<FsmnVadOnline>((FsmnVad*)(tpass_obj->vad_handle).get());
+ }else{
+ LOG(ERROR)<<"asr_handle is null";
+ exit(-1);
+ }
+
+ if(tpass_obj->asr_handle){
+ asr_online_handle = make_unique<ParaformerOnline>((Paraformer*)(tpass_obj->asr_handle).get(), chunk_size);
+ }else{
+ LOG(ERROR)<<"asr_handle is null";
+ exit(-1);
+ }
+}
+
+TpassOnlineStream* CreateTpassOnlineStream(void* tpass_stream, std::vector<int> chunk_size)
+{
+ TpassOnlineStream *mm;
+ mm =new TpassOnlineStream((TpassStream*)tpass_stream, chunk_size);
+ return mm;
+}
+
+} // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/onnxruntime/src/tpass-stream.cpp b/funasr/runtime/onnxruntime/src/tpass-stream.cpp
new file mode 100644
index 0000000..9377286
--- /dev/null
+++ b/funasr/runtime/onnxruntime/src/tpass-stream.cpp
@@ -0,0 +1,87 @@
+#include "precomp.h"
+#include <unistd.h>
+
+namespace funasr {
+TpassStream::TpassStream(std::map<std::string, std::string>& model_path, int thread_num)
+{
+ // VAD model
+ if(model_path.find(VAD_DIR) != model_path.end()){
+ string vad_model_path;
+ string vad_cmvn_path;
+ string vad_config_path;
+
+ vad_model_path = PathAppend(model_path.at(VAD_DIR), MODEL_NAME);
+ if(model_path.find(VAD_QUANT) != model_path.end() && model_path.at(VAD_QUANT) == "true"){
+ vad_model_path = PathAppend(model_path.at(VAD_DIR), QUANT_MODEL_NAME);
+ }
+ vad_cmvn_path = PathAppend(model_path.at(VAD_DIR), VAD_CMVN_NAME);
+ vad_config_path = PathAppend(model_path.at(VAD_DIR), VAD_CONFIG_NAME);
+ if (access(vad_model_path.c_str(), F_OK) != 0 ||
+ access(vad_cmvn_path.c_str(), F_OK) != 0 ||
+ access(vad_config_path.c_str(), F_OK) != 0 )
+ {
+ LOG(INFO) << "VAD model file is not exist, skip load vad model.";
+ }else{
+ vad_handle = make_unique<FsmnVad>();
+ vad_handle->InitVad(vad_model_path, vad_cmvn_path, vad_config_path, thread_num);
+ use_vad = true;
+ }
+ }
+
+ // AM model
+ if(model_path.find(OFFLINE_MODEL_DIR) != model_path.end() && model_path.find(ONLINE_MODEL_DIR) != model_path.end()){
+ // 2pass
+ string am_model_path;
+ string en_model_path;
+ string de_model_path;
+ string am_cmvn_path;
+ string am_config_path;
+
+ am_model_path = PathAppend(model_path.at(OFFLINE_MODEL_DIR), MODEL_NAME);
+ en_model_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), ENCODER_NAME);
+ de_model_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), DECODER_NAME);
+ if(model_path.find(QUANTIZE) != model_path.end() && model_path.at(QUANTIZE) == "true"){
+ am_model_path = PathAppend(model_path.at(OFFLINE_MODEL_DIR), QUANT_MODEL_NAME);
+ en_model_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), QUANT_ENCODER_NAME);
+ de_model_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), QUANT_DECODER_NAME);
+ }
+ am_cmvn_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), AM_CMVN_NAME);
+ am_config_path = PathAppend(model_path.at(ONLINE_MODEL_DIR), AM_CONFIG_NAME);
+
+ asr_handle = make_unique<Paraformer>();
+ asr_handle->InitAsr(am_model_path, en_model_path, de_model_path, am_cmvn_path, am_config_path, thread_num);
+ }else{
+ LOG(ERROR) <<"Can not find offline-model-dir or online-model-dir";
+ exit(-1);
+ }
+
+ // PUNC model
+ if(model_path.find(PUNC_DIR) != model_path.end()){
+ string punc_model_path;
+ string punc_config_path;
+
+ punc_model_path = PathAppend(model_path.at(PUNC_DIR), MODEL_NAME);
+ if(model_path.find(PUNC_QUANT) != model_path.end() && model_path.at(PUNC_QUANT) == "true"){
+ punc_model_path = PathAppend(model_path.at(PUNC_DIR), QUANT_MODEL_NAME);
+ }
+ punc_config_path = PathAppend(model_path.at(PUNC_DIR), PUNC_CONFIG_NAME);
+
+ if (access(punc_model_path.c_str(), F_OK) != 0 ||
+ access(punc_config_path.c_str(), F_OK) != 0 )
+ {
+ LOG(INFO) << "PUNC model file is not exist, skip load punc model.";
+ }else{
+ punc_online_handle = make_unique<CTTransformerOnline>();
+ punc_online_handle->InitPunc(punc_model_path, punc_config_path, thread_num);
+ use_punc = true;
+ }
+ }
+}
+
+TpassStream *CreateTpassStream(std::map<std::string, std::string>& model_path, int thread_num)
+{
+ TpassStream *mm;
+ mm = new TpassStream(model_path, thread_num);
+ return mm;
+}
+} // namespace funasr
\ No newline at end of file
diff --git a/funasr/runtime/python/grpc/Readme.md b/funasr/runtime/python/grpc/Readme.md
index 832b87e..13723f2 100644
--- a/funasr/runtime/python/grpc/Readme.md
+++ b/funasr/runtime/python/grpc/Readme.md
@@ -1,73 +1,27 @@
-# Service with grpc-python
-We can send streaming audio data to server in real-time with grpc client every 10 ms e.g., and get transcribed text when stop speaking.
-The audio data is in streaming, the asr inference process is in offline.
+# GRPC python Client for 2pass decoding
+The client can send streaming or full audio data to server as you wish, and get transcribed text once the server respond (depends on mode)
-## For the Server
+In the demo client, audio_chunk_duration is set to 1000ms, and send_interval is set to 100ms
-### Prepare server environment
-Install the modelscope and funasr
-
+### 1. Install the requirements
```shell
-pip install -U modelscope funasr
-# For the users in China, you could install with the command:
-# pip install -U modelscope funasr -i https://mirror.sjtu.edu.cn/pypi/web/simple
-git clone https://github.com/alibaba/FunASR.git && cd FunASR
+git clone https://github.com/alibaba/FunASR.git && cd FunASR/funasr/runtime/python/grpc
+pip install -r requirements.txt
```
-Install the requirements
-
-```shell
-cd funasr/runtime/python/grpc
-pip install -r requirements_server.txt
-```
-
-
-### Generate protobuf file
-Run on server, the two generated pb files are both used for server and client
-
+### 2. Generate protobuf file
```shell
# paraformer_pb2.py and paraformer_pb2_grpc.py are already generated,
# regenerate it only when you make changes to ./proto/paraformer.proto file.
-python -m grpc_tools.protoc --proto_path=./proto -I ./proto --python_out=. --grpc_python_out=./ ./proto/paraformer.proto
+python -m grpc_tools.protoc --proto_path=./proto -I ./proto --python_out=. --grpc_python_out=./ ./proto/paraformer.proto
```
-### Start grpc server
-
-```
-# Start server.
-python grpc_main_server.py --port 10095 --backend pipeline
-```
-
-
-## For the client
-
-### Install the requirements
-
-```shell
-git clone https://github.com/alibaba/FunASR.git && cd FunASR
-cd funasr/runtime/python/grpc
-pip install -r requirements_client.txt
-```
-
-### Generate protobuf file
-Run on server, the two generated pb files are both used for server and client
-
-```shell
-# paraformer_pb2.py and paraformer_pb2_grpc.py are already generated,
-# regenerate it only when you make changes to ./proto/paraformer.proto file.
-python -m grpc_tools.protoc --proto_path=./proto -I ./proto --python_out=. --grpc_python_out=./ ./proto/paraformer.proto
-```
-
-### Start grpc client
+### 3. Start grpc client
```
# Start client.
-python grpc_main_client_mic.py --host 127.0.0.1 --port 10095
+python grpc_main_client.py --host 127.0.0.1 --port 10100 --wav_path /path/to/your_test_wav.wav
```
-
-## Workflow in desgin
-
-<div align="left"><img src="proto/workflow.png" width="400"/>
-
## Acknowledge
-1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
\ No newline at end of file
+1. This project is maintained by [FunASR community](https://github.com/alibaba-damo-academy/FunASR).
+2. We acknowledge burkliu (鍒樻煆鍩�, liubaiji@xverse.cn) for contributing the grpc service.
diff --git a/funasr/runtime/python/grpc/grpc_client.py b/funasr/runtime/python/grpc/grpc_client.py
deleted file mode 100644
index 8f0bcd9..0000000
--- a/funasr/runtime/python/grpc/grpc_client.py
+++ /dev/null
@@ -1,17 +0,0 @@
-import queue
-import paraformer_pb2
-
-def transcribe_audio_bytes(stub, chunk, user='zksz', language='zh-CN', speaking = True, isEnd = False):
- req = paraformer_pb2.Request()
- if chunk is not None:
- req.audio_data = chunk
- req.user = user
- req.language = language
- req.speaking = speaking
- req.isEnd = isEnd
- my_queue = queue.SimpleQueue()
- my_queue.put(req)
- return stub.Recognize(iter(my_queue.get, None))
-
-
-
diff --git a/funasr/runtime/python/grpc/grpc_main_client.py b/funasr/runtime/python/grpc/grpc_main_client.py
index b6491df..92888bd 100644
--- a/funasr/runtime/python/grpc/grpc_main_client.py
+++ b/funasr/runtime/python/grpc/grpc_main_client.py
@@ -1,62 +1,78 @@
-import grpc
-import json
-import time
-import asyncio
-import soundfile as sf
+import logging
import argparse
+import soundfile as sf
+import time
-from grpc_client import transcribe_audio_bytes
-from paraformer_pb2_grpc import ASRStub
+import grpc
+import paraformer_pb2_grpc
+from paraformer_pb2 import Request, WavFormat, DecodeMode
-# send the audio data once
-async def grpc_rec(wav_scp, grpc_uri, asr_user, language):
- with grpc.insecure_channel(grpc_uri) as channel:
- stub = ASRStub(channel)
- for line in wav_scp:
- wav_file = line.split()[1]
- wav, _ = sf.read(wav_file, dtype='int16')
-
- b = time.time()
- response = transcribe_audio_bytes(stub, wav.tobytes(), user=asr_user, language=language, speaking=False, isEnd=False)
- resp = response.next()
- text = ''
- if 'decoding' == resp.action:
- resp = response.next()
- if 'finish' == resp.action:
- text = json.loads(resp.sentence)['text']
- response = transcribe_audio_bytes(stub, None, user=asr_user, language=language, speaking=False, isEnd=True)
- res= {'text': text, 'time': time.time() - b}
- print(res)
+class GrpcClient:
+ def __init__(self, wav_path, uri, mode):
+ self.wav, self.sampling_rate = sf.read(wav_path, dtype='int16')
+ self.wav_format = WavFormat.pcm
+ self.audio_chunk_duration = 1000 # ms
+ self.audio_chunk_size = int(self.sampling_rate * self.audio_chunk_duration / 1000)
+ self.send_interval = 100 # ms
+ self.mode = mode
-async def test(args):
- wav_scp = open(args.wav_scp, "r").readlines()
- uri = '{}:{}'.format(args.host, args.port)
- res = await grpc_rec(wav_scp, uri, args.user_allowed, language = 'zh-CN')
+ # connect to grpc server
+ channel = grpc.insecure_channel(uri)
+ self.stub = paraformer_pb2_grpc.ASRStub(channel)
+
+ # start request
+ for respond in self.stub.Recognize(self.request_iterator()):
+ logging.info("[receive] mode {}, text {}, is final {}".format(
+ DecodeMode.Name(respond.mode), respond.text, respond.is_final))
+
+ def request_iterator(self, mode = DecodeMode.two_pass):
+ is_first_pack = True
+ is_final = False
+ for start in range(0, len(self.wav), self.audio_chunk_size):
+ request = Request()
+ audio_chunk = self.wav[start : start + self.audio_chunk_size]
+
+ if is_first_pack:
+ is_first_pack = False
+ request.sampling_rate = self.sampling_rate
+ request.mode = self.mode
+ request.wav_format = self.wav_format
+ if request.mode == DecodeMode.two_pass or request.mode == DecodeMode.online:
+ request.chunk_size.extend([5, 10, 5])
+
+ if start + self.audio_chunk_size >= len(self.wav):
+ is_final = True
+ request.is_final = is_final
+ request.audio_data = audio_chunk.tobytes()
+ logging.info("[request] audio_data len {}, is final {}".format(
+ len(request.audio_data), request.is_final)) # int16 = 2bytes
+ time.sleep(self.send_interval / 1000)
+ yield request
if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument("--host",
- type=str,
- default="127.0.0.1",
- required=False,
- help="grpc server host ip")
- parser.add_argument("--port",
- type=int,
- default=10108,
- required=False,
- help="grpc server port")
- parser.add_argument("--user_allowed",
- type=str,
- default="project1_user1",
- help="allowed user for grpc client")
- parser.add_argument("--sample_rate",
- type=int,
- default=16000,
- help="audio sample_rate from client")
- parser.add_argument("--wav_scp",
- type=str,
- required=True,
- help="audio wav scp")
- args = parser.parse_args()
-
- asyncio.run(test(args))
+ logging.basicConfig(filename="", format="%(asctime)s %(message)s", level=logging.INFO)
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--host",
+ type=str,
+ default="127.0.0.1",
+ required=False,
+ help="grpc server host ip")
+ parser.add_argument("--port",
+ type=int,
+ default=10100,
+ required=False,
+ help="grpc server port")
+ parser.add_argument("--wav_path",
+ type=str,
+ required=True,
+ help="audio wav path")
+ args = parser.parse_args()
+
+ for mode in [DecodeMode.offline, DecodeMode.online, DecodeMode.two_pass]:
+ mode_name = DecodeMode.Name(mode)
+ logging.info("[request] start requesting with mode {}".format(mode_name))
+
+ st = time.time()
+ uri = '{}:{}'.format(args.host, args.port)
+ client = GrpcClient(args.wav_path, uri, mode)
+ logging.info("mode {}, time pass: {}".format(mode_name, time.time() - st))
diff --git a/funasr/runtime/python/grpc/grpc_main_client_mic.py b/funasr/runtime/python/grpc/grpc_main_client_mic.py
deleted file mode 100644
index acbe90b..0000000
--- a/funasr/runtime/python/grpc/grpc_main_client_mic.py
+++ /dev/null
@@ -1,112 +0,0 @@
-import pyaudio
-import grpc
-import json
-import webrtcvad
-import time
-import asyncio
-import argparse
-
-from grpc_client import transcribe_audio_bytes
-from paraformer_pb2_grpc import ASRStub
-
-async def deal_chunk(sig_mic):
- global stub,SPEAKING,asr_user,language,sample_rate
- if vad.is_speech(sig_mic, sample_rate): #speaking
- SPEAKING = True
- response = transcribe_audio_bytes(stub, sig_mic, user=asr_user, language=language, speaking = True, isEnd = False) #speaking, send audio to server.
- else: #silence
- begin_time = 0
- if SPEAKING: #means we have some audio recorded, send recognize order to server.
- SPEAKING = False
- begin_time = int(round(time.time() * 1000))
- response = transcribe_audio_bytes(stub, None, user=asr_user, language=language, speaking = False, isEnd = False) #speak end, call server for recognize one sentence
- resp = response.next()
- if "decoding" == resp.action:
- resp = response.next() #TODO, blocking operation may leads to miss some audio clips. C++ multi-threading is preferred.
- if "finish" == resp.action:
- end_time = int(round(time.time() * 1000))
- print (json.loads(resp.sentence))
- print ("delay in ms: %d " % (end_time - begin_time))
- else:
- pass
-
-
-async def record(host,port,sample_rate,mic_chunk,record_seconds,asr_user,language):
- with grpc.insecure_channel('{}:{}'.format(host, port)) as channel:
- global stub
- stub = ASRStub(channel)
- for i in range(0, int(sample_rate / mic_chunk * record_seconds)):
-
- sig_mic = stream.read(mic_chunk,exception_on_overflow = False)
- await asyncio.create_task(deal_chunk(sig_mic))
-
- #end grpc
- response = transcribe_audio_bytes(stub, None, user=asr_user, language=language, speaking = False, isEnd = True)
- print (response.next().action)
-
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument("--host",
- type=str,
- default="127.0.0.1",
- required=True,
- help="grpc server host ip")
-
- parser.add_argument("--port",
- type=int,
- default=10095,
- required=True,
- help="grpc server port")
-
- parser.add_argument("--user_allowed",
- type=str,
- default="project1_user1",
- help="allowed user for grpc client")
-
- parser.add_argument("--sample_rate",
- type=int,
- default=16000,
- help="audio sample_rate from client")
-
- parser.add_argument("--mic_chunk",
- type=int,
- default=160,
- help="chunk size for mic")
-
- parser.add_argument("--record_seconds",
- type=int,
- default=120,
- help="run specified seconds then exit ")
-
- args = parser.parse_args()
-
-
- SPEAKING = False
- asr_user = args.user_allowed
- sample_rate = args.sample_rate
- language = 'zh-CN'
-
-
- vad = webrtcvad.Vad()
- vad.set_mode(1)
-
- FORMAT = pyaudio.paInt16
- CHANNELS = 1
- p = pyaudio.PyAudio()
-
- stream = p.open(format=FORMAT,
- channels=CHANNELS,
- rate=args.sample_rate,
- input=True,
- frames_per_buffer=args.mic_chunk)
-
- print("* recording")
- asyncio.run(record(args.host,args.port,args.sample_rate,args.mic_chunk,args.record_seconds,args.user_allowed,language))
- stream.stop_stream()
- stream.close()
- p.terminate()
- print("recording stop")
-
-
-
diff --git a/funasr/runtime/python/grpc/grpc_main_server.py b/funasr/runtime/python/grpc/grpc_main_server.py
deleted file mode 100644
index ae386fa..0000000
--- a/funasr/runtime/python/grpc/grpc_main_server.py
+++ /dev/null
@@ -1,68 +0,0 @@
-import grpc
-from concurrent import futures
-import argparse
-
-import paraformer_pb2_grpc
-from grpc_server import ASRServicer
-
-def serve(args):
- server = grpc.server(futures.ThreadPoolExecutor(max_workers=10),
- # interceptors=(AuthInterceptor('Bearer mysecrettoken'),)
- )
- paraformer_pb2_grpc.add_ASRServicer_to_server(
- ASRServicer(args.user_allowed, args.model, args.sample_rate, args.backend, args.onnx_dir, vad_model=args.vad_model, punc_model=args.punc_model), server)
- port = "[::]:" + str(args.port)
- server.add_insecure_port(port)
- server.start()
- print("grpc server started!")
- server.wait_for_termination()
-
-if __name__ == '__main__':
- parser = argparse.ArgumentParser()
- parser.add_argument("--port",
- type=int,
- default=10095,
- required=True,
- help="grpc server port")
-
- parser.add_argument("--user_allowed",
- type=str,
- default="project1_user1|project1_user2|project2_user3",
- help="allowed user for grpc client")
-
- parser.add_argument("--model",
- type=str,
- default="damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
- help="model from modelscope")
- parser.add_argument("--vad_model",
- type=str,
- default="damo/speech_fsmn_vad_zh-cn-16k-common-pytorch",
- help="model from modelscope")
-
- parser.add_argument("--punc_model",
- type=str,
- default="",
- help="model from modelscope")
-
- parser.add_argument("--sample_rate",
- type=int,
- default=16000,
- help="audio sample_rate from client")
-
- parser.add_argument("--backend",
- type=str,
- default="pipeline",
- choices=("pipeline", "onnxruntime"),
- help="backend, optional modelscope pipeline or onnxruntime")
-
- parser.add_argument("--onnx_dir",
- type=str,
- default="/nfs/models/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch",
- help="onnx model dir")
-
-
-
-
- args = parser.parse_args()
-
- serve(args)
diff --git a/funasr/runtime/python/grpc/grpc_server.py b/funasr/runtime/python/grpc/grpc_server.py
deleted file mode 100644
index 4fd4f95..0000000
--- a/funasr/runtime/python/grpc/grpc_server.py
+++ /dev/null
@@ -1,132 +0,0 @@
-from concurrent import futures
-import grpc
-import json
-import time
-
-import paraformer_pb2_grpc
-from paraformer_pb2 import Response
-
-
-class ASRServicer(paraformer_pb2_grpc.ASRServicer):
- def __init__(self, user_allowed, model, sample_rate, backend, onnx_dir, vad_model='', punc_model=''):
- print("ASRServicer init")
- self.backend = backend
- self.init_flag = 0
- self.client_buffers = {}
- self.client_transcription = {}
- self.auth_user = user_allowed.split("|")
- if self.backend == "pipeline":
- try:
- from modelscope.pipelines import pipeline
- from modelscope.utils.constant import Tasks
- except ImportError:
- raise ImportError(f"Please install modelscope")
- self.inference_16k_pipeline = pipeline(task=Tasks.auto_speech_recognition, model=model, vad_model=vad_model, punc_model=punc_model)
- elif self.backend == "onnxruntime":
- try:
- from funasr_onnx import Paraformer
- except ImportError:
- raise ImportError(f"Please install onnxruntime environment")
- self.inference_16k_pipeline = Paraformer(model_dir=onnx_dir)
- self.sample_rate = sample_rate
-
- def clear_states(self, user):
- self.clear_buffers(user)
- self.clear_transcriptions(user)
-
- def clear_buffers(self, user):
- if user in self.client_buffers:
- del self.client_buffers[user]
-
- def clear_transcriptions(self, user):
- if user in self.client_transcription:
- del self.client_transcription[user]
-
- def disconnect(self, user):
- self.clear_states(user)
- print("Disconnecting user: %s" % str(user))
-
- def Recognize(self, request_iterator, context):
-
-
- for req in request_iterator:
- if req.user not in self.auth_user:
- result = {}
- result["success"] = False
- result["detail"] = "Not Authorized user: %s " % req.user
- result["text"] = ""
- yield Response(sentence=json.dumps(result), user=req.user, action="terminate", language=req.language)
- elif req.isEnd: #end grpc
- print("asr end")
- self.disconnect(req.user)
- result = {}
- result["success"] = True
- result["detail"] = "asr end"
- result["text"] = ""
- yield Response(sentence=json.dumps(result), user=req.user, action="terminate",language=req.language)
- elif req.speaking: #continue speaking
- if req.audio_data is not None and len(req.audio_data) > 0:
- if req.user in self.client_buffers:
- self.client_buffers[req.user] += req.audio_data #append audio
- else:
- self.client_buffers[req.user] = req.audio_data
- result = {}
- result["success"] = True
- result["detail"] = "speaking"
- result["text"] = ""
- yield Response(sentence=json.dumps(result), user=req.user, action="speaking", language=req.language)
- elif not req.speaking: #silence
- if req.user not in self.client_buffers:
- result = {}
- result["success"] = True
- result["detail"] = "waiting_for_more_voice"
- result["text"] = ""
- yield Response(sentence=json.dumps(result), user=req.user, action="waiting", language=req.language)
- else:
- begin_time = int(round(time.time() * 1000))
- tmp_data = self.client_buffers[req.user]
- self.clear_states(req.user)
- result = {}
- result["success"] = True
- result["detail"] = "decoding data: %d bytes" % len(tmp_data)
- result["text"] = ""
- yield Response(sentence=json.dumps(result), user=req.user, action="decoding", language=req.language)
- if len(tmp_data) < 9600: #min input_len for asr model , 300ms
- end_time = int(round(time.time() * 1000))
- delay_str = str(end_time - begin_time)
- result = {}
- result["success"] = True
- result["detail"] = "waiting_for_more_voice"
- result["server_delay_ms"] = delay_str
- result["text"] = ""
- print ("user: %s , delay(ms): %s, info: %s " % (req.user, delay_str, "waiting_for_more_voice"))
- yield Response(sentence=json.dumps(result), user=req.user, action="waiting", language=req.language)
- else:
- if self.backend == "pipeline":
- asr_result = self.inference_16k_pipeline(audio_in=tmp_data, audio_fs = self.sample_rate)
- if "text" in asr_result:
- asr_result = asr_result['text']
- else:
- asr_result = ""
- elif self.backend == "onnxruntime":
- from funasr_onnx.utils.frontend import load_bytes
- array = load_bytes(tmp_data)
- asr_result = self.inference_16k_pipeline(array)[0]
- end_time = int(round(time.time() * 1000))
- delay_str = str(end_time - begin_time)
- print ("user: %s , delay(ms): %s, text: %s " % (req.user, delay_str, asr_result))
- result = {}
- result["success"] = True
- result["detail"] = "finish_sentence"
- result["server_delay_ms"] = delay_str
- result["text"] = asr_result
- yield Response(sentence=json.dumps(result), user=req.user, action="finish", language=req.language)
- else:
- result = {}
- result["success"] = False
- result["detail"] = "error, no condition matched! Unknown reason."
- result["text"] = ""
- self.disconnect(req.user)
- yield Response(sentence=json.dumps(result), user=req.user, action="terminate", language=req.language)
-
-
diff --git a/funasr/runtime/python/grpc/paraformer_pb2.py b/funasr/runtime/python/grpc/paraformer_pb2.py
deleted file mode 100644
index 05e05ff..0000000
--- a/funasr/runtime/python/grpc/paraformer_pb2.py
+++ /dev/null
@@ -1,30 +0,0 @@
-# -*- coding: utf-8 -*-
-# Generated by the protocol buffer compiler. DO NOT EDIT!
-# source: paraformer.proto
-"""Generated protocol buffer code."""
-from google.protobuf.internal import builder as _builder
-from google.protobuf import descriptor as _descriptor
-from google.protobuf import descriptor_pool as _descriptor_pool
-from google.protobuf import symbol_database as _symbol_database
-# @@protoc_insertion_point(imports)
-
-_sym_db = _symbol_database.Default()
-
-
-
-
-DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x10paraformer.proto\x12\nparaformer\"^\n\x07Request\x12\x12\n\naudio_data\x18\x01 \x01(\x0c\x12\x0c\n\x04user\x18\x02 \x01(\t\x12\x10\n\x08language\x18\x03 \x01(\t\x12\x10\n\x08speaking\x18\x04 \x01(\x08\x12\r\n\x05isEnd\x18\x05 \x01(\x08\"L\n\x08Response\x12\x10\n\x08sentence\x18\x01 \x01(\t\x12\x0c\n\x04user\x18\x02 \x01(\t\x12\x10\n\x08language\x18\x03 \x01(\t\x12\x0e\n\x06\x61\x63tion\x18\x04 \x01(\t2C\n\x03\x41SR\x12<\n\tRecognize\x12\x13.paraformer.Request\x1a\x14.paraformer.Response\"\x00(\x01\x30\x01\x42\x16\n\x07\x65x.grpc\xa2\x02\nparaformerb\x06proto3')
-
-_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, globals())
-_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'paraformer_pb2', globals())
-if _descriptor._USE_C_DESCRIPTORS == False:
-
- DESCRIPTOR._options = None
- DESCRIPTOR._serialized_options = b'\n\007ex.grpc\242\002\nparaformer'
- _REQUEST._serialized_start=32
- _REQUEST._serialized_end=126
- _RESPONSE._serialized_start=128
- _RESPONSE._serialized_end=204
- _ASR._serialized_start=206
- _ASR._serialized_end=273
-# @@protoc_insertion_point(module_scope)
diff --git a/funasr/runtime/python/grpc/paraformer_pb2_grpc.py b/funasr/runtime/python/grpc/paraformer_pb2_grpc.py
deleted file mode 100644
index 035563e..0000000
--- a/funasr/runtime/python/grpc/paraformer_pb2_grpc.py
+++ /dev/null
@@ -1,66 +0,0 @@
-# Generated by the gRPC Python protocol compiler plugin. DO NOT EDIT!
-"""Client and server classes corresponding to protobuf-defined services."""
-import grpc
-
-import paraformer_pb2 as paraformer__pb2
-
-
-class ASRStub(object):
- """Missing associated documentation comment in .proto file."""
-
- def __init__(self, channel):
- """Constructor.
-
- Args:
- channel: A grpc.Channel.
- """
- self.Recognize = channel.stream_stream(
- '/paraformer.ASR/Recognize',
- request_serializer=paraformer__pb2.Request.SerializeToString,
- response_deserializer=paraformer__pb2.Response.FromString,
- )
-
-
-class ASRServicer(object):
- """Missing associated documentation comment in .proto file."""
-
- def Recognize(self, request_iterator, context):
- """Missing associated documentation comment in .proto file."""
- context.set_code(grpc.StatusCode.UNIMPLEMENTED)
- context.set_details('Method not implemented!')
- raise NotImplementedError('Method not implemented!')
-
-
-def add_ASRServicer_to_server(servicer, server):
- rpc_method_handlers = {
- 'Recognize': grpc.stream_stream_rpc_method_handler(
- servicer.Recognize,
- request_deserializer=paraformer__pb2.Request.FromString,
- response_serializer=paraformer__pb2.Response.SerializeToString,
- ),
- }
- generic_handler = grpc.method_handlers_generic_handler(
- 'paraformer.ASR', rpc_method_handlers)
- server.add_generic_rpc_handlers((generic_handler,))
-
-
- # This class is part of an EXPERIMENTAL API.
-class ASR(object):
- """Missing associated documentation comment in .proto file."""
-
- @staticmethod
- def Recognize(request_iterator,
- target,
- options=(),
- channel_credentials=None,
- call_credentials=None,
- insecure=False,
- compression=None,
- wait_for_ready=None,
- timeout=None,
- metadata=None):
- return grpc.experimental.stream_stream(request_iterator, target, '/paraformer.ASR/Recognize',
- paraformer__pb2.Request.SerializeToString,
- paraformer__pb2.Response.FromString,
- options, channel_credentials,
- insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
diff --git a/funasr/runtime/python/grpc/proto/paraformer.proto b/funasr/runtime/python/grpc/proto/paraformer.proto
index 6c336a8..85e8534 100644
--- a/funasr/runtime/python/grpc/proto/paraformer.proto
+++ b/funasr/runtime/python/grpc/proto/paraformer.proto
@@ -1,3 +1,8 @@
+// Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+// Reserved. MIT License (https://opensource.org/licenses/MIT)
+//
+// 2023 by burkliu(鍒樻煆鍩�) liubaiji@xverse.cn
+
syntax = "proto3";
option objc_class_prefix = "paraformer";
@@ -8,17 +13,27 @@
rpc Recognize (stream Request) returns (stream Response) {}
}
+enum WavFormat {
+ pcm = 0;
+}
+
+enum DecodeMode {
+ offline = 0;
+ online = 1;
+ two_pass = 2;
+}
+
message Request {
- bytes audio_data = 1;
- string user = 2;
- string language = 3;
- bool speaking = 4;
- bool isEnd = 5;
+ DecodeMode mode = 1;
+ WavFormat wav_format = 2;
+ int32 sampling_rate = 3;
+ repeated int32 chunk_size = 4;
+ bool is_final = 5;
+ bytes audio_data = 6;
}
message Response {
- string sentence = 1;
- string user = 2;
- string language = 3;
- string action = 4;
+ DecodeMode mode = 1;
+ string text = 2;
+ bool is_final = 3;
}
diff --git a/funasr/runtime/python/grpc/requirements_server.txt b/funasr/runtime/python/grpc/requirements.txt
similarity index 94%
rename from funasr/runtime/python/grpc/requirements_server.txt
rename to funasr/runtime/python/grpc/requirements.txt
index a6646e7..ee677c6 100644
--- a/funasr/runtime/python/grpc/requirements_server.txt
+++ b/funasr/runtime/python/grpc/requirements.txt
@@ -1,2 +1,2 @@
grpcio
-grpcio-tools
+grpcio-tools
\ No newline at end of file
diff --git a/funasr/runtime/python/grpc/requirements_client.txt b/funasr/runtime/python/grpc/requirements_client.txt
deleted file mode 100644
index 4daa02c..0000000
--- a/funasr/runtime/python/grpc/requirements_client.txt
+++ /dev/null
@@ -1,4 +0,0 @@
-pyaudio
-webrtcvad
-grpcio
-grpcio-tools
\ No newline at end of file
diff --git a/funasr/runtime/python/onnxruntime/demo_paraformer_online.py b/funasr/runtime/python/onnxruntime/demo_paraformer_online.py
new file mode 100644
index 0000000..b5c9371
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/demo_paraformer_online.py
@@ -0,0 +1,30 @@
+import soundfile
+from funasr_onnx.paraformer_online_bin import Paraformer
+from pathlib import Path
+
+model_dir = "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online"
+wav_path = '{}/.cache/modelscope/hub/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online/example/asr_example.wav'.format(Path.home())
+
+chunk_size = [5, 10, 5]
+model = Paraformer(model_dir, batch_size=1, quantize=True, chunk_size=chunk_size, intra_op_num_threads=4) # only support batch_size = 1
+
+##online asr
+speech, sample_rate = soundfile.read(wav_path)
+speech_length = speech.shape[0]
+sample_offset = 0
+step = chunk_size[1] * 960
+param_dict = {'cache': dict()}
+final_result = ""
+for sample_offset in range(0, speech_length, min(step, speech_length - sample_offset)):
+ if sample_offset + step >= speech_length - 1:
+ step = speech_length - sample_offset
+ is_final = True
+ else:
+ is_final = False
+ param_dict['is_final'] = is_final
+ rec_result = model(audio_in=speech[sample_offset: sample_offset + step],
+ param_dict=param_dict)
+ if len(rec_result) > 0:
+ final_result += rec_result[0]["preds"][0]
+ print(rec_result)
+print(final_result)
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
new file mode 100644
index 0000000..1e0611e
--- /dev/null
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/paraformer_online_bin.py
@@ -0,0 +1,309 @@
+# -*- encoding: utf-8 -*-
+
+import os.path
+from pathlib import Path
+from typing import List, Union, Tuple
+
+import copy
+import librosa
+import numpy as np
+
+from .utils.utils import (CharTokenizer, Hypothesis, ONNXRuntimeError,
+ OrtInferSession, TokenIDConverter, get_logger,
+ read_yaml)
+from .utils.postprocess_utils import sentence_postprocess
+from .utils.frontend import WavFrontendOnline, SinusoidalPositionEncoderOnline
+
+logging = get_logger()
+
+
+class Paraformer():
+ def __init__(self, model_dir: Union[str, Path] = None,
+ batch_size: int = 1,
+ chunk_size: List = [5, 10, 5],
+ device_id: Union[str, int] = "-1",
+ quantize: bool = False,
+ intra_op_num_threads: int = 4,
+ cache_dir: str = None
+ ):
+
+ if not Path(model_dir).exists():
+ from modelscope.hub.snapshot_download import snapshot_download
+ try:
+ model_dir = snapshot_download(model_dir, cache_dir=cache_dir)
+ except:
+ raise "model_dir must be model_name in modelscope or local path downloaded from modelscope, but is {}".format(model_dir)
+
+ encoder_model_file = os.path.join(model_dir, 'model.onnx')
+ decoder_model_file = os.path.join(model_dir, 'decoder.onnx')
+ if quantize:
+ encoder_model_file = os.path.join(model_dir, 'model_quant.onnx')
+ decoder_model_file = os.path.join(model_dir, 'decoder_quant.onnx')
+ if not os.path.exists(encoder_model_file) or not os.path.exists(decoder_model_file):
+ print(".onnx is not exist, begin to export onnx")
+ from funasr.export.export_model import ModelExport
+ export_model = ModelExport(
+ cache_dir=cache_dir,
+ onnx=True,
+ device="cpu",
+ quant=quantize,
+ )
+ export_model.export(model_dir)
+
+ config_file = os.path.join(model_dir, 'config.yaml')
+ cmvn_file = os.path.join(model_dir, 'am.mvn')
+ config = read_yaml(config_file)
+
+ self.converter = TokenIDConverter(config['token_list'])
+ self.tokenizer = CharTokenizer()
+ self.frontend = WavFrontendOnline(
+ cmvn_file=cmvn_file,
+ **config['frontend_conf']
+ )
+ self.pe = SinusoidalPositionEncoderOnline()
+ self.ort_encoder_infer = OrtInferSession(encoder_model_file, device_id,
+ intra_op_num_threads=intra_op_num_threads)
+ self.ort_decoder_infer = OrtInferSession(decoder_model_file, device_id,
+ intra_op_num_threads=intra_op_num_threads)
+ self.batch_size = batch_size
+ self.chunk_size = chunk_size
+ self.encoder_output_size = config["encoder_conf"]["output_size"]
+ self.fsmn_layer = config["decoder_conf"]["num_blocks"]
+ self.fsmn_lorder = config["decoder_conf"]["kernel_size"] - 1
+ self.fsmn_dims = config["encoder_conf"]["output_size"]
+ self.feats_dims = config["frontend_conf"]["n_mels"] * config["frontend_conf"]["lfr_m"]
+ self.cif_threshold = config["predictor_conf"]["threshold"]
+ self.tail_threshold = config["predictor_conf"]["tail_threshold"]
+
+ def prepare_cache(self, cache: dict = {}, batch_size=1):
+ if len(cache) > 0:
+ return cache
+ cache["start_idx"] = 0
+ cache["cif_hidden"] = np.zeros((batch_size, 1, self.encoder_output_size)).astype(np.float32)
+ cache["cif_alphas"] = np.zeros((batch_size, 1)).astype(np.float32)
+ cache["chunk_size"] = self.chunk_size
+ cache["last_chunk"] = False
+ cache["feats"] = np.zeros((batch_size, self.chunk_size[0] + self.chunk_size[2], self.feats_dims)).astype(np.float32)
+ cache["decoder_fsmn"] = []
+ for i in range(self.fsmn_layer):
+ fsmn_cache = np.zeros((batch_size, self.fsmn_dims, self.fsmn_lorder)).astype(np.float32)
+ cache["decoder_fsmn"].append(fsmn_cache)
+ return cache
+
+ def add_overlap_chunk(self, feats: np.ndarray, cache: dict = {}):
+ if len(cache) == 0:
+ return feats
+ # process last chunk
+ overlap_feats = np.concatenate((cache["feats"], feats), axis=1)
+ if cache["is_final"]:
+ cache["feats"] = overlap_feats[:, -self.chunk_size[0]:, :]
+ if not cache["last_chunk"]:
+ padding_length = sum(self.chunk_size) - overlap_feats.shape[1]
+ overlap_feats = np.pad(overlap_feats, ((0, 0), (0, padding_length), (0, 0)))
+ else:
+ cache["feats"] = overlap_feats[:, -(self.chunk_size[0] + self.chunk_size[2]):, :]
+ return overlap_feats
+
+ def __call__(self, audio_in: np.ndarray, **kwargs):
+ waveforms = np.expand_dims(audio_in, axis=0)
+ param_dict = kwargs.get('param_dict', dict())
+ is_final = param_dict.get('is_final', False)
+ cache = param_dict.get('cache', dict())
+ asr_res = []
+
+ if waveforms.shape[1] < 16 * 60 and is_final and len(cache) > 0:
+ cache["last_chunk"] = True
+ feats = cache["feats"]
+ feats_len = np.array([feats.shape[1]]).astype(np.int32)
+ asr_res = self.infer(feats, feats_len, cache)
+ return asr_res
+
+ feats, feats_len = self.extract_feat(waveforms, is_final)
+ if feats.shape[1] != 0:
+ feats *= self.encoder_output_size ** 0.5
+ cache = self.prepare_cache(cache)
+ cache["is_final"] = is_final
+
+ # fbank -> position encoding -> overlap chunk
+ feats = self.pe.forward(feats, cache["start_idx"])
+ cache["start_idx"] += feats.shape[1]
+ if is_final:
+ if feats.shape[1] + self.chunk_size[2] <= self.chunk_size[1]:
+ cache["last_chunk"] = True
+ feats = self.add_overlap_chunk(feats, cache)
+ else:
+ # first chunk
+ feats_chunk1 = self.add_overlap_chunk(feats[:, :self.chunk_size[1], :], cache)
+ feats_len = np.array([feats_chunk1.shape[1]]).astype(np.int32)
+ asr_res_chunk1 = self.infer(feats_chunk1, feats_len, cache)
+
+ # last chunk
+ cache["last_chunk"] = True
+ feats_chunk2 = self.add_overlap_chunk(feats[:, -(feats.shape[1] + self.chunk_size[2] - self.chunk_size[1]):, :], cache)
+ feats_len = np.array([feats_chunk2.shape[1]]).astype(np.int32)
+ asr_res_chunk2 = self.infer(feats_chunk2, feats_len, cache)
+
+ asr_res_chunk = asr_res_chunk1 + asr_res_chunk2
+ res = {}
+ for pred in asr_res_chunk:
+ for key, value in pred.items():
+ if key in res:
+ res[key][0] += value[0]
+ res[key][1].extend(value[1])
+ else:
+ res[key] = [value[0], value[1]]
+ return [res]
+ else:
+ feats = self.add_overlap_chunk(feats, cache)
+
+ feats_len = np.array([feats.shape[1]]).astype(np.int32)
+ asr_res = self.infer(feats, feats_len, cache)
+
+ return asr_res
+
+ def infer(self, feats: np.ndarray, feats_len: np.ndarray, cache):
+ # encoder forward
+ enc_input = [feats, feats_len]
+ enc, enc_lens, cif_alphas = self.ort_encoder_infer(enc_input)
+
+ # predictor forward
+ acoustic_embeds, acoustic_embeds_len = self.cif_search(enc, cif_alphas, cache)
+
+ # decoder forward
+ asr_res = []
+ if acoustic_embeds.shape[1] > 0:
+ dec_input = [enc, enc_lens, acoustic_embeds, acoustic_embeds_len]
+ dec_input.extend(cache["decoder_fsmn"])
+ dec_output = self.ort_decoder_infer(dec_input)
+ logits, sample_ids, cache["decoder_fsmn"] = dec_output[0], dec_output[1], dec_output[2:]
+ cache["decoder_fsmn"] = [item[:, :, -self.fsmn_lorder:] for item in cache["decoder_fsmn"]]
+
+ preds = self.decode(logits, acoustic_embeds_len)
+ for pred in preds:
+ pred = sentence_postprocess(pred)
+ asr_res.append({'preds': pred})
+
+ return asr_res
+
+ def load_data(self,
+ wav_content: Union[str, np.ndarray, List[str]], fs: int = None) -> List:
+ def load_wav(path: str) -> np.ndarray:
+ waveform, _ = librosa.load(path, sr=fs)
+ return waveform
+
+ if isinstance(wav_content, np.ndarray):
+ return [wav_content]
+
+ if isinstance(wav_content, str):
+ return [load_wav(wav_content)]
+
+ if isinstance(wav_content, list):
+ return [load_wav(path) for path in wav_content]
+
+ raise TypeError(
+ f'The type of {wav_content} is not in [str, np.ndarray, list]')
+
+ def extract_feat(self,
+ waveforms: np.ndarray, is_final: bool = False
+ ) -> Tuple[np.ndarray, np.ndarray]:
+ waveforms_lens = np.zeros(waveforms.shape[0]).astype(np.int32)
+ for idx, waveform in enumerate(waveforms):
+ waveforms_lens[idx] = waveform.shape[-1]
+
+ feats, feats_len = self.frontend.extract_fbank(waveforms, waveforms_lens, is_final)
+ return feats.astype(np.float32), feats_len.astype(np.int32)
+
+ def decode(self, am_scores: np.ndarray, token_nums: int) -> List[str]:
+ return [self.decode_one(am_score, token_num)
+ for am_score, token_num in zip(am_scores, token_nums)]
+
+ def decode_one(self,
+ am_score: np.ndarray,
+ valid_token_num: int) -> List[str]:
+ yseq = am_score.argmax(axis=-1)
+ score = am_score.max(axis=-1)
+ score = np.sum(score, axis=-1)
+
+ # pad with mask tokens to ensure compatibility with sos/eos tokens
+ # asr_model.sos:1 asr_model.eos:2
+ yseq = np.array([1] + yseq.tolist() + [2])
+ hyp = Hypothesis(yseq=yseq, score=score)
+
+ # remove sos/eos and get results
+ last_pos = -1
+ token_int = hyp.yseq[1:last_pos].tolist()
+
+ # remove blank symbol id, which is assumed to be 0
+ token_int = list(filter(lambda x: x not in (0, 2), token_int))
+
+ # Change integer-ids to tokens
+ token = self.converter.ids2tokens(token_int)
+ token = token[:valid_token_num]
+ # texts = sentence_postprocess(token)
+ return token
+
+ def cif_search(self, hidden, alphas, cache=None):
+ batch_size, len_time, hidden_size = hidden.shape
+ token_length = []
+ list_fires = []
+ list_frames = []
+ cache_alphas = []
+ cache_hiddens = []
+ alphas[:, :self.chunk_size[0]] = 0.0
+ alphas[:, sum(self.chunk_size[:2]):] = 0.0
+ if cache is not None and "cif_alphas" in cache and "cif_hidden" in cache:
+ hidden = np.concatenate((cache["cif_hidden"], hidden), axis=1)
+ alphas = np.concatenate((cache["cif_alphas"], alphas), axis=1)
+ if cache is not None and "last_chunk" in cache and cache["last_chunk"]:
+ tail_hidden = np.zeros((batch_size, 1, hidden_size)).astype(np.float32)
+ tail_alphas = np.array([[self.tail_threshold]]).astype(np.float32)
+ tail_alphas =np.tile(tail_alphas, (batch_size, 1))
+ hidden = np.concatenate((hidden, tail_hidden), axis=1)
+ alphas = np.concatenate((alphas, tail_alphas), axis=1)
+
+ len_time = alphas.shape[1]
+ for b in range(batch_size):
+ integrate = 0.0
+ frames = np.zeros(hidden_size).astype(np.float32)
+ list_frame = []
+ list_fire = []
+ for t in range(len_time):
+ alpha = alphas[b][t]
+ if alpha + integrate < self.cif_threshold:
+ integrate += alpha
+ list_fire.append(integrate)
+ frames += alpha * hidden[b][t]
+ else:
+ frames += (self.cif_threshold - integrate) * hidden[b][t]
+ list_frame.append(frames)
+ integrate += alpha
+ list_fire.append(integrate)
+ integrate -= self.cif_threshold
+ frames = integrate * hidden[b][t]
+
+ cache_alphas.append(integrate)
+ if integrate > 0.0:
+ cache_hiddens.append(frames / integrate)
+ else:
+ cache_hiddens.append(frames)
+
+ token_length.append(len(list_frame))
+ list_fires.append(list_fire)
+ list_frames.append(list_frame)
+
+ max_token_len = max(token_length)
+ list_ls = []
+ for b in range(batch_size):
+ pad_frames = np.zeros((max_token_len - token_length[b], hidden_size)).astype(np.float32)
+ if token_length[b] == 0:
+ list_ls.append(pad_frames)
+ else:
+ list_ls.append(np.concatenate((list_frames[b], pad_frames), axis=0))
+
+ cache["cif_alphas"] = np.stack(cache_alphas, axis=0)
+ cache["cif_alphas"] = np.expand_dims(cache["cif_alphas"], axis=0)
+ cache["cif_hidden"] = np.stack(cache_hiddens, axis=0)
+ cache["cif_hidden"] = np.expand_dims(cache["cif_hidden"], axis=0)
+
+ return np.stack(list_ls, axis=0).astype(np.float32), np.stack(token_length, axis=0).astype(np.int32)
+
diff --git a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py
index ded04b6..295e7b5 100644
--- a/funasr/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py
+++ b/funasr/runtime/python/onnxruntime/funasr_onnx/utils/frontend.py
@@ -349,6 +349,28 @@
return array
+class SinusoidalPositionEncoderOnline():
+ '''Streaming Positional encoding.
+ '''
+
+ def encode(self, positions: np.ndarray = None, depth: int = None, dtype: np.dtype = np.float32):
+ batch_size = positions.shape[0]
+ positions = positions.astype(dtype)
+ log_timescale_increment = np.log(np.array([10000], dtype=dtype)) / (depth / 2 - 1)
+ inv_timescales = np.exp(np.arange(depth / 2).astype(dtype) * (-log_timescale_increment))
+ inv_timescales = np.reshape(inv_timescales, [batch_size, -1])
+ scaled_time = np.reshape(positions, [1, -1, 1]) * np.reshape(inv_timescales, [1, 1, -1])
+ encoding = np.concatenate((np.sin(scaled_time), np.cos(scaled_time)), axis=2)
+ return encoding.astype(dtype)
+
+ def forward(self, x, start_idx=0):
+ batch_size, timesteps, input_dim = x.shape
+ positions = np.arange(1, timesteps+1+start_idx)[None, :]
+ position_encoding = self.encode(positions, input_dim, x.dtype)
+
+ return x + position_encoding[:, start_idx: start_idx + timesteps]
+
+
def test():
path = "/nfs/zhifu.gzf/export/damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-pytorch/example/asr_example.wav"
import librosa
diff --git a/funasr/runtime/websocket/CMakeLists.txt b/funasr/runtime/websocket/CMakeLists.txt
index 06ae59b..33c7695 100644
--- a/funasr/runtime/websocket/CMakeLists.txt
+++ b/funasr/runtime/websocket/CMakeLists.txt
@@ -58,7 +58,11 @@
find_package(OpenSSL REQUIRED)
add_executable(funasr-wss-server "funasr-wss-server.cpp" "websocket-server.cpp")
+add_executable(funasr-wss-server-2pass "funasr-wss-server-2pass.cpp" "websocket-server-2pass.cpp")
add_executable(funasr-wss-client "funasr-wss-client.cpp")
+add_executable(funasr-wss-client-2pass "funasr-wss-client-2pass.cpp")
target_link_libraries(funasr-wss-client PUBLIC funasr ssl crypto)
+target_link_libraries(funasr-wss-client-2pass PUBLIC funasr ssl crypto)
target_link_libraries(funasr-wss-server PUBLIC funasr ssl crypto)
+target_link_libraries(funasr-wss-server-2pass PUBLIC funasr ssl crypto)
diff --git a/funasr/runtime/websocket/funasr-wss-client-2pass.cpp b/funasr/runtime/websocket/funasr-wss-client-2pass.cpp
new file mode 100644
index 0000000..91500c3
--- /dev/null
+++ b/funasr/runtime/websocket/funasr-wss-client-2pass.cpp
@@ -0,0 +1,430 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+ * Reserved. MIT License (https://opensource.org/licenses/MIT)
+ */
+/* 2022-2023 by zhaomingwork */
+
+// client for websocket, support multiple threads
+// ./funasr-wss-client --server-ip <string>
+// --port <string>
+// --wav-path <string>
+// [--thread-num <int>]
+// [--is-ssl <int>] [--]
+// [--version] [-h]
+// example:
+// ./funasr-wss-client --server-ip 127.0.0.1 --port 10095 --wav-path test.wav --thread-num 1 --is-ssl 1
+
+#define ASIO_STANDALONE 1
+#include <websocketpp/client.hpp>
+#include <websocketpp/common/thread.hpp>
+#include <websocketpp/config/asio_client.hpp>
+#include <iostream>
+#include <fstream>
+#include <sstream>
+#include <atomic>
+#include <thread>
+#include <glog/logging.h>
+
+#include "audio.h"
+#include "nlohmann/json.hpp"
+#include "tclap/CmdLine.h"
+
+/**
+ * Define a semi-cross platform helper method that waits/sleeps for a bit.
+ */
+void WaitABit() {
+ #ifdef WIN32
+ Sleep(1000);
+ #else
+ sleep(1);
+ #endif
+}
+std::atomic<int> wav_index(0);
+
+bool IsTargetFile(const std::string& filename, const std::string target) {
+ std::size_t pos = filename.find_last_of(".");
+ if (pos == std::string::npos) {
+ return false;
+ }
+ std::string extension = filename.substr(pos + 1);
+ return (extension == target);
+}
+
+typedef websocketpp::config::asio_client::message_type::ptr message_ptr;
+typedef websocketpp::lib::shared_ptr<websocketpp::lib::asio::ssl::context> context_ptr;
+using websocketpp::lib::bind;
+using websocketpp::lib::placeholders::_1;
+using websocketpp::lib::placeholders::_2;
+context_ptr OnTlsInit(websocketpp::connection_hdl) {
+ context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
+ asio::ssl::context::sslv23);
+
+ try {
+ ctx->set_options(
+ asio::ssl::context::default_workarounds | asio::ssl::context::no_sslv2 |
+ asio::ssl::context::no_sslv3 | asio::ssl::context::single_dh_use);
+
+ } catch (std::exception& e) {
+ LOG(ERROR) << e.what();
+ }
+ return ctx;
+}
+
+// template for tls or not config
+template <typename T>
+class WebsocketClient {
+ public:
+ // typedef websocketpp::client<T> client;
+ // typedef websocketpp::client<websocketpp::config::asio_tls_client>
+ // wss_client;
+ typedef websocketpp::lib::lock_guard<websocketpp::lib::mutex> scoped_lock;
+
+ WebsocketClient(int is_ssl) : m_open(false), m_done(false) {
+ // set up access channels to only log interesting things
+ m_client.clear_access_channels(websocketpp::log::alevel::all);
+ m_client.set_access_channels(websocketpp::log::alevel::connect);
+ m_client.set_access_channels(websocketpp::log::alevel::disconnect);
+ m_client.set_access_channels(websocketpp::log::alevel::app);
+
+ // Initialize the Asio transport policy
+ m_client.init_asio();
+
+ // Bind the handlers we are using
+ using websocketpp::lib::bind;
+ using websocketpp::lib::placeholders::_1;
+ m_client.set_open_handler(bind(&WebsocketClient::on_open, this, _1));
+ m_client.set_close_handler(bind(&WebsocketClient::on_close, this, _1));
+
+ m_client.set_message_handler(
+ [this](websocketpp::connection_hdl hdl, message_ptr msg) {
+ on_message(hdl, msg);
+ });
+
+ m_client.set_fail_handler(bind(&WebsocketClient::on_fail, this, _1));
+ m_client.clear_access_channels(websocketpp::log::alevel::all);
+ }
+
+ void on_message(websocketpp::connection_hdl hdl, message_ptr msg) {
+ const std::string& payload = msg->get_payload();
+ switch (msg->get_opcode()) {
+ case websocketpp::frame::opcode::text:
+ nlohmann::json jsonresult = nlohmann::json::parse(payload);
+ LOG(INFO)<< "Thread: " << this_thread::get_id() <<",on_message = " << payload;
+
+ // if (jsonresult["is_final"] == true){
+ // websocketpp::lib::error_code ec;
+ // m_client.close(m_hdl, websocketpp::close::status::going_away, "", ec);
+ // if (ec){
+ // LOG(ERROR)<< "Error closing connection " << ec.message();
+ // }
+ // }
+ }
+ }
+
+ // This method will block until the connection is complete
+ void run(const std::string& uri, const std::vector<string>& wav_list, const std::vector<string>& wav_ids, std::string asr_mode, std::vector<int> chunk_size) {
+ // Create a new connection to the given URI
+ websocketpp::lib::error_code ec;
+ typename websocketpp::client<T>::connection_ptr con =
+ m_client.get_connection(uri, ec);
+ if (ec) {
+ m_client.get_alog().write(websocketpp::log::alevel::app,
+ "Get Connection Error: " + ec.message());
+ return;
+ }
+ // Grab a handle for this connection so we can talk to it in a thread
+ // safe manor after the event loop starts.
+ m_hdl = con->get_handle();
+
+ // Queue the connection. No DNS queries or network connections will be
+ // made until the io_service event loop is run.
+ m_client.connect(con);
+
+ // Create a thread to run the ASIO io_service event loop
+ websocketpp::lib::thread asio_thread(&websocketpp::client<T>::run,
+ &m_client);
+ while(true){
+ int i = wav_index.fetch_add(1);
+ if (i >= wav_list.size()) {
+ break;
+ }
+ send_wav_data(wav_list[i], wav_ids[i], asr_mode, chunk_size);
+ }
+ WaitABit();
+
+ asio_thread.join();
+
+ }
+
+ // The open handler will signal that we are ready to start sending data
+ void on_open(websocketpp::connection_hdl) {
+ m_client.get_alog().write(websocketpp::log::alevel::app,
+ "Connection opened, starting data!");
+
+ scoped_lock guard(m_lock);
+ m_open = true;
+ }
+
+ // The close handler will signal that we should stop sending data
+ void on_close(websocketpp::connection_hdl) {
+ m_client.get_alog().write(websocketpp::log::alevel::app,
+ "Connection closed, stopping data!");
+
+ scoped_lock guard(m_lock);
+ m_done = true;
+ }
+
+ // The fail handler will signal that we should stop sending data
+ void on_fail(websocketpp::connection_hdl) {
+ m_client.get_alog().write(websocketpp::log::alevel::app,
+ "Connection failed, stopping data!");
+
+ scoped_lock guard(m_lock);
+ m_done = true;
+ }
+ // send wav to server
+ void send_wav_data(string wav_path, string wav_id, std::string asr_mode, std::vector<int> chunk_vector) {
+ uint64_t count = 0;
+ std::stringstream val;
+
+ funasr::Audio audio(1);
+ int32_t sampling_rate = 16000;
+ std::string wav_format = "pcm";
+ if(IsTargetFile(wav_path.c_str(), "wav")){
+ int32_t sampling_rate = -1;
+ if(!audio.LoadWav(wav_path.c_str(), &sampling_rate))
+ return ;
+ }else if(IsTargetFile(wav_path.c_str(), "pcm")){
+ if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate))
+ return ;
+ }else{
+ wav_format = "others";
+ if (!audio.LoadOthers2Char(wav_path.c_str()))
+ return ;
+ }
+
+ float* buff;
+ int len;
+ int flag = 0;
+ bool wait = false;
+ while (1) {
+ {
+ scoped_lock guard(m_lock);
+ // If the connection has been closed, stop generating data
+ if (m_done) {
+ break;
+ }
+ // If the connection hasn't been opened yet wait a bit and retry
+ if (!m_open) {
+ wait = true;
+ } else {
+ break;
+ }
+ }
+ if (wait) {
+ // LOG(INFO) << "wait.." << m_open;
+ WaitABit();
+ continue;
+ }
+ }
+ websocketpp::lib::error_code ec;
+
+ nlohmann::json jsonbegin;
+ nlohmann::json chunk_size = nlohmann::json::array();
+ chunk_size.push_back(chunk_vector[0]);
+ chunk_size.push_back(chunk_vector[1]);
+ chunk_size.push_back(chunk_vector[2]);
+ jsonbegin["mode"] = asr_mode;
+ jsonbegin["chunk_size"] = chunk_size;
+ jsonbegin["wav_name"] = wav_id;
+ jsonbegin["wav_format"] = wav_format;
+ jsonbegin["is_speaking"] = true;
+ m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text,
+ ec);
+
+ // fetch wav data use asr engine api
+ if(wav_format == "pcm"){
+ while (audio.Fetch(buff, len, flag) > 0) {
+ short* iArray = new short[len];
+ for (size_t i = 0; i < len; ++i) {
+ iArray[i] = (short)(buff[i]*32768);
+ }
+
+ // send data to server
+ int offset = 0;
+ int block_size = 102400;
+ while(offset < len){
+ int send_block = 0;
+ if (offset + block_size <= len){
+ send_block = block_size;
+ }else{
+ send_block = len - offset;
+ }
+ m_client.send(m_hdl, iArray+offset, send_block * sizeof(short),
+ websocketpp::frame::opcode::binary, ec);
+ offset += send_block;
+ }
+
+ LOG(INFO) << "sended data len=" << len * sizeof(short);
+ // The most likely error that we will get is that the connection is
+ // not in the right state. Usually this means we tried to send a
+ // message to a connection that was closed or in the process of
+ // closing. While many errors here can be easily recovered from,
+ // in this simple example, we'll stop the data loop.
+ if (ec) {
+ m_client.get_alog().write(websocketpp::log::alevel::app,
+ "Send Error: " + ec.message());
+ break;
+ }
+ delete[] iArray;
+ // WaitABit();
+ }
+ }else{
+ int offset = 0;
+ int block_size = 204800;
+ len = audio.GetSpeechLen();
+ char* others_buff = audio.GetSpeechChar();
+
+ while(offset < len){
+ int send_block = 0;
+ if (offset + block_size <= len){
+ send_block = block_size;
+ }else{
+ send_block = len - offset;
+ }
+ m_client.send(m_hdl, others_buff+offset, send_block,
+ websocketpp::frame::opcode::binary, ec);
+ offset += send_block;
+ }
+
+ LOG(INFO) << "sended data len=" << len;
+ // The most likely error that we will get is that the connection is
+ // not in the right state. Usually this means we tried to send a
+ // message to a connection that was closed or in the process of
+ // closing. While many errors here can be easily recovered from,
+ // in this simple example, we'll stop the data loop.
+ if (ec) {
+ m_client.get_alog().write(websocketpp::log::alevel::app,
+ "Send Error: " + ec.message());
+ }
+ }
+
+ nlohmann::json jsonresult;
+ jsonresult["is_speaking"] = false;
+ m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
+ ec);
+ // WaitABit();
+ }
+ websocketpp::client<T> m_client;
+
+ private:
+ websocketpp::connection_hdl m_hdl;
+ websocketpp::lib::mutex m_lock;
+ bool m_open;
+ bool m_done;
+ int total_num=0;
+};
+
+int main(int argc, char* argv[]) {
+
+ google::InitGoogleLogging(argv[0]);
+ FLAGS_logtostderr = true;
+
+ TCLAP::CmdLine cmd("funasr-wss-client", ' ', "1.0");
+ TCLAP::ValueArg<std::string> server_ip_("", "server-ip", "server-ip", true,
+ "127.0.0.1", "string");
+ TCLAP::ValueArg<std::string> port_("", "port", "port", true, "10095", "string");
+ TCLAP::ValueArg<std::string> wav_path_("", "wav-path",
+ "the input could be: wav_path, e.g.: asr_example.wav; pcm_path, e.g.: asr_example.pcm; wav.scp, kaldi style wav list (wav_id \t wav_path)",
+ true, "", "string");
+ TCLAP::ValueArg<std::string> asr_mode_("", ASR_MODE, "offline, online, 2pass", false, "2pass", "string");
+ TCLAP::ValueArg<std::string> chunk_size_("", "chunk-size", "chunk_size: 5-10-5 or 5-12-5", false, "5-10-5", "string");
+ TCLAP::ValueArg<int> thread_num_("", "thread-num", "thread-num",
+ false, 1, "int");
+ TCLAP::ValueArg<int> is_ssl_(
+ "", "is-ssl", "is-ssl is 1 means use wss connection, or use ws connection",
+ false, 1, "int");
+
+ cmd.add(server_ip_);
+ cmd.add(port_);
+ cmd.add(wav_path_);
+ cmd.add(asr_mode_);
+ cmd.add(chunk_size_);
+ cmd.add(thread_num_);
+ cmd.add(is_ssl_);
+ cmd.parse(argc, argv);
+
+ std::string server_ip = server_ip_.getValue();
+ std::string port = port_.getValue();
+ std::string wav_path = wav_path_.getValue();
+ std::string asr_mode = asr_mode_.getValue();
+ std::string chunk_size_str = chunk_size_.getValue();
+ // get chunk_size
+ std::vector<int> chunk_size;
+ std::stringstream ss(chunk_size_str);
+ std::string item;
+ while (std::getline(ss, item, '-')) {
+ try {
+ chunk_size.push_back(stoi(item));
+ } catch (const invalid_argument&) {
+ LOG(ERROR) << "Invalid argument: " << item;
+ exit(-1);
+ }
+ }
+
+ int threads_num = thread_num_.getValue();
+ int is_ssl = is_ssl_.getValue();
+
+ std::vector<websocketpp::lib::thread> client_threads;
+ std::string uri = "";
+ if (is_ssl == 1) {
+ uri = "wss://" + server_ip + ":" + port;
+ } else {
+ uri = "ws://" + server_ip + ":" + port;
+ }
+
+ // read wav_path
+ std::vector<string> wav_list;
+ std::vector<string> wav_ids;
+ string default_id = "wav_default_id";
+ if(IsTargetFile(wav_path, "scp")){
+ ifstream in(wav_path);
+ if (!in.is_open()) {
+ printf("Failed to open scp file");
+ return 0;
+ }
+ string line;
+ while(getline(in, line))
+ {
+ istringstream iss(line);
+ string column1, column2;
+ iss >> column1 >> column2;
+ wav_list.emplace_back(column2);
+ wav_ids.emplace_back(column1);
+ }
+ in.close();
+ }else{
+ wav_list.emplace_back(wav_path);
+ wav_ids.emplace_back(default_id);
+ }
+
+ for (size_t i = 0; i < threads_num; i++) {
+ client_threads.emplace_back([uri, wav_list, wav_ids, asr_mode, chunk_size, is_ssl]() {
+ if (is_ssl == 1) {
+ WebsocketClient<websocketpp::config::asio_tls_client> c(is_ssl);
+
+ c.m_client.set_tls_init_handler(bind(&OnTlsInit, ::_1));
+
+ c.run(uri, wav_list, wav_ids, asr_mode, chunk_size);
+ } else {
+ WebsocketClient<websocketpp::config::asio_client> c(is_ssl);
+
+ c.run(uri, wav_list, wav_ids, asr_mode, chunk_size);
+ }
+ });
+ }
+
+ for (auto& t : client_threads) {
+ t.join();
+ }
+}
diff --git a/funasr/runtime/websocket/funasr-wss-server-2pass.cpp b/funasr/runtime/websocket/funasr-wss-server-2pass.cpp
new file mode 100644
index 0000000..99497dc
--- /dev/null
+++ b/funasr/runtime/websocket/funasr-wss-server-2pass.cpp
@@ -0,0 +1,419 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+ * Reserved. MIT License (https://opensource.org/licenses/MIT)
+ */
+/* 2022-2023 by zhaomingwork */
+
+// io server
+// Usage:funasr-wss-server [--model_thread_num <int>] [--decoder_thread_num
+// <int>]
+// [--io_thread_num <int>] [--port <int>] [--listen_ip
+// <string>] [--punc-quant <string>] [--punc-dir <string>]
+// [--vad-quant <string>] [--vad-dir <string>] [--quantize
+// <string>] --model-dir <string> [--] [--version] [-h]
+#include <unistd.h>
+#include "websocket-server-2pass.h"
+
+using namespace std;
+void GetValue(TCLAP::ValueArg<std::string>& value_arg, string key,
+ std::map<std::string, std::string>& model_path) {
+ model_path.insert({key, value_arg.getValue()});
+ LOG(INFO) << key << " : " << value_arg.getValue();
+}
+int main(int argc, char* argv[]) {
+ try {
+ google::InitGoogleLogging(argv[0]);
+ FLAGS_logtostderr = true;
+
+ TCLAP::CmdLine cmd("funasr-wss-server", ' ', "1.0");
+ TCLAP::ValueArg<std::string> download_model_dir(
+ "", "download-model-dir",
+ "Download model from Modelscope to download_model_dir", false,
+ "/workspace/models", "string");
+ TCLAP::ValueArg<std::string> offline_model_dir(
+ "", OFFLINE_MODEL_DIR,
+ "default: /workspace/models/offline_asr, the asr model path, which "
+ "contains model_quant.onnx, config.yaml, am.mvn",
+ false, "/workspace/models/offline_asr", "string");
+ TCLAP::ValueArg<std::string> online_model_dir(
+ "", ONLINE_MODEL_DIR,
+ "default: damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online-onnx, the asr model path, which "
+ "contains model_quant.onnx, config.yaml, am.mvn",
+ false, "damo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online-onnx", "string");
+
+ TCLAP::ValueArg<std::string> offline_model_revision(
+ "", "offline-model-revision", "ASR offline model revision", false,
+ "v1.2.1", "string");
+
+ TCLAP::ValueArg<std::string> online_model_revision(
+ "", "online-model-revision", "ASR online model revision", false,
+ "v1.0.6", "string");
+
+ TCLAP::ValueArg<std::string> quantize(
+ "", QUANTIZE,
+ "true (Default), load the model of model_quant.onnx in model_dir. If "
+ "set "
+ "false, load the model of model.onnx in model_dir",
+ false, "true", "string");
+ TCLAP::ValueArg<std::string> vad_dir(
+ "", VAD_DIR,
+ "default: /workspace/models/vad, the vad model path, which contains "
+ "model_quant.onnx, vad.yaml, vad.mvn",
+ false, "/workspace/models/vad", "string");
+ TCLAP::ValueArg<std::string> vad_revision(
+ "", "vad-revision", "VAD model revision", false, "v1.2.0", "string");
+ TCLAP::ValueArg<std::string> vad_quant(
+ "", VAD_QUANT,
+ "true (Default), load the model of model_quant.onnx in vad_dir. If set "
+ "false, load the model of model.onnx in vad_dir",
+ false, "true", "string");
+ TCLAP::ValueArg<std::string> punc_dir(
+ "", PUNC_DIR,
+ "default: /workspace/models/punc, the punc model path, which contains "
+ "model_quant.onnx, punc.yaml",
+ false, "/workspace/models/punc", "string");
+ TCLAP::ValueArg<std::string> punc_revision(
+ "", "punc-revision", "PUNC model revision", false, "v1.0.2", "string");
+ TCLAP::ValueArg<std::string> punc_quant(
+ "", PUNC_QUANT,
+ "true (Default), load the model of model_quant.onnx in punc_dir. If "
+ "set "
+ "false, load the model of model.onnx in punc_dir",
+ false, "true", "string");
+
+ TCLAP::ValueArg<std::string> listen_ip("", "listen-ip", "listen ip", false,
+ "0.0.0.0", "string");
+ TCLAP::ValueArg<int> port("", "port", "port", false, 10095, "int");
+ TCLAP::ValueArg<int> io_thread_num("", "io-thread-num", "io thread num",
+ false, 8, "int");
+ TCLAP::ValueArg<int> decoder_thread_num(
+ "", "decoder-thread-num", "decoder thread num", false, 8, "int");
+ TCLAP::ValueArg<int> model_thread_num("", "model-thread-num",
+ "model thread num", false, 4, "int");
+
+ TCLAP::ValueArg<std::string> certfile(
+ "", "certfile",
+ "default: ../../../ssl_key/server.crt, path of certficate for WSS "
+ "connection. if it is empty, it will be in WS mode.",
+ false, "../../../ssl_key/server.crt", "string");
+ TCLAP::ValueArg<std::string> keyfile(
+ "", "keyfile",
+ "default: ../../../ssl_key/server.key, path of keyfile for WSS "
+ "connection",
+ false, "../../../ssl_key/server.key", "string");
+
+ cmd.add(certfile);
+ cmd.add(keyfile);
+
+ cmd.add(download_model_dir);
+ cmd.add(offline_model_dir);
+ cmd.add(online_model_dir);
+ cmd.add(offline_model_revision);
+ cmd.add(online_model_revision);
+ cmd.add(quantize);
+ cmd.add(vad_dir);
+ cmd.add(vad_revision);
+ cmd.add(vad_quant);
+ cmd.add(punc_dir);
+ cmd.add(punc_revision);
+ cmd.add(punc_quant);
+
+ cmd.add(listen_ip);
+ cmd.add(port);
+ cmd.add(io_thread_num);
+ cmd.add(decoder_thread_num);
+ cmd.add(model_thread_num);
+ cmd.parse(argc, argv);
+
+ std::map<std::string, std::string> model_path;
+ GetValue(offline_model_dir, OFFLINE_MODEL_DIR, model_path);
+ GetValue(online_model_dir, ONLINE_MODEL_DIR, model_path);
+ GetValue(quantize, QUANTIZE, model_path);
+ GetValue(vad_dir, VAD_DIR, model_path);
+ GetValue(vad_quant, VAD_QUANT, model_path);
+ GetValue(punc_dir, PUNC_DIR, model_path);
+ GetValue(punc_quant, PUNC_QUANT, model_path);
+
+ GetValue(offline_model_revision, "offline-model-revision", model_path);
+ GetValue(online_model_revision, "online-model-revision", model_path);
+ GetValue(vad_revision, "vad-revision", model_path);
+ GetValue(punc_revision, "punc-revision", model_path);
+
+ // Download model form Modelscope
+ try {
+ std::string s_download_model_dir = download_model_dir.getValue();
+
+ std::string s_vad_path = model_path[VAD_DIR];
+ std::string s_vad_quant = model_path[VAD_QUANT];
+ std::string s_offline_asr_path = model_path[OFFLINE_MODEL_DIR];
+ std::string s_online_asr_path = model_path[ONLINE_MODEL_DIR];
+ std::string s_asr_quant = model_path[QUANTIZE];
+ std::string s_punc_path = model_path[PUNC_DIR];
+ std::string s_punc_quant = model_path[PUNC_QUANT];
+
+ std::string python_cmd =
+ "python -m funasr.utils.runtime_sdk_download_tool --type onnx --quantize True ";
+
+ if (vad_dir.isSet() && !s_vad_path.empty()) {
+ std::string python_cmd_vad;
+ std::string down_vad_path;
+ std::string down_vad_model;
+
+ if (access(s_vad_path.c_str(), F_OK) == 0) {
+ // local
+ python_cmd_vad = python_cmd + " --model-name " + s_vad_path +
+ " --export-dir ./ " + " --model_revision " +
+ model_path["vad-revision"];
+ down_vad_path = s_vad_path;
+ } else {
+ // modelscope
+ LOG(INFO) << "Download model: " << s_vad_path
+ << " from modelscope: ";
+ python_cmd_vad = python_cmd + " --model-name " +
+ s_vad_path +
+ " --export-dir " + s_download_model_dir +
+ " --model_revision " + model_path["vad-revision"];
+ down_vad_path =
+ s_download_model_dir +
+ "/" + s_vad_path;
+ }
+
+ int ret = system(python_cmd_vad.c_str());
+ if (ret != 0) {
+ LOG(INFO) << "Failed to download model from modelscope. If you set local vad model path, you can ignore the errors.";
+ }
+ down_vad_model = down_vad_path + "/model_quant.onnx";
+ if (s_vad_quant == "false" || s_vad_quant == "False" ||
+ s_vad_quant == "FALSE") {
+ down_vad_model = down_vad_path + "/model.onnx";
+ }
+
+ if (access(down_vad_model.c_str(), F_OK) != 0) {
+ LOG(ERROR) << down_vad_model << " do not exists.";
+ exit(-1);
+ } else {
+ model_path[VAD_DIR] = down_vad_path;
+ LOG(INFO) << "Set " << VAD_DIR << " : " << model_path[VAD_DIR];
+ }
+ }
+ else {
+ LOG(INFO) << "VAD model is not set, use default.";
+ }
+
+ if (offline_model_dir.isSet() && !s_offline_asr_path.empty()) {
+ std::string python_cmd_asr;
+ std::string down_asr_path;
+ std::string down_asr_model;
+
+ if (access(s_offline_asr_path.c_str(), F_OK) == 0) {
+ // local
+ python_cmd_asr = python_cmd + " --model-name " + s_offline_asr_path +
+ " --export-dir ./ " + " --model_revision " +
+ model_path["offline-model-revision"];
+ down_asr_path = s_offline_asr_path;
+ } else {
+ // modelscope
+ LOG(INFO) << "Download model: " << s_offline_asr_path
+ << " from modelscope : ";
+ python_cmd_asr = python_cmd + " --model-name " +
+ s_offline_asr_path +
+ " --export-dir " + s_download_model_dir +
+ " --model_revision " + model_path["offline-model-revision"];
+ down_asr_path
+ = s_download_model_dir + "/" + s_offline_asr_path;
+ }
+
+ int ret = system(python_cmd_asr.c_str());
+ if (ret != 0) {
+ LOG(INFO) << "Failed to download model from modelscope. If you set local asr model path, you can ignore the errors.";
+ }
+ down_asr_model = down_asr_path + "/model_quant.onnx";
+ if (s_asr_quant == "false" || s_asr_quant == "False" ||
+ s_asr_quant == "FALSE") {
+ down_asr_model = down_asr_path + "/model.onnx";
+ }
+
+ if (access(down_asr_model.c_str(), F_OK) != 0) {
+ LOG(ERROR) << down_asr_model << " do not exists.";
+ exit(-1);
+ } else {
+ model_path[OFFLINE_MODEL_DIR] = down_asr_path;
+ LOG(INFO) << "Set " << OFFLINE_MODEL_DIR << " : " << model_path[OFFLINE_MODEL_DIR];
+ }
+ } else {
+ LOG(INFO) << "ASR Offline model is not set, use default.";
+ }
+
+ if (!s_online_asr_path.empty()) {
+ std::string python_cmd_asr;
+ std::string down_asr_path;
+ std::string down_asr_model;
+
+ if (access(s_online_asr_path.c_str(), F_OK) == 0) {
+ // local
+ python_cmd_asr = python_cmd + " --model-name " + s_online_asr_path +
+ " --export-dir ./ " + " --model_revision " +
+ model_path["online-model-revision"];
+ down_asr_path = s_online_asr_path;
+ } else {
+ // modelscope
+ LOG(INFO) << "Download model: " << s_online_asr_path
+ << " from modelscope : ";
+ python_cmd_asr = python_cmd + " --model-name " +
+ s_online_asr_path +
+ " --export-dir " + s_download_model_dir +
+ " --model_revision " + model_path["online-model-revision"];
+ down_asr_path
+ = s_download_model_dir + "/" + s_online_asr_path;
+ }
+
+ int ret = system(python_cmd_asr.c_str());
+ if (ret != 0) {
+ LOG(INFO) << "Failed to download model from modelscope. If you set local asr model path, you can ignore the errors.";
+ }
+ down_asr_model = down_asr_path + "/model_quant.onnx";
+ if (s_asr_quant == "false" || s_asr_quant == "False" ||
+ s_asr_quant == "FALSE") {
+ down_asr_model = down_asr_path + "/model.onnx";
+ }
+
+ if (access(down_asr_model.c_str(), F_OK) != 0) {
+ LOG(ERROR) << down_asr_model << " do not exists.";
+ exit(-1);
+ } else {
+ model_path[ONLINE_MODEL_DIR] = down_asr_path;
+ LOG(INFO) << "Set " << ONLINE_MODEL_DIR << " : " << model_path[ONLINE_MODEL_DIR];
+ }
+ } else {
+ LOG(INFO) << "ASR online model is not set, use default.";
+ }
+
+ if (punc_dir.isSet() && !s_punc_path.empty()) {
+ std::string python_cmd_punc;
+ std::string down_punc_path;
+ std::string down_punc_model;
+
+ if (access(s_punc_path.c_str(), F_OK) == 0) {
+ // local
+ python_cmd_punc = python_cmd + " --model-name " + s_punc_path +
+ " --export-dir ./ " + " --model_revision " +
+ model_path["punc-revision"];
+ down_punc_path = s_punc_path;
+ } else {
+ // modelscope
+ LOG(INFO) << "Download model: " << s_punc_path
+ << " from modelscope : "; python_cmd_punc = python_cmd + " --model-name " +
+ s_punc_path +
+ " --export-dir " + s_download_model_dir +
+ " --model_revision " + model_path["punc-revision"];
+ down_punc_path =
+ s_download_model_dir +
+ "/" + s_punc_path;
+ }
+
+ int ret = system(python_cmd_punc.c_str());
+ if (ret != 0) {
+ LOG(INFO) << "Failed to download model from modelscope. If you set local punc model path, you can ignore the errors.";
+ }
+ down_punc_model = down_punc_path + "/model_quant.onnx";
+ if (s_punc_quant == "false" || s_punc_quant == "False" ||
+ s_punc_quant == "FALSE") {
+ down_punc_model = down_punc_path + "/model.onnx";
+ }
+
+ if (access(down_punc_model.c_str(), F_OK) != 0) {
+ LOG(ERROR) << down_punc_model << " do not exists.";
+ exit(-1);
+ } else {
+ model_path[PUNC_DIR] = down_punc_path;
+ LOG(INFO) << "Set " << PUNC_DIR << " : " << model_path[PUNC_DIR];
+ }
+ } else {
+ LOG(INFO) << "PUNC model is not set, use default.";
+ }
+
+ } catch (std::exception const& e) {
+ LOG(ERROR) << "Error: " << e.what();
+ }
+
+ std::string s_listen_ip = listen_ip.getValue();
+ int s_port = port.getValue();
+ int s_io_thread_num = io_thread_num.getValue();
+ int s_decoder_thread_num = decoder_thread_num.getValue();
+
+ int s_model_thread_num = model_thread_num.getValue();
+
+ asio::io_context io_decoder; // context for decoding
+ asio::io_context io_server; // context for server
+
+ std::vector<std::thread> decoder_threads;
+
+ std::string s_certfile = certfile.getValue();
+ std::string s_keyfile = keyfile.getValue();
+
+ bool is_ssl = false;
+ if (!s_certfile.empty()) {
+ is_ssl = true;
+ }
+
+ auto conn_guard = asio::make_work_guard(
+ io_decoder); // make sure threads can wait in the queue
+ auto server_guard = asio::make_work_guard(
+ io_server); // make sure threads can wait in the queue
+ // create threads pool
+ for (int32_t i = 0; i < s_decoder_thread_num; ++i) {
+ decoder_threads.emplace_back([&io_decoder]() { io_decoder.run(); });
+ }
+
+ server server_; // server for websocket
+ wss_server wss_server_;
+ if (is_ssl) {
+ wss_server_.init_asio(&io_server); // init asio
+ wss_server_.set_reuse_addr(
+ true); // reuse address as we create multiple threads
+
+ // list on port for accept
+ wss_server_.listen(asio::ip::address::from_string(s_listen_ip), s_port);
+ WebSocketServer websocket_srv(
+ io_decoder, is_ssl, nullptr, &wss_server_, s_certfile,
+ s_keyfile); // websocket server for asr engine
+ websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model
+
+ } else {
+ server_.init_asio(&io_server); // init asio
+ server_.set_reuse_addr(
+ true); // reuse address as we create multiple threads
+
+ // list on port for accept
+ server_.listen(asio::ip::address::from_string(s_listen_ip), s_port);
+ WebSocketServer websocket_srv(
+ io_decoder, is_ssl, &server_, nullptr, s_certfile,
+ s_keyfile); // websocket server for asr engine
+ websocket_srv.initAsr(model_path, s_model_thread_num); // init asr model
+ }
+
+ std::cout << "asr model init finished. listen on port:" << s_port
+ << std::endl;
+
+ // Start the ASIO network io_service run loop
+ std::vector<std::thread> ts;
+ // create threads for io network
+ for (size_t i = 0; i < s_io_thread_num; i++) {
+ ts.emplace_back([&io_server]() { io_server.run(); });
+ }
+ // wait for theads
+ for (size_t i = 0; i < s_io_thread_num; i++) {
+ ts[i].join();
+ }
+
+ // wait for theads
+ for (auto& t : decoder_threads) {
+ t.join();
+ }
+
+ } catch (std::exception const& e) {
+ std::cerr << "Error: " << e.what() << std::endl;
+ }
+
+ return 0;
+}
diff --git a/funasr/runtime/websocket/funasr-wss-server.cpp b/funasr/runtime/websocket/funasr-wss-server.cpp
index 5061bba..614d455 100644
--- a/funasr/runtime/websocket/funasr-wss-server.cpp
+++ b/funasr/runtime/websocket/funasr-wss-server.cpp
@@ -79,7 +79,7 @@
TCLAP::ValueArg<int> decoder_thread_num(
"", "decoder-thread-num", "decoder thread num", false, 8, "int");
TCLAP::ValueArg<int> model_thread_num("", "model-thread-num",
- "model thread num", false, 1, "int");
+ "model thread num", false, 4, "int");
TCLAP::ValueArg<std::string> certfile("", "certfile",
"default: ../../../ssl_key/server.crt, path of certficate for WSS connection. if it is empty, it will be in WS mode.",
diff --git a/funasr/runtime/websocket/readme.md b/funasr/runtime/websocket/readme.md
index 12d255c..48b9063 100644
--- a/funasr/runtime/websocket/readme.md
+++ b/funasr/runtime/websocket/readme.md
@@ -116,6 +116,18 @@
--punc-dir ./export/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx
```
+##### Start the 2pass Service
+```shell
+./funasr-wss-server-2pass \
+ --download-model-dir /workspace/models \
+ --offline-model-dir ./exportdamo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-onnx \
+ --vad-dir ./exportdamo/speech_fsmn_vad_zh-cn-16k-common-onnx \
+ --punc-dir ./export/damo/punc_ct-transformer_zh-cn-common-vocab272727-onnx \
+ --online-model-dir ./exportdamo/speech_paraformer-large_asr_nat-zh-cn-16k-common-vocab8404-online \
+ --quantize false
+```
+
+
### Client Usage
diff --git a/funasr/runtime/websocket/websocket-server-2pass.cpp b/funasr/runtime/websocket/websocket-server-2pass.cpp
new file mode 100644
index 0000000..8833b0b
--- /dev/null
+++ b/funasr/runtime/websocket/websocket-server-2pass.cpp
@@ -0,0 +1,369 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+ * Reserved. MIT License (https://opensource.org/licenses/MIT)
+ */
+/* 2022-2023 by zhaomingwork */
+
+// websocket server for asr engine
+// take some ideas from https://github.com/k2-fsa/sherpa-onnx
+// online-websocket-server-impl.cc, thanks. The websocket server has two threads
+// pools, one for handle network data and one for asr decoder.
+// now only support offline engine.
+
+#include "websocket-server-2pass.h"
+
+#include <thread>
+#include <utility>
+#include <vector>
+
+context_ptr WebSocketServer::on_tls_init(tls_mode mode,
+ websocketpp::connection_hdl hdl,
+ std::string& s_certfile,
+ std::string& s_keyfile) {
+ namespace asio = websocketpp::lib::asio;
+
+ LOG(INFO) << "on_tls_init called with hdl: " << hdl.lock().get();
+ LOG(INFO) << "using TLS mode: "
+ << (mode == MOZILLA_MODERN ? "Mozilla Modern"
+ : "Mozilla Intermediate");
+
+ context_ptr ctx = websocketpp::lib::make_shared<asio::ssl::context>(
+ asio::ssl::context::sslv23);
+
+ try {
+ if (mode == MOZILLA_MODERN) {
+ // Modern disables TLSv1
+ ctx->set_options(
+ asio::ssl::context::default_workarounds |
+ asio::ssl::context::no_sslv2 | asio::ssl::context::no_sslv3 |
+ asio::ssl::context::no_tlsv1 | asio::ssl::context::single_dh_use);
+ } else {
+ ctx->set_options(asio::ssl::context::default_workarounds |
+ asio::ssl::context::no_sslv2 |
+ asio::ssl::context::no_sslv3 |
+ asio::ssl::context::single_dh_use);
+ }
+
+ ctx->use_certificate_chain_file(s_certfile);
+ ctx->use_private_key_file(s_keyfile, asio::ssl::context::pem);
+
+ } catch (std::exception& e) {
+ LOG(INFO) << "Exception: " << e.what();
+ }
+ return ctx;
+}
+
+nlohmann::json handle_result(FUNASR_RESULT result, std::string& online_res,
+ std::string& tpass_res, nlohmann::json msg) {
+
+ websocketpp::lib::error_code ec;
+ nlohmann::json jsonresult;
+ jsonresult["text"]="";
+
+ std::string tmp_online_msg = FunASRGetResult(result, 0);
+ online_res += tmp_online_msg;
+ if (tmp_online_msg != "") {
+ LOG(INFO) << "online_res :" << tmp_online_msg;
+ jsonresult["text"] = tmp_online_msg;
+ jsonresult["mode"] = "2pass-online";
+ }
+ std::string tmp_tpass_msg = FunASRGetTpassResult(result, 0);
+ tpass_res += tmp_tpass_msg;
+ if (tmp_tpass_msg != "") {
+ LOG(INFO) << "offline results : " << tmp_tpass_msg;
+ jsonresult["text"] = tmp_tpass_msg;
+ jsonresult["mode"] = "2pass-offline";
+ }
+
+ if (msg.contains("wav_name")) {
+ jsonresult["wav_name"] = msg["wav_name"];
+ }
+
+ return jsonresult;
+}
+// feed buffer to asr engine for decoder
+void WebSocketServer::do_decoder(
+ std::vector<char>& buffer, websocketpp::connection_hdl& hdl,
+ nlohmann::json& msg, std::vector<std::vector<std::string>>& punc_cache,
+ websocketpp::lib::mutex& thread_lock, bool& is_final,
+ FUNASR_HANDLE& tpass_online_handle, std::string& online_res,
+ std::string& tpass_res) {
+
+ // lock for each connection
+ scoped_lock guard(thread_lock);
+ FUNASR_RESULT Result = nullptr;
+ int asr_mode_ = 2;
+ if (msg.contains("mode")) {
+ std::string modeltype = msg["mode"];
+ if (modeltype == "offline") {
+ asr_mode_ = 0;
+ } else if (modeltype == "online") {
+ asr_mode_ = 1;
+ } else if (modeltype == "2pass") {
+ asr_mode_ = 2;
+ }
+ } else {
+ // default value
+ msg["mode"] = "2pass";
+ asr_mode_ = 2;
+ }
+
+ try {
+ // loop to send chunk_size 800*2 data to asr engine. TODO: chunk_size need get from client
+ while (buffer.size() >= 800 * 2) {
+ std::vector<char> subvector = {buffer.begin(),
+ buffer.begin() + 800 * 2};
+ buffer.erase(buffer.begin(), buffer.begin() + 800 * 2);
+
+ try{
+ Result =
+ FunTpassInferBuffer(tpass_handle, tpass_online_handle,
+ subvector.data(), subvector.size(), punc_cache,
+ false, msg["audio_fs"], msg["wav_format"], (ASR_TYPE)asr_mode_);
+ }catch (std::exception const &e)
+ {
+ LOG(ERROR)<<e.what();
+ }
+ if (Result) {
+ websocketpp::lib::error_code ec;
+ nlohmann::json jsonresult =
+ handle_result(Result, online_res, tpass_res, msg["wav_name"]);
+ jsonresult["is_final"] = false;
+ if(jsonresult["text"] != "") {
+ if (is_ssl) {
+ wss_server_->send(hdl, jsonresult.dump(),
+ websocketpp::frame::opcode::text, ec);
+ } else {
+ server_->send(hdl, jsonresult.dump(),
+ websocketpp::frame::opcode::text, ec);
+ }
+ }
+ FunASRFreeResult(Result);
+ }
+
+ }
+ if(is_final){
+
+ try{
+ Result = FunTpassInferBuffer(tpass_handle, tpass_online_handle,
+ buffer.data(), buffer.size(), punc_cache,
+ is_final, msg["audio_fs"], msg["wav_format"], (ASR_TYPE)asr_mode_);
+ }catch (std::exception const &e)
+ {
+ LOG(ERROR)<<e.what();
+ }
+ for(auto &vec:punc_cache){
+ vec.clear();
+ }
+ if (Result) {
+ websocketpp::lib::error_code ec;
+ nlohmann::json jsonresult =
+ handle_result(Result, online_res, tpass_res, msg["wav_name"]);
+ jsonresult["is_final"] = true;
+ if (is_ssl) {
+ wss_server_->send(hdl, jsonresult.dump(),
+ websocketpp::frame::opcode::text, ec);
+ } else {
+ server_->send(hdl, jsonresult.dump(),
+ websocketpp::frame::opcode::text, ec);
+ }
+ FunASRFreeResult(Result);
+ }
+ }
+
+ } catch (std::exception const& e) {
+ std::cerr << "Error: " << e.what() << std::endl;
+ }
+}
+
+void WebSocketServer::on_open(websocketpp::connection_hdl hdl) {
+ scoped_lock guard(m_lock); // for threads safty
+ check_and_clean_connection(); // remove closed connection
+
+ std::shared_ptr<FUNASR_MESSAGE> data_msg =
+ std::make_shared<FUNASR_MESSAGE>(); // put a new data vector for new
+ // connection
+ data_msg->samples = std::make_shared<std::vector<char>>();
+ data_msg->thread_lock = new websocketpp::lib::mutex();
+
+ data_msg->msg = nlohmann::json::parse("{}");
+ data_msg->msg["wav_format"] = "pcm";
+ data_msg->msg["audio_fs"] = 16000;
+ data_msg->punc_cache =
+ std::make_shared<std::vector<std::vector<std::string>>>(2);
+ // std::vector<int> chunk_size = {5, 10, 5}; //TODO, need get from client
+ // FUNASR_HANDLE tpass_online_handle =
+ // FunTpassOnlineInit(tpass_handle, chunk_size);
+ // data_msg->tpass_online_handle = tpass_online_handle;
+ data_map.emplace(hdl, data_msg);
+ LOG(INFO) << "on_open, active connections: " << data_map.size();
+
+}
+
+void WebSocketServer::on_close(websocketpp::connection_hdl hdl) {
+ scoped_lock guard(m_lock);
+ std::shared_ptr<FUNASR_MESSAGE> data_msg = nullptr;
+ auto it_data = data_map.find(hdl);
+ if (it_data != data_map.end()) {
+ data_msg = it_data->second;
+ }
+ else
+ {
+ return;
+ }
+ scoped_lock guard_decoder(*(data_msg->thread_lock)); //wait for do_decoder finished and avoid access freed tpass_online_handle
+ FunTpassOnlineUninit(data_msg->tpass_online_handle);
+ data_map.erase(hdl); // remove data vector when connection is closed
+ LOG(INFO) << "on_close, active connections: "<< data_map.size();
+}
+
+// remove closed connection
+void WebSocketServer::check_and_clean_connection() {
+ std::vector<websocketpp::connection_hdl> to_remove; // remove list
+ auto iter = data_map.begin();
+ while (iter != data_map.end()) { // loop to find closed connection
+ websocketpp::connection_hdl hdl = iter->first;
+
+ if (is_ssl) {
+ wss_server::connection_ptr con = wss_server_->get_con_from_hdl(hdl);
+ if (con->get_state() != 1) { // session::state::open ==1
+ to_remove.push_back(hdl);
+ }
+ } else {
+ server::connection_ptr con = server_->get_con_from_hdl(hdl);
+ if (con->get_state() != 1) { // session::state::open ==1
+ to_remove.push_back(hdl);
+ }
+ }
+
+ iter++;
+ }
+ for (auto hdl : to_remove) {
+ data_map.erase(hdl);
+ LOG(INFO) << "remove one connection ";
+ }
+}
+void WebSocketServer::on_message(websocketpp::connection_hdl hdl,
+ message_ptr msg) {
+ unique_lock lock(m_lock);
+ // find the sample data vector according to one connection
+
+ std::shared_ptr<FUNASR_MESSAGE> msg_data = nullptr;
+
+ auto it_data = data_map.find(hdl);
+ if (it_data != data_map.end()) {
+ msg_data = it_data->second;
+ }
+
+ std::shared_ptr<std::vector<char>> sample_data_p = msg_data->samples;
+ std::shared_ptr<std::vector<std::vector<std::string>>> punc_cache_p =
+ msg_data->punc_cache;
+ websocketpp::lib::mutex* thread_lock_p = msg_data->thread_lock;
+
+ lock.unlock();
+
+ if (sample_data_p == nullptr) {
+ LOG(INFO) << "error when fetch sample data vector";
+ return;
+ }
+
+ const std::string& payload = msg->get_payload(); // get msg type
+
+ switch (msg->get_opcode()) {
+ case websocketpp::frame::opcode::text: {
+ nlohmann::json jsonresult = nlohmann::json::parse(payload);
+
+ if (jsonresult.contains("wav_name")) {
+ msg_data->msg["wav_name"] = jsonresult["wav_name"];
+ }
+ if (jsonresult.contains("mode")) {
+ msg_data->msg["mode"] = jsonresult["mode"];
+ }
+ if (jsonresult.contains("wav_format")) {
+ msg_data->msg["wav_format"] = jsonresult["wav_format"];
+ }
+ if (jsonresult.contains("audio_fs")) {
+ msg_data->msg["audio_fs"] = jsonresult["audio_fs"];
+ }
+ if (jsonresult.contains("chunk_size")){
+ if(msg_data->tpass_online_handle == NULL){
+ std::vector<int> chunk_size_vec = jsonresult["chunk_size"].get<std::vector<int>>();
+ FUNASR_HANDLE tpass_online_handle =
+ FunTpassOnlineInit(tpass_handle, chunk_size_vec);
+ msg_data->tpass_online_handle = tpass_online_handle;
+ }
+ }
+ LOG(INFO) << "jsonresult=" << jsonresult << ", msg_data->msg="
+ << msg_data->msg;
+ if (jsonresult["is_speaking"] == false ||
+ jsonresult["is_finished"] == true) {
+ LOG(INFO) << "client done";
+
+ // if it is in final message, post the sample_data to decode
+ asio::post(
+ io_decoder_,
+ std::bind(&WebSocketServer::do_decoder, this,
+ std::move(*(sample_data_p.get())), std::move(hdl),
+ std::ref(msg_data->msg), std::ref(*(punc_cache_p.get())),
+ std::ref(*thread_lock_p), std::move(true),
+ std::ref(msg_data->tpass_online_handle),
+ std::ref(msg_data->online_res),
+ std::ref(msg_data->tpass_res)));
+ }
+ break;
+ }
+ case websocketpp::frame::opcode::binary: {
+ // recived binary data
+ const auto* pcm_data = static_cast<const char*>(payload.data());
+ int32_t num_samples = payload.size();
+
+ if (isonline) {
+ sample_data_p->insert(sample_data_p->end(), pcm_data,
+ pcm_data + num_samples);
+ int setpsize = 800 * 2; // TODO, need get from client
+ // if sample_data size > setpsize, we post data to decode
+ if (sample_data_p->size() > setpsize) {
+ int chunksize = floor(sample_data_p->size() / setpsize);
+ // make sure the subvector size is an integer multiple of setpsize
+ std::vector<char> subvector = {
+ sample_data_p->begin(),
+ sample_data_p->begin() + chunksize * setpsize};
+ // keep remain in sample_data
+ sample_data_p->erase(sample_data_p->begin(),
+ sample_data_p->begin() + chunksize * setpsize);
+ // post to decode
+ asio::post(io_decoder_,
+ std::bind(&WebSocketServer::do_decoder, this,
+ std::move(subvector), std::move(hdl),
+ std::ref(msg_data->msg),
+ std::ref(*(punc_cache_p.get())),
+ std::ref(*thread_lock_p), std::move(false),
+ std::ref(msg_data->tpass_online_handle),
+ std::ref(msg_data->online_res),
+ std::ref(msg_data->tpass_res)));
+ }
+ } else {
+ sample_data_p->insert(sample_data_p->end(), pcm_data,
+ pcm_data + num_samples);
+ }
+ break;
+ }
+ default:
+ break;
+ }
+}
+
+// init asr model
+void WebSocketServer::initAsr(std::map<std::string, std::string>& model_path,
+ int thread_num) {
+ try {
+ tpass_handle = FunTpassInit(model_path, thread_num);
+ if (!tpass_handle) {
+ LOG(ERROR) << "FunTpassInit init failed";
+ exit(-1);
+ }
+
+ } catch (const std::exception& e) {
+ LOG(INFO) << e.what();
+ }
+}
diff --git a/funasr/runtime/websocket/websocket-server-2pass.h b/funasr/runtime/websocket/websocket-server-2pass.h
new file mode 100644
index 0000000..21b98ca
--- /dev/null
+++ b/funasr/runtime/websocket/websocket-server-2pass.h
@@ -0,0 +1,148 @@
+/**
+ * Copyright FunASR (https://github.com/alibaba-damo-academy/FunASR). All Rights
+ * Reserved. MIT License (https://opensource.org/licenses/MIT)
+ */
+/* 2022-2023 by zhaomingwork */
+
+// websocket server for asr engine
+// take some ideas from https://github.com/k2-fsa/sherpa-onnx
+// online-websocket-server-impl.cc, thanks. The websocket server has two threads
+// pools, one for handle network data and one for asr decoder.
+// now only support offline engine.
+
+#ifndef WEBSOCKET_SERVER_H_
+#define WEBSOCKET_SERVER_H_
+
+#include <iostream>
+#include <map>
+#include <memory>
+#include <string>
+#include <thread>
+#include <utility>
+#define ASIO_STANDALONE 1 // not boost
+#include <glog/logging.h>
+
+#include <fstream>
+#include <functional>
+#include <websocketpp/common/thread.hpp>
+#include <websocketpp/config/asio.hpp>
+#include <websocketpp/server.hpp>
+
+#include "asio.hpp"
+#include "com-define.h"
+#include "funasrruntime.h"
+#include "nlohmann/json.hpp"
+#include "tclap/CmdLine.h"
+typedef websocketpp::server<websocketpp::config::asio> server;
+typedef websocketpp::server<websocketpp::config::asio_tls> wss_server;
+typedef server::message_ptr message_ptr;
+using websocketpp::lib::bind;
+using websocketpp::lib::placeholders::_1;
+using websocketpp::lib::placeholders::_2;
+
+typedef websocketpp::lib::lock_guard<websocketpp::lib::mutex> scoped_lock;
+typedef websocketpp::lib::unique_lock<websocketpp::lib::mutex> unique_lock;
+typedef websocketpp::lib::shared_ptr<websocketpp::lib::asio::ssl::context>
+ context_ptr;
+
+typedef struct {
+ std::string msg;
+ float snippet_time;
+} FUNASR_RECOG_RESULT;
+
+typedef struct {
+ nlohmann::json msg;
+ std::shared_ptr<std::vector<char>> samples;
+ std::shared_ptr<std::vector<std::vector<std::string>>> punc_cache;
+ websocketpp::lib::mutex* thread_lock; // lock for each connection
+ FUNASR_HANDLE tpass_online_handle=NULL;
+ std::string online_res = "";
+ std::string tpass_res = "";
+
+} FUNASR_MESSAGE;
+
+// See https://wiki.mozilla.org/Security/Server_Side_TLS for more details about
+// the TLS modes. The code below demonstrates how to implement both the modern
+enum tls_mode { MOZILLA_INTERMEDIATE = 1, MOZILLA_MODERN = 2 };
+class WebSocketServer {
+ public:
+ WebSocketServer(asio::io_context& io_decoder, bool is_ssl, server* server,
+ wss_server* wss_server, std::string& s_certfile,
+ std::string& s_keyfile)
+ : io_decoder_(io_decoder),
+ is_ssl(is_ssl),
+ server_(server),
+ wss_server_(wss_server) {
+ if (is_ssl) {
+ std::cout << "certfile path is " << s_certfile << std::endl;
+ wss_server->set_tls_init_handler(
+ bind<context_ptr>(&WebSocketServer::on_tls_init, this,
+ MOZILLA_INTERMEDIATE, ::_1, s_certfile, s_keyfile));
+ wss_server_->set_message_handler(
+ [this](websocketpp::connection_hdl hdl, message_ptr msg) {
+ on_message(hdl, msg);
+ });
+ // set open handle
+ wss_server_->set_open_handler(
+ [this](websocketpp::connection_hdl hdl) { on_open(hdl); });
+ // set close handle
+ wss_server_->set_close_handler(
+ [this](websocketpp::connection_hdl hdl) { on_close(hdl); });
+ // begin accept
+ wss_server_->start_accept();
+ // not print log
+ wss_server_->clear_access_channels(websocketpp::log::alevel::all);
+
+ } else {
+ // set message handle
+ server_->set_message_handler(
+ [this](websocketpp::connection_hdl hdl, message_ptr msg) {
+ on_message(hdl, msg);
+ });
+ // set open handle
+ server_->set_open_handler(
+ [this](websocketpp::connection_hdl hdl) { on_open(hdl); });
+ // set close handle
+ server_->set_close_handler(
+ [this](websocketpp::connection_hdl hdl) { on_close(hdl); });
+ // begin accept
+ server_->start_accept();
+ // not print log
+ server_->clear_access_channels(websocketpp::log::alevel::all);
+ }
+ }
+ void do_decoder(std::vector<char>& buffer, websocketpp::connection_hdl& hdl,
+ nlohmann::json& msg,
+ std::vector<std::vector<std::string>>& punc_cache,
+ websocketpp::lib::mutex& thread_lock, bool& is_final,
+ FUNASR_HANDLE& tpass_online_handle, std::string& online_res,
+ std::string& tpass_res);
+
+ void initAsr(std::map<std::string, std::string>& model_path, int thread_num);
+ void on_message(websocketpp::connection_hdl hdl, message_ptr msg);
+ void on_open(websocketpp::connection_hdl hdl);
+ void on_close(websocketpp::connection_hdl hdl);
+ context_ptr on_tls_init(tls_mode mode, websocketpp::connection_hdl hdl,
+ std::string& s_certfile, std::string& s_keyfile);
+
+ private:
+ void check_and_clean_connection();
+ asio::io_context& io_decoder_; // threads for asr decoder
+ // std::ofstream fout;
+ // FUNASR_HANDLE asr_handle; // asr engine handle
+ FUNASR_HANDLE tpass_handle=NULL;
+ bool isonline = true; // online or offline engine, now only support offline
+ bool is_ssl = true;
+ server* server_; // websocket server
+ wss_server* wss_server_; // websocket server
+
+ // use map to keep the received samples data from one connection in offline
+ // engine. if for online engline, a data struct is needed(TODO)
+
+ std::map<websocketpp::connection_hdl, std::shared_ptr<FUNASR_MESSAGE>,
+ std::owner_less<websocketpp::connection_hdl>>
+ data_map;
+ websocketpp::lib::mutex m_lock; // mutex for sample_map
+};
+
+#endif // WEBSOCKET_SERVER_H_
--
Gitblit v1.9.1