From 9c0fa3c0a435478fe0b36810cfc5ca273d4593f7 Mon Sep 17 00:00:00 2001
From: nichongjia-2007 <nichongjia@gmail.com>
Date: 星期二, 11 七月 2023 17:36:22 +0800
Subject: [PATCH] Merge branch 'main' of https://github.com/alibaba-damo-academy/FunASR
---
funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp | 11
funasr/runtime/onnxruntime/src/audio.cpp | 363 ++++++++++++++++++++++++++++++
funasr/runtime/onnxruntime/src/CMakeLists.txt | 3
funasr/runtime/onnxruntime/src/funasrruntime.cpp | 46 ++-
egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/demo_long.py | 52 ++++
funasr/runtime/websocket/CMakeLists.txt | 1
funasr/runtime/websocket/funasr-wss-client.cpp | 74 ++++-
funasr/runtime/websocket/websocket-server.cpp | 14
funasr/runtime/onnxruntime/include/audio.h | 3
funasr/runtime/websocket/readme.md | 11
funasr/runtime/onnxruntime/include/funasrruntime.h | 6
funasr/runtime/onnxruntime/CMakeLists.txt | 1
funasr/bin/asr_inference_launch.py | 118 +++++----
13 files changed, 597 insertions(+), 106 deletions(-)
diff --git a/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/demo_long.py b/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/demo_long.py
new file mode 100644
index 0000000..c04d985
--- /dev/null
+++ b/egs_modelscope/tp/speech_timestamp_prediction-v1-16k-offline/demo_long.py
@@ -0,0 +1,52 @@
+# if you want to use ASR model besides paraformer-bicif (like contextual paraformer)
+# to get ASR results for long audio as well as timestamp prediction results,
+# try this demo
+from modelscope.pipelines import pipeline
+from modelscope.utils.constant import Tasks
+import os
+import librosa
+import soundfile as sf
+
+param_dict = dict()
+param_dict['hotword'] = "淇¤"
+
+test_wav = '/Users/shixian/Downloads/tpdebug.wav'
+output_dir = './tmp'
+os.system("mkdir -p {}".format(output_dir))
+
+vad_pipeline = pipeline(
+ task=Tasks.voice_activity_detection,
+ model='damo/speech_fsmn_vad_zh-cn-16k-common-pytorch',
+ model_revision=None,
+)
+asr_pipeline = pipeline(
+ task=Tasks.auto_speech_recognition,
+ model="damo/speech_paraformer-large-contextual_asr_nat-zh-cn-16k-common-vocab8404",
+ output_dir=output_dir)
+tp_pipeline = pipeline(
+ task=Tasks.speech_timestamp,
+ model='damo/speech_timestamp_prediction-v1-16k-offline',
+ output_dir=output_dir)
+
+vad_res = vad_pipeline(audio_in=test_wav)
+timestamps = vad_res['text']
+
+samples = librosa.load(test_wav, sr=16000)[0]
+wavseg_scp = "{}/wav.scp".format(output_dir)
+
+with open(wavseg_scp, 'w') as fout:
+ for i, timestamp in enumerate(timestamps):
+ start = int(timestamp[0]/1000*16000)
+ end = int(timestamp[1]/1000*16000)
+ uttid = "wav_{}_{} ".format(start, end)
+ wavpath = '{}/wavseg_{}.wav'.format(output_dir, i)
+ _samples = samples[start:end]
+ sf.write(wavpath, _samples, 16000)
+ fout.write("{} {}\n".format(uttid, wavpath))
+print("Wav segment done: {}".format(wavseg_scp))
+
+asr_res = '{}/1best_recog/text'.format(output_dir)
+tp_res = '{}/timestamp_prediction/tp_sync'.format(output_dir)
+rec_result_asr = asr_pipeline(audio_in=wavseg_scp)
+rec_result_tp = tp_pipeline(audio_in=wavseg_scp, text_in=asr_res)
+print("Find your ASR results in {}, and timestamp prediction results in {}.".format(asr_res, tp_res))
\ No newline at end of file
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index de18894..10f8e50 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -1272,27 +1272,27 @@
nbest: int,
num_workers: int,
log_level: Union[int, str],
- data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
+ # data_path_and_name_and_type: Sequence[Tuple[str, str, str]],
asr_train_config: Optional[str],
asr_model_file: Optional[str],
- cmvn_file: Optional[str],
- beam_search_config: Optional[dict],
- lm_train_config: Optional[str],
- lm_file: Optional[str],
- model_tag: Optional[str],
- token_type: Optional[str],
- bpemodel: Optional[str],
- key_file: Optional[str],
- allow_variable_data_keys: bool,
- quantize_asr_model: Optional[bool],
- quantize_modules: Optional[List[str]],
- quantize_dtype: Optional[str],
- streaming: Optional[bool],
- simu_streaming: Optional[bool],
- chunk_size: Optional[int],
- left_context: Optional[int],
- right_context: Optional[int],
- display_partial_hypotheses: bool,
+ cmvn_file: Optional[str] = None,
+ beam_search_config: Optional[dict] = None,
+ lm_train_config: Optional[str] = None,
+ lm_file: Optional[str] = None,
+ model_tag: Optional[str] = None,
+ token_type: Optional[str] = None,
+ bpemodel: Optional[str] = None,
+ key_file: Optional[str] = None,
+ allow_variable_data_keys: bool = False,
+ quantize_asr_model: Optional[bool] = False,
+ quantize_modules: Optional[List[str]] = None,
+ quantize_dtype: Optional[str] = "float16",
+ streaming: Optional[bool] = False,
+ simu_streaming: Optional[bool] = False,
+ chunk_size: Optional[int] = 16,
+ left_context: Optional[int] = 16,
+ right_context: Optional[int] = 0,
+ display_partial_hypotheses: bool = False,
**kwargs,
) -> None:
"""Transducer model inference.
@@ -1327,6 +1327,7 @@
right_context: Number of frames in right context AFTER subsampling.
display_partial_hypotheses: Whether to display partial hypotheses.
"""
+ # assert check_argument_types()
if batch_size > 1:
raise NotImplementedError("batch decoding is not implemented")
@@ -1369,7 +1370,10 @@
left_context=left_context,
right_context=right_context,
)
- speech2text = Speech2TextTransducer(**speech2text_kwargs)
+ speech2text = Speech2TextTransducer.from_pretrained(
+ model_tag=model_tag,
+ **speech2text_kwargs,
+ )
def _forward(data_path_and_name_and_type,
raw_inputs: Union[np.ndarray, torch.Tensor] = None,
@@ -1388,47 +1392,55 @@
key_file=key_file,
num_workers=num_workers,
)
+ asr_result_list = []
+
+ if output_dir is not None:
+ writer = DatadirWriter(output_dir)
+ else:
+ writer = None
# 4 .Start for-loop
- with DatadirWriter(output_dir) as writer:
- for keys, batch in loader:
- assert isinstance(batch, dict), type(batch)
- assert all(isinstance(s, str) for s in keys), keys
+ for keys, batch in loader:
+ assert isinstance(batch, dict), type(batch)
+ assert all(isinstance(s, str) for s in keys), keys
- _bs = len(next(iter(batch.values())))
- assert len(keys) == _bs, f"{len(keys)} != {_bs}"
- batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
- assert len(batch.keys()) == 1
+ _bs = len(next(iter(batch.values())))
+ assert len(keys) == _bs, f"{len(keys)} != {_bs}"
+ batch = {k: v[0] for k, v in batch.items() if not k.endswith("_lengths")}
+ assert len(batch.keys()) == 1
- try:
- if speech2text.streaming:
- speech = batch["speech"]
+ try:
+ if speech2text.streaming:
+ speech = batch["speech"]
- _steps = len(speech) // speech2text._ctx
- _end = 0
- for i in range(_steps):
- _end = (i + 1) * speech2text._ctx
+ _steps = len(speech) // speech2text._ctx
+ _end = 0
+ for i in range(_steps):
+ _end = (i + 1) * speech2text._ctx
- speech2text.streaming_decode(
- speech[i * speech2text._ctx: _end], is_final=False
- )
-
- final_hyps = speech2text.streaming_decode(
- speech[_end: len(speech)], is_final=True
+ speech2text.streaming_decode(
+ speech[i * speech2text._ctx: _end], is_final=False
)
- elif speech2text.simu_streaming:
- final_hyps = speech2text.simu_streaming_decode(**batch)
- else:
- final_hyps = speech2text(**batch)
- results = speech2text.hypotheses_to_results(final_hyps)
- except TooShortUttError as e:
- logging.warning(f"Utterance {keys} {e}")
- hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
- results = [[" ", ["<space>"], [2], hyp]] * nbest
+ final_hyps = speech2text.streaming_decode(
+ speech[_end: len(speech)], is_final=True
+ )
+ elif speech2text.simu_streaming:
+ final_hyps = speech2text.simu_streaming_decode(**batch)
+ else:
+ final_hyps = speech2text(**batch)
- key = keys[0]
- for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ results = speech2text.hypotheses_to_results(final_hyps)
+ except TooShortUttError as e:
+ logging.warning(f"Utterance {keys} {e}")
+ hyp = Hypothesis(score=0.0, yseq=[], dec_state=None)
+ results = [[" ", ["<space>"], [2], hyp]] * nbest
+
+ key = keys[0]
+ for n, (text, token, token_int, hyp) in zip(range(1, nbest + 1), results):
+ item = {'key': key, 'value': text}
+ asr_result_list.append(item)
+ if writer is not None:
ibest_writer = writer[f"{n}best_recog"]
ibest_writer["token"][key] = " ".join(token)
@@ -1438,6 +1450,8 @@
if text is not None:
ibest_writer["text"][key] = text
+ logging.info("decoding, utt: {}, predictions: {}".format(key, text))
+ return asr_result_list
return _forward
diff --git a/funasr/runtime/onnxruntime/CMakeLists.txt b/funasr/runtime/onnxruntime/CMakeLists.txt
index 0847d1f..637d765 100644
--- a/funasr/runtime/onnxruntime/CMakeLists.txt
+++ b/funasr/runtime/onnxruntime/CMakeLists.txt
@@ -27,6 +27,7 @@
endif()
ELSE()
link_directories(${ONNXRUNTIME_DIR}/lib)
+ link_directories(${FFMPEG_DIR}/lib)
endif()
include_directories(${PROJECT_SOURCE_DIR}/third_party/kaldi-native-fbank)
diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
index 82668f8..caa8605 100644
--- a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
@@ -91,11 +91,8 @@
vector<string> wav_ids;
string default_id = "wav_default_id";
string wav_path_ = model_path.at(WAV_PATH);
- if(is_target_file(wav_path_, "wav") || is_target_file(wav_path_, "pcm")){
- wav_list.emplace_back(wav_path_);
- wav_ids.emplace_back(default_id);
- }
- else if(is_target_file(wav_path_, "scp")){
+
+ if(is_target_file(wav_path_, "scp")){
ifstream in(wav_path_);
if (!in.is_open()) {
LOG(ERROR) << "Failed to open file: " << model_path.at(WAV_SCP) ;
@@ -112,8 +109,8 @@
}
in.close();
}else{
- LOG(ERROR)<<"Please check the wav extension!";
- exit(-1);
+ wav_list.emplace_back(wav_path_);
+ wav_ids.emplace_back(default_id);
}
float snippet_time = 0.0f;
diff --git a/funasr/runtime/onnxruntime/include/audio.h b/funasr/runtime/onnxruntime/include/audio.h
index d2100a4..a1b6312 100644
--- a/funasr/runtime/onnxruntime/include/audio.h
+++ b/funasr/runtime/onnxruntime/include/audio.h
@@ -55,6 +55,9 @@
bool LoadPcmwav(const char* buf, int n_file_len, int32_t* sampling_rate);
bool LoadPcmwav(const char* filename, int32_t* sampling_rate);
bool LoadPcmwav2Char(const char* filename, int32_t* sampling_rate);
+ bool LoadOthers2Char(const char* filename);
+ bool FfmpegLoad(const char *filename);
+ bool FfmpegLoad(const char* buf, int n_file_len);
int FetchChunck(float *&dout, int len);
int Fetch(float *&dout, int &len, int &flag);
void Padding();
diff --git a/funasr/runtime/onnxruntime/include/funasrruntime.h b/funasr/runtime/onnxruntime/include/funasrruntime.h
index 98727bd..ddb65b9 100644
--- a/funasr/runtime/onnxruntime/include/funasrruntime.h
+++ b/funasr/runtime/onnxruntime/include/funasrruntime.h
@@ -56,7 +56,7 @@
// ASR
_FUNASRAPI FUNASR_HANDLE FunASRInit(std::map<std::string, std::string>& model_path, int thread_num);
// buffer
-_FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
+_FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000, std::string wav_format="pcm");
// file, support wav & pcm
_FUNASRAPI FUNASR_RESULT FunASRInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
@@ -70,7 +70,7 @@
_FUNASRAPI FUNASR_HANDLE FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num);
_FUNASRAPI FUNASR_HANDLE FsmnVadOnlineInit(FUNASR_HANDLE fsmnvad_handle);
// buffer
-_FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, QM_CALLBACK fn_callback, bool input_finished=true, int sampling_rate=16000);
+_FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, QM_CALLBACK fn_callback, bool input_finished=true, int sampling_rate=16000, std::string wav_format="pcm");
// file, support wav & pcm
_FUNASRAPI FUNASR_RESULT FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, QM_CALLBACK fn_callback, int sampling_rate=16000);
@@ -89,7 +89,7 @@
//OfflineStream
_FUNASRAPI FUNASR_HANDLE FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num);
// buffer
-_FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
+_FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000, std::string wav_format="pcm");
// file, support wav & pcm
_FUNASRAPI FUNASR_RESULT FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
_FUNASRAPI void FunOfflineUninit(FUNASR_HANDLE handle);
diff --git a/funasr/runtime/onnxruntime/src/CMakeLists.txt b/funasr/runtime/onnxruntime/src/CMakeLists.txt
index d083d8e..c781ef0 100644
--- a/funasr/runtime/onnxruntime/src/CMakeLists.txt
+++ b/funasr/runtime/onnxruntime/src/CMakeLists.txt
@@ -15,8 +15,9 @@
target_compile_definitions(funasr PUBLIC -D_FUNASR_API_EXPORT)
else()
- set(EXTRA_LIBS pthread yaml-cpp csrc glog )
+ set(EXTRA_LIBS pthread yaml-cpp csrc glog avutil avcodec avformat swresample)
include_directories(${ONNXRUNTIME_DIR}/include)
+ include_directories(${FFMPEG_DIR}/include)
endif()
include_directories(${CMAKE_SOURCE_DIR}/include)
diff --git a/funasr/runtime/onnxruntime/src/audio.cpp b/funasr/runtime/onnxruntime/src/audio.cpp
index 23d0010..85633b7 100644
--- a/funasr/runtime/onnxruntime/src/audio.cpp
+++ b/funasr/runtime/onnxruntime/src/audio.cpp
@@ -9,6 +9,15 @@
#include "audio.h"
#include "precomp.h"
+extern "C" {
+#include <libavutil/opt.h>
+#include <libavcodec/avcodec.h>
+#include <libavformat/avformat.h>
+#include <libavutil/channel_layout.h>
+#include <libavutil/samplefmt.h>
+#include <libswresample/swresample.h>
+}
+
using namespace std;
namespace funasr {
@@ -220,6 +229,334 @@
memset(speech_data, 0, sizeof(float) * speech_len);
copy(samples.begin(), samples.end(), speech_data);
}
+
+bool Audio::FfmpegLoad(const char *filename){
+ // from file
+ AVFormatContext* formatContext = avformat_alloc_context();
+ if (avformat_open_input(&formatContext, filename, NULL, NULL) != 0) {
+ printf("Error: Could not open input file.");
+ avformat_close_input(&formatContext);
+ avformat_free_context(formatContext);
+ return false;
+ }
+
+ if (avformat_find_stream_info(formatContext, NULL) < 0) {
+ printf("Error: Could not find stream information.");
+ avformat_close_input(&formatContext);
+ avformat_free_context(formatContext);
+ return false;
+ }
+ const AVCodec* codec = NULL;
+ AVCodecParameters* codecParameters = NULL;
+ int audioStreamIndex = av_find_best_stream(formatContext, AVMEDIA_TYPE_AUDIO, -1, -1, &codec, 0);
+ if (audioStreamIndex >= 0) {
+ codecParameters = formatContext->streams[audioStreamIndex]->codecpar;
+ }
+ AVCodecContext* codecContext = avcodec_alloc_context3(codec);
+ if (!codecContext) {
+ fprintf(stderr, "Failed to allocate codec context\n");
+ avformat_close_input(&formatContext);
+ avformat_free_context(formatContext);
+ return false;
+ }
+ if (avcodec_parameters_to_context(codecContext, codecParameters) != 0) {
+ printf("Error: Could not copy codec parameters to codec context.");
+ avformat_close_input(&formatContext);
+ avformat_free_context(formatContext);
+ avcodec_free_context(&codecContext);
+ return false;
+ }
+ if (avcodec_open2(codecContext, codec, NULL) < 0) {
+ printf("Error: Could not open audio decoder.");
+ avformat_close_input(&formatContext);
+ avformat_free_context(formatContext);
+ avcodec_free_context(&codecContext);
+ return false;
+ }
+ SwrContext *swr_ctx = swr_alloc_set_opts(
+ nullptr, // allocate a new context
+ AV_CH_LAYOUT_MONO, // output channel layout (stereo)
+ AV_SAMPLE_FMT_S16, // output sample format (signed 16-bit)
+ 16000, // output sample rate (same as input)
+ av_get_default_channel_layout(codecContext->channels), // input channel layout
+ codecContext->sample_fmt, // input sample format
+ codecContext->sample_rate, // input sample rate
+ 0, // logging level
+ nullptr // parent context
+ );
+ if (swr_ctx == nullptr) {
+ std::cerr << "Could not initialize resampler" << std::endl;
+ avformat_close_input(&formatContext);
+ avformat_free_context(formatContext);
+ avcodec_free_context(&codecContext);
+ return false;
+ }
+ if (swr_init(swr_ctx) != 0) {
+ std::cerr << "Could not initialize resampler" << std::endl;
+ avformat_close_input(&formatContext);
+ avformat_free_context(formatContext);
+ avcodec_free_context(&codecContext);
+ swr_free(&swr_ctx);
+ return false;
+ }
+
+ // to pcm
+ AVPacket* packet = av_packet_alloc();
+ AVFrame* frame = av_frame_alloc();
+ std::vector<uint8_t> resampled_buffers;
+ while (av_read_frame(formatContext, packet) >= 0) {
+ if (packet->stream_index == audioStreamIndex) {
+ if (avcodec_send_packet(codecContext, packet) >= 0) {
+ while (avcodec_receive_frame(codecContext, frame) >= 0) {
+ // Resample audio if necessary
+ std::vector<uint8_t> resampled_buffer;
+ int in_samples = frame->nb_samples;
+ uint8_t **in_data = frame->extended_data;
+ int out_samples = av_rescale_rnd(in_samples,
+ 16000,
+ codecContext->sample_rate,
+ AV_ROUND_DOWN);
+
+ int resampled_size = out_samples * av_get_bytes_per_sample(AV_SAMPLE_FMT_S16);
+ if (resampled_buffer.size() < resampled_size) {
+ resampled_buffer.resize(resampled_size);
+ }
+ uint8_t *resampled_data = resampled_buffer.data();
+ int ret = swr_convert(
+ swr_ctx,
+ &resampled_data, // output buffer
+ resampled_size, // output buffer size
+ (const uint8_t **)(frame->data), //(const uint8_t **)(frame->extended_data)
+ in_samples // input buffer size
+ );
+ if (ret < 0) {
+ std::cerr << "Error resampling audio" << std::endl;
+ break;
+ }
+ std::copy(resampled_buffer.begin(), resampled_buffer.end(), std::back_inserter(resampled_buffers));
+ }
+ }
+ }
+ av_packet_unref(packet);
+ }
+
+ avformat_close_input(&formatContext);
+ avformat_free_context(formatContext);
+ avcodec_free_context(&codecContext);
+ swr_free(&swr_ctx);
+ av_packet_free(&packet);
+ av_frame_free(&frame);
+
+ if (speech_data != NULL) {
+ free(speech_data);
+ }
+ if (speech_buff != NULL) {
+ free(speech_buff);
+ }
+ offset = 0;
+
+ speech_len = (resampled_buffers.size()) / 2;
+ speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
+ if (speech_buff)
+ {
+ memset(speech_buff, 0, sizeof(int16_t) * speech_len);
+ memcpy((void*)speech_buff, (const void*)resampled_buffers.data(), speech_len * sizeof(int16_t));
+
+ speech_data = (float*)malloc(sizeof(float) * speech_len);
+ memset(speech_data, 0, sizeof(float) * speech_len);
+
+ float scale = 1;
+ if (data_type == 1) {
+ scale = 32768;
+ }
+ for (int32_t i = 0; i != speech_len; ++i) {
+ speech_data[i] = (float)speech_buff[i] / scale;
+ }
+
+ AudioFrame* frame = new AudioFrame(speech_len);
+ frame_queue.push(frame);
+
+ return true;
+ }
+ else
+ return false;
+
+}
+
+bool Audio::FfmpegLoad(const char* buf, int n_file_len){
+ // from buf
+ char* buf_copy = (char *)malloc(n_file_len);
+ memcpy(buf_copy, buf, n_file_len);
+
+ AVIOContext* avio_ctx = avio_alloc_context(
+ (unsigned char*)buf_copy, // buffer
+ n_file_len, // buffer size
+ 0, // write flag (0 for read-only)
+ nullptr, // opaque pointer (not used here)
+ nullptr, // read callback (not used here)
+ nullptr, // write callback (not used here)
+ nullptr // seek callback (not used here)
+ );
+ AVFormatContext* formatContext = avformat_alloc_context();
+ formatContext->pb = avio_ctx;
+ if (avformat_open_input(&formatContext, "", NULL, NULL) != 0) {
+ printf("Error: Could not open input file.");
+ avio_context_free(&avio_ctx);
+ avformat_close_input(&formatContext);
+ avformat_free_context(formatContext);
+ return false;
+ }
+
+ if (avformat_find_stream_info(formatContext, NULL) < 0) {
+ printf("Error: Could not find stream information.");
+ avio_context_free(&avio_ctx);
+ avformat_close_input(&formatContext);
+ avformat_free_context(formatContext);
+ return false;
+ }
+ const AVCodec* codec = NULL;
+ AVCodecParameters* codecParameters = NULL;
+ int audioStreamIndex = av_find_best_stream(formatContext, AVMEDIA_TYPE_AUDIO, -1, -1, &codec, 0);
+ if (audioStreamIndex >= 0) {
+ codecParameters = formatContext->streams[audioStreamIndex]->codecpar;
+ }
+ AVCodecContext* codecContext = avcodec_alloc_context3(codec);
+ if (!codecContext) {
+ fprintf(stderr, "Failed to allocate codec context\n");
+ avio_context_free(&avio_ctx);
+ avformat_close_input(&formatContext);
+ avformat_free_context(formatContext);
+ return false;
+ }
+ if (avcodec_parameters_to_context(codecContext, codecParameters) != 0) {
+ printf("Error: Could not copy codec parameters to codec context.");
+ avio_context_free(&avio_ctx);
+ avformat_close_input(&formatContext);
+ avformat_free_context(formatContext);
+ avcodec_free_context(&codecContext);
+ return false;
+ }
+ if (avcodec_open2(codecContext, codec, NULL) < 0) {
+ printf("Error: Could not open audio decoder.");
+ avio_context_free(&avio_ctx);
+ avformat_close_input(&formatContext);
+ avformat_free_context(formatContext);
+ avcodec_free_context(&codecContext);
+ return false;
+ }
+ SwrContext *swr_ctx = swr_alloc_set_opts(
+ nullptr, // allocate a new context
+ AV_CH_LAYOUT_MONO, // output channel layout (stereo)
+ AV_SAMPLE_FMT_S16, // output sample format (signed 16-bit)
+ 16000, // output sample rate (same as input)
+ av_get_default_channel_layout(codecContext->channels), // input channel layout
+ codecContext->sample_fmt, // input sample format
+ codecContext->sample_rate, // input sample rate
+ 0, // logging level
+ nullptr // parent context
+ );
+ if (swr_ctx == nullptr) {
+ std::cerr << "Could not initialize resampler" << std::endl;
+ avio_context_free(&avio_ctx);
+ avformat_close_input(&formatContext);
+ avformat_free_context(formatContext);
+ avcodec_free_context(&codecContext);
+ return false;
+ }
+ if (swr_init(swr_ctx) != 0) {
+ std::cerr << "Could not initialize resampler" << std::endl;
+ avio_context_free(&avio_ctx);
+ avformat_close_input(&formatContext);
+ avformat_free_context(formatContext);
+ avcodec_free_context(&codecContext);
+ swr_free(&swr_ctx);
+ return false;
+ }
+
+ // to pcm
+ AVPacket* packet = av_packet_alloc();
+ AVFrame* frame = av_frame_alloc();
+ std::vector<uint8_t> resampled_buffers;
+ while (av_read_frame(formatContext, packet) >= 0) {
+ if (packet->stream_index == audioStreamIndex) {
+ if (avcodec_send_packet(codecContext, packet) >= 0) {
+ while (avcodec_receive_frame(codecContext, frame) >= 0) {
+ // Resample audio if necessary
+ std::vector<uint8_t> resampled_buffer;
+ int in_samples = frame->nb_samples;
+ uint8_t **in_data = frame->extended_data;
+ int out_samples = av_rescale_rnd(in_samples,
+ 16000,
+ codecContext->sample_rate,
+ AV_ROUND_DOWN);
+
+ int resampled_size = out_samples * av_get_bytes_per_sample(AV_SAMPLE_FMT_S16);
+ if (resampled_buffer.size() < resampled_size) {
+ resampled_buffer.resize(resampled_size);
+ }
+ uint8_t *resampled_data = resampled_buffer.data();
+ int ret = swr_convert(
+ swr_ctx,
+ &resampled_data, // output buffer
+ resampled_size, // output buffer size
+ (const uint8_t **)(frame->data), //(const uint8_t **)(frame->extended_data)
+ in_samples // input buffer size
+ );
+ if (ret < 0) {
+ std::cerr << "Error resampling audio" << std::endl;
+ break;
+ }
+ std::copy(resampled_buffer.begin(), resampled_buffer.end(), std::back_inserter(resampled_buffers));
+ }
+ }
+ }
+ av_packet_unref(packet);
+ }
+
+ avio_context_free(&avio_ctx);
+ avformat_close_input(&formatContext);
+ avformat_free_context(formatContext);
+ avcodec_free_context(&codecContext);
+ swr_free(&swr_ctx);
+ av_packet_free(&packet);
+ av_frame_free(&frame);
+
+ if (speech_data != NULL) {
+ free(speech_data);
+ }
+ if (speech_buff != NULL) {
+ free(speech_buff);
+ }
+ offset = 0;
+
+ speech_len = (resampled_buffers.size()) / 2;
+ speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
+ if (speech_buff)
+ {
+ memset(speech_buff, 0, sizeof(int16_t) * speech_len);
+ memcpy((void*)speech_buff, (const void*)resampled_buffers.data(), speech_len * sizeof(int16_t));
+
+ speech_data = (float*)malloc(sizeof(float) * speech_len);
+ memset(speech_data, 0, sizeof(float) * speech_len);
+
+ float scale = 1;
+ if (data_type == 1) {
+ scale = 32768;
+ }
+ for (int32_t i = 0; i != speech_len; ++i) {
+ speech_data[i] = (float)speech_buff[i] / scale;
+ }
+
+ AudioFrame* frame = new AudioFrame(speech_len);
+ frame_queue.push(frame);
+
+ return true;
+ }
+ else
+ return false;
+
+}
+
bool Audio::LoadWav(const char *filename, int32_t* sampling_rate)
{
@@ -507,6 +844,32 @@
return true;
}
+bool Audio::LoadOthers2Char(const char* filename)
+{
+ if (speech_char != NULL) {
+ free(speech_char);
+ }
+
+ FILE* fp;
+ fp = fopen(filename, "rb");
+ if (fp == nullptr)
+ {
+ LOG(ERROR) << "Failed to read " << filename;
+ return false;
+ }
+ fseek(fp, 0, SEEK_END);
+ uint32_t n_file_len = ftell(fp);
+ fseek(fp, 0, SEEK_SET);
+
+ speech_len = n_file_len;
+ speech_char = (char *)malloc(n_file_len);
+ memset(speech_char, 0, n_file_len);
+ fread(speech_char, 1, n_file_len, fp);
+ fclose(fp);
+
+ return true;
+}
+
int Audio::FetchChunck(float *&dout, int len)
{
if (offset >= speech_align_len) {
diff --git a/funasr/runtime/onnxruntime/src/funasrruntime.cpp b/funasr/runtime/onnxruntime/src/funasrruntime.cpp
index 82fdd70..a1829fd 100644
--- a/funasr/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/funasr/runtime/onnxruntime/src/funasrruntime.cpp
@@ -36,15 +36,20 @@
}
// APIs for ASR Infer
- _FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate)
+ _FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate, std::string wav_format)
{
funasr::Model* recog_obj = (funasr::Model*)handle;
if (!recog_obj)
return nullptr;
funasr::Audio audio(1);
- if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
- return nullptr;
+ if(wav_format == "pcm" || wav_format == "PCM"){
+ if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
+ return nullptr;
+ }else{
+ if (!audio.FfmpegLoad(sz_buf, n_len))
+ return nullptr;
+ }
float* buff;
int len;
@@ -82,8 +87,8 @@
if (!audio.LoadPcmwav(sz_filename, &sampling_rate))
return nullptr;
}else{
- LOG(ERROR)<<"Wrong wav extension";
- exit(-1);
+ if (!audio.FfmpegLoad(sz_filename))
+ return nullptr;
}
float* buff;
@@ -108,15 +113,20 @@
}
// APIs for VAD Infer
- _FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, QM_CALLBACK fn_callback, bool input_finished, int sampling_rate)
+ _FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, QM_CALLBACK fn_callback, bool input_finished, int sampling_rate, std::string wav_format)
{
funasr::VadModel* vad_obj = (funasr::VadModel*)handle;
if (!vad_obj)
return nullptr;
funasr::Audio audio(1);
- if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
- return nullptr;
+ if(wav_format == "pcm" || wav_format == "PCM"){
+ if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
+ return nullptr;
+ }else{
+ if (!audio.FfmpegLoad(sz_buf, n_len))
+ return nullptr;
+ }
funasr::FUNASR_VAD_RESULT* p_result = new funasr::FUNASR_VAD_RESULT;
p_result->snippet_time = audio.GetTimeLen();
@@ -146,8 +156,8 @@
if (!audio.LoadPcmwav(sz_filename, &sampling_rate))
return nullptr;
}else{
- LOG(ERROR)<<"Wrong wav extension";
- exit(-1);
+ if (!audio.FfmpegLoad(sz_filename))
+ return nullptr;
}
funasr::FUNASR_VAD_RESULT* p_result = new funasr::FUNASR_VAD_RESULT;
@@ -189,15 +199,21 @@
}
// APIs for Offline-stream Infer
- _FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate)
+ _FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate, std::string wav_format)
{
funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
if (!offline_stream)
return nullptr;
funasr::Audio audio(1);
- if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
- return nullptr;
+ if(wav_format == "pcm" || wav_format == "PCM"){
+ if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
+ return nullptr;
+ }else{
+ if (!audio.FfmpegLoad(sz_buf, n_len))
+ return nullptr;
+ }
+
funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
p_result->snippet_time = audio.GetTimeLen();
if(p_result->snippet_time == 0){
@@ -243,8 +259,8 @@
if (!audio.LoadPcmwav(sz_filename, &sampling_rate))
return nullptr;
}else{
- LOG(ERROR)<<"Wrong wav extension";
- exit(-1);
+ if (!audio.FfmpegLoad(sz_filename))
+ return nullptr;
}
funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
p_result->snippet_time = audio.GetTimeLen();
diff --git a/funasr/runtime/websocket/CMakeLists.txt b/funasr/runtime/websocket/CMakeLists.txt
index 51b0795..06ae59b 100644
--- a/funasr/runtime/websocket/CMakeLists.txt
+++ b/funasr/runtime/websocket/CMakeLists.txt
@@ -40,6 +40,7 @@
# Include generated *.pb.h files
link_directories(${ONNXRUNTIME_DIR}/lib)
+link_directories(${FFMPEG_DIR}/lib)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/include/)
include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp/include/)
diff --git a/funasr/runtime/websocket/funasr-wss-client.cpp b/funasr/runtime/websocket/funasr-wss-client.cpp
index b6d69f2..231303f 100644
--- a/funasr/runtime/websocket/funasr-wss-client.cpp
+++ b/funasr/runtime/websocket/funasr-wss-client.cpp
@@ -187,6 +187,7 @@
funasr::Audio audio(1);
int32_t sampling_rate = 16000;
+ std::string wav_format = "pcm";
if(IsTargetFile(wav_path.c_str(), "wav")){
int32_t sampling_rate = -1;
if(!audio.LoadWav(wav_path.c_str(), &sampling_rate))
@@ -195,8 +196,9 @@
if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate))
return ;
}else{
- printf("Wrong wav extension");
- exit(-1);
+ wav_format = "others";
+ if (!audio.LoadOthers2Char(wav_path.c_str()))
+ return ;
}
float* buff;
@@ -233,20 +235,54 @@
jsonbegin["chunk_size"] = chunk_size;
jsonbegin["chunk_interval"] = 10;
jsonbegin["wav_name"] = wav_id;
+ jsonbegin["wav_format"] = wav_format;
jsonbegin["is_speaking"] = true;
m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text,
ec);
// fetch wav data use asr engine api
- while (audio.Fetch(buff, len, flag) > 0) {
- short* iArray = new short[len];
- for (size_t i = 0; i < len; ++i) {
- iArray[i] = (short)(buff[i]*32768);
- }
+ if(wav_format == "pcm"){
+ while (audio.Fetch(buff, len, flag) > 0) {
+ short* iArray = new short[len];
+ for (size_t i = 0; i < len; ++i) {
+ iArray[i] = (short)(buff[i]*32768);
+ }
- // send data to server
+ // send data to server
+ int offset = 0;
+ int block_size = 102400;
+ while(offset < len){
+ int send_block = 0;
+ if (offset + block_size <= len){
+ send_block = block_size;
+ }else{
+ send_block = len - offset;
+ }
+ m_client.send(m_hdl, iArray+offset, send_block * sizeof(short),
+ websocketpp::frame::opcode::binary, ec);
+ offset += send_block;
+ }
+
+ LOG(INFO) << "sended data len=" << len * sizeof(short);
+ // The most likely error that we will get is that the connection is
+ // not in the right state. Usually this means we tried to send a
+ // message to a connection that was closed or in the process of
+ // closing. While many errors here can be easily recovered from,
+ // in this simple example, we'll stop the data loop.
+ if (ec) {
+ m_client.get_alog().write(websocketpp::log::alevel::app,
+ "Send Error: " + ec.message());
+ break;
+ }
+ delete[] iArray;
+ // WaitABit();
+ }
+ }else{
int offset = 0;
- int block_size = 102400;
+ int block_size = 204800;
+ len = audio.GetSpeechLen();
+ char* others_buff = audio.GetSpeechChar();
+
while(offset < len){
int send_block = 0;
if (offset + block_size <= len){
@@ -254,25 +290,23 @@
}else{
send_block = len - offset;
}
- m_client.send(m_hdl, iArray+offset, send_block * sizeof(short),
+ m_client.send(m_hdl, others_buff+offset, send_block,
websocketpp::frame::opcode::binary, ec);
offset += send_block;
}
- LOG(INFO) << "sended data len=" << len * sizeof(short);
+ LOG(INFO) << "sended data len=" << len;
// The most likely error that we will get is that the connection is
// not in the right state. Usually this means we tried to send a
// message to a connection that was closed or in the process of
// closing. While many errors here can be easily recovered from,
// in this simple example, we'll stop the data loop.
if (ec) {
- m_client.get_alog().write(websocketpp::log::alevel::app,
+ m_client.get_alog().write(websocketpp::log::alevel::app,
"Send Error: " + ec.message());
- break;
}
- delete[] iArray;
- // WaitABit();
}
+
nlohmann::json jsonresult;
jsonresult["is_speaking"] = false;
m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
@@ -332,11 +366,7 @@
std::vector<string> wav_list;
std::vector<string> wav_ids;
string default_id = "wav_default_id";
- if(IsTargetFile(wav_path, "wav") || IsTargetFile(wav_path, "pcm")){
- wav_list.emplace_back(wav_path);
- wav_ids.emplace_back(default_id);
- }
- else if(IsTargetFile(wav_path, "scp")){
+ if(IsTargetFile(wav_path, "scp")){
ifstream in(wav_path);
if (!in.is_open()) {
printf("Failed to open scp file");
@@ -353,8 +383,8 @@
}
in.close();
}else{
- printf("Please check the wav extension!");
- exit(-1);
+ wav_list.emplace_back(wav_path);
+ wav_ids.emplace_back(default_id);
}
for (size_t i = 0; i < threads_num; i++) {
diff --git a/funasr/runtime/websocket/readme.md b/funasr/runtime/websocket/readme.md
index b67a905..378f478 100644
--- a/funasr/runtime/websocket/readme.md
+++ b/funasr/runtime/websocket/readme.md
@@ -32,6 +32,15 @@
tar -zxvf onnxruntime-linux-x64-1.14.0.tgz
```
+### Download ffmpeg
+```shell
+wget https://github.com/BtbN/FFmpeg-Builds/releases/download/autobuild-2023-07-09-12-50/ffmpeg-N-111383-g20b8688092-linux64-gpl-shared.tar.xz
+tar -xvf ffmpeg-N-111383-g20b8688092-linux64-gpl-shared.tar.xz
+# 鍥藉唴鍙互浣跨敤涓嬭堪鏂瑰紡
+# wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/ffmpeg-N-111383-g20b8688092-linux64-gpl-shared.tar.xz
+# tar -xvf ffmpeg-N-111383-g20b8688092-linux64-gpl-shared.tar.xz
+```
+
### Install openblas
```shell
sudo apt-get install libopenblas-dev #ubuntu
@@ -48,7 +57,7 @@
git clone https://github.com/alibaba-damo-academy/FunASR.git && cd FunASR/funasr/runtime/websocket
mkdir build && cd build
-cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0
+cmake -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0 -DFFMPEG_DIR=/path/to/ffmpeg-N-111383-g20b8688092-linux64-gpl-shared
make
```
## Run the websocket server
diff --git a/funasr/runtime/websocket/websocket-server.cpp b/funasr/runtime/websocket/websocket-server.cpp
index a311c23..59109b3 100644
--- a/funasr/runtime/websocket/websocket-server.cpp
+++ b/funasr/runtime/websocket/websocket-server.cpp
@@ -61,13 +61,13 @@
int num_samples = buffer.size(); // the size of the buf
if (!buffer.empty()) {
- // fout.write(buffer.data(), buffer.size());
// feed data to asr engine
FUNASR_RESULT Result = FunOfflineInferBuffer(
- asr_hanlde, buffer.data(), buffer.size(), RASR_NONE, NULL, 16000);
+ asr_hanlde, buffer.data(), buffer.size(), RASR_NONE, NULL, 16000, msg["wav_format"]);
std::string asr_result =
((FUNASR_RECOG_RESULT*)Result)->msg; // get decode result
+ FunASRFreeResult(Result);
websocketpp::lib::error_code ec;
nlohmann::json jsonresult; // result json
@@ -107,6 +107,7 @@
// connection
data_msg->samples = std::make_shared<std::vector<char>>();
data_msg->msg = nlohmann::json::parse("{}");
+ data_msg->msg["wav_format"] = "pcm";
data_map.emplace(hdl, data_msg);
LOG(INFO) << "on_open, active connections: " << data_map.size();
}
@@ -171,6 +172,9 @@
if (jsonresult["wav_name"] != nullptr) {
msg_data->msg["wav_name"] = jsonresult["wav_name"];
}
+ if (jsonresult["wav_format"] != nullptr) {
+ msg_data->msg["wav_format"] = jsonresult["wav_format"];
+ }
if (jsonresult["is_speaking"] == false ||
jsonresult["is_finished"] == true) {
@@ -180,9 +184,9 @@
// do_close(ws);
} else {
// add padding to the end of the wav data
- std::vector<short> padding(static_cast<short>(0.3 * 16000));
- sample_data_p->insert(sample_data_p->end(), padding.data(),
- padding.data() + padding.size());
+ // std::vector<short> padding(static_cast<short>(0.3 * 16000));
+ // sample_data_p->insert(sample_data_p->end(), padding.data(),
+ // padding.data() + padding.size());
// for offline, send all receive data to decoder engine
asio::post(io_decoder_,
std::bind(&WebSocketServer::do_decoder, this,
--
Gitblit v1.9.1