From ae609ca0c64056622888d5eddfca09a92defc30b Mon Sep 17 00:00:00 2001
From: Yabin Li <wucong.lyb@alibaba-inc.com>
Date: 星期二, 11 七月 2023 10:57:30 +0800
Subject: [PATCH] Dev ffmpeg (#727)

---
 funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp |   11 
 funasr/runtime/onnxruntime/src/audio.cpp               |  363 ++++++++++++++++++++++++++++++
 funasr/runtime/websocket/CMakeLists.txt                |    1 
 funasr/runtime/websocket/funasr-wss-client.cpp         |   74 ++++-
 funasr/runtime/onnxruntime/src/CMakeLists.txt          |    3 
 funasr/runtime/websocket/websocket-server.cpp          |   14 
 funasr/runtime/onnxruntime/include/audio.h             |    3 
 funasr/runtime/onnxruntime/src/funasrruntime.cpp       |   46 ++-
 funasr/runtime/onnxruntime/bin/ffmpeg.cpp              |  167 +++++++++++++
 funasr/runtime/websocket/readme.md                     |   11 
 funasr/runtime/onnxruntime/include/funasrruntime.h     |    6 
 funasr/runtime/onnxruntime/CMakeLists.txt              |    1 
 12 files changed, 646 insertions(+), 54 deletions(-)

diff --git a/funasr/runtime/onnxruntime/CMakeLists.txt b/funasr/runtime/onnxruntime/CMakeLists.txt
index 0847d1f..637d765 100644
--- a/funasr/runtime/onnxruntime/CMakeLists.txt
+++ b/funasr/runtime/onnxruntime/CMakeLists.txt
@@ -27,6 +27,7 @@
 	endif()
 ELSE()
     link_directories(${ONNXRUNTIME_DIR}/lib)
+    link_directories(${FFMPEG_DIR}/lib)
 endif()
 
 include_directories(${PROJECT_SOURCE_DIR}/third_party/kaldi-native-fbank)
diff --git a/funasr/runtime/onnxruntime/bin/ffmpeg.cpp b/funasr/runtime/onnxruntime/bin/ffmpeg.cpp
new file mode 100644
index 0000000..65976a0
--- /dev/null
+++ b/funasr/runtime/onnxruntime/bin/ffmpeg.cpp
@@ -0,0 +1,167 @@
+#include <iostream>
+#include <vector>
+#include <cstring>
+#include <fstream>
+
+extern "C" {
+#include <libavutil/opt.h>
+#include <libavcodec/avcodec.h>
+#include <libavformat/avformat.h>
+#include <libavutil/channel_layout.h>
+#include <libavutil/samplefmt.h>
+#include <libswresample/swresample.h>
+}
+
+int main(int argc, char* argv[]) {
+    // from buff
+    FILE* fp;
+    fp = fopen(argv[1], "rb");
+    if (fp == nullptr)
+	{
+        return -1;
+	}
+    fseek(fp, 0, SEEK_END);
+    uint32_t n_file_len = ftell(fp);
+    fseek(fp, 0, SEEK_SET);
+
+    char* buf = (char *)malloc(n_file_len);
+    memset(buf, 0, n_file_len);
+    fread(buf, 1, n_file_len, fp);
+    fclose(fp);
+
+    AVIOContext* avio_ctx = avio_alloc_context(
+        (unsigned char*)buf, // buffer
+        n_file_len, // buffer size
+        0, // write flag (0 for read-only)
+        nullptr, // opaque pointer (not used here)
+        nullptr, // read callback (not used here)
+        nullptr, // write callback (not used here)
+        nullptr // seek callback (not used here)
+    );
+    AVFormatContext* formatContext = avformat_alloc_context();
+    formatContext->pb = avio_ctx;
+    if (avformat_open_input(&formatContext, "", NULL, NULL) != 0) {
+        printf("Error: Could not open input file.");
+        avformat_close_input(&formatContext);
+        avformat_free_context(formatContext);
+        return -1;
+    }
+
+    // from file
+    // AVFormatContext* formatContext = avformat_alloc_context();
+    // if (avformat_open_input(&formatContext, argv[1], NULL, NULL) != 0) {
+    //     printf("Error: Could not open input file.");
+    //     avformat_close_input(&formatContext);
+    //     avformat_free_context(formatContext);
+    //     return -1;
+    // }
+
+
+    if (avformat_find_stream_info(formatContext, NULL) < 0) {
+        printf("Error: Could not find stream information.");
+        avformat_close_input(&formatContext);
+        avformat_free_context(formatContext);
+        return -1;
+    }
+    const AVCodec* codec = NULL;
+    AVCodecParameters* codecParameters = NULL;
+    int audioStreamIndex = av_find_best_stream(formatContext, AVMEDIA_TYPE_AUDIO, -1, -1, &codec, 0);
+    if (audioStreamIndex >= 0) {
+        codecParameters = formatContext->streams[audioStreamIndex]->codecpar;
+    }
+    AVCodecContext* codecContext = avcodec_alloc_context3(codec);
+    if (!codecContext) {
+        fprintf(stderr, "Failed to allocate codec context\n");
+        avformat_close_input(&formatContext);
+        return -1;
+    }
+    if (avcodec_parameters_to_context(codecContext, codecParameters) != 0) {
+        printf("Error: Could not copy codec parameters to codec context.");
+        avcodec_free_context(&codecContext);
+        avformat_close_input(&formatContext);
+        avformat_free_context(formatContext);
+        return -1;
+    }
+    if (avcodec_open2(codecContext, codec, NULL) < 0) {
+        printf("Error: Could not open audio decoder.");
+        avcodec_free_context(&codecContext);
+        avformat_close_input(&formatContext);
+        avformat_free_context(formatContext);
+        return -1;
+    }
+    SwrContext *swr_ctx = swr_alloc_set_opts(
+        nullptr, // allocate a new context
+        AV_CH_LAYOUT_MONO, // output channel layout (stereo)
+        AV_SAMPLE_FMT_S16, // output sample format (signed 16-bit)
+        16000, // output sample rate (same as input)
+        av_get_default_channel_layout(codecContext->channels), // input channel layout
+        codecContext->sample_fmt, // input sample format
+        codecContext->sample_rate, // input sample rate
+        0, // logging level
+        nullptr // parent context
+    );
+    if (swr_ctx == nullptr) {
+        std::cerr << "Could not initialize resampler" << std::endl;
+        avcodec_free_context(&codecContext);
+        avformat_close_input(&formatContext);
+        avformat_free_context(formatContext);
+        return -1;
+    }
+    if (swr_init(swr_ctx) != 0) {
+        std::cerr << "Could not initialize resampler" << std::endl;
+        swr_free(&swr_ctx);
+        avcodec_free_context(&codecContext);
+        avformat_close_input(&formatContext);
+        avformat_free_context(formatContext);
+        return -1;
+    }
+
+    // to pcm
+    FILE *out_file = fopen("output.pcm", "wb");
+    AVPacket* packet = av_packet_alloc();
+    AVFrame* frame = av_frame_alloc();
+    std::vector<uint8_t> resampled_buffer;
+    while (av_read_frame(formatContext, packet) >= 0) {
+        if (packet->stream_index == audioStreamIndex) {
+            if (avcodec_send_packet(codecContext, packet) >= 0) {
+                while (avcodec_receive_frame(codecContext, frame) >= 0) {
+                    // Resample audio if necessary
+                    int in_samples = frame->nb_samples;
+                    uint8_t **in_data = frame->extended_data;
+                    int out_samples = av_rescale_rnd(in_samples,
+                                                    16000,
+                                                    codecContext->sample_rate,
+                                                    AV_ROUND_DOWN);
+                    
+                    int resampled_size = out_samples * av_get_bytes_per_sample(AV_SAMPLE_FMT_S16);
+                    if (resampled_buffer.size() < resampled_size) {
+                        resampled_buffer.resize(resampled_size);
+                    }                    
+                    uint8_t *resampled_data = resampled_buffer.data();
+                    int ret = swr_convert(
+                        swr_ctx,
+                        &resampled_data, // output buffer
+                        resampled_size, // output buffer size
+                        (const uint8_t **)(frame->data), //(const uint8_t **)(frame->extended_data)
+                        in_samples // input buffer size
+                    );
+                    if (ret < 0) {
+                        std::cerr << "Error resampling audio" << std::endl;
+                        break;
+                    }
+                    fwrite(resampled_buffer.data(), sizeof(int8_t), resampled_size, out_file);
+                }
+            }
+        }
+        av_packet_unref(packet);
+    }
+    fclose(out_file);
+
+    avio_context_free(&avio_ctx);
+    avformat_close_input(&formatContext);
+    avformat_free_context(formatContext);
+    avcodec_free_context(&codecContext);
+    swr_free(&swr_ctx);
+    av_packet_free(&packet);
+    av_frame_free(&frame);
+}
diff --git a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
index 82668f8..caa8605 100644
--- a/funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
+++ b/funasr/runtime/onnxruntime/bin/funasr-onnx-offline.cpp
@@ -91,11 +91,8 @@
     vector<string> wav_ids;
     string default_id = "wav_default_id";
     string wav_path_ = model_path.at(WAV_PATH);
-    if(is_target_file(wav_path_, "wav") || is_target_file(wav_path_, "pcm")){
-        wav_list.emplace_back(wav_path_);
-        wav_ids.emplace_back(default_id);
-    }
-    else if(is_target_file(wav_path_, "scp")){
+
+    if(is_target_file(wav_path_, "scp")){
         ifstream in(wav_path_);
         if (!in.is_open()) {
             LOG(ERROR) << "Failed to open file: " << model_path.at(WAV_SCP) ;
@@ -112,8 +109,8 @@
         }
         in.close();
     }else{
-        LOG(ERROR)<<"Please check the wav extension!";
-        exit(-1);
+        wav_list.emplace_back(wav_path_);
+        wav_ids.emplace_back(default_id);
     }
     
     float snippet_time = 0.0f;
diff --git a/funasr/runtime/onnxruntime/include/audio.h b/funasr/runtime/onnxruntime/include/audio.h
index d2100a4..a1b6312 100644
--- a/funasr/runtime/onnxruntime/include/audio.h
+++ b/funasr/runtime/onnxruntime/include/audio.h
@@ -55,6 +55,9 @@
     bool LoadPcmwav(const char* buf, int n_file_len, int32_t* sampling_rate);
     bool LoadPcmwav(const char* filename, int32_t* sampling_rate);
     bool LoadPcmwav2Char(const char* filename, int32_t* sampling_rate);
+    bool LoadOthers2Char(const char* filename);
+    bool FfmpegLoad(const char *filename);
+    bool FfmpegLoad(const char* buf, int n_file_len);
     int FetchChunck(float *&dout, int len);
     int Fetch(float *&dout, int &len, int &flag);
     void Padding();
diff --git a/funasr/runtime/onnxruntime/include/funasrruntime.h b/funasr/runtime/onnxruntime/include/funasrruntime.h
index 98727bd..ddb65b9 100644
--- a/funasr/runtime/onnxruntime/include/funasrruntime.h
+++ b/funasr/runtime/onnxruntime/include/funasrruntime.h
@@ -56,7 +56,7 @@
 // ASR
 _FUNASRAPI FUNASR_HANDLE  	FunASRInit(std::map<std::string, std::string>& model_path, int thread_num);
 // buffer
-_FUNASRAPI FUNASR_RESULT	FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
+_FUNASRAPI FUNASR_RESULT	FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000, std::string wav_format="pcm");
 // file, support wav & pcm
 _FUNASRAPI FUNASR_RESULT	FunASRInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
 
@@ -70,7 +70,7 @@
 _FUNASRAPI FUNASR_HANDLE  	FsmnVadInit(std::map<std::string, std::string>& model_path, int thread_num);
 _FUNASRAPI FUNASR_HANDLE  	FsmnVadOnlineInit(FUNASR_HANDLE fsmnvad_handle);
 // buffer
-_FUNASRAPI FUNASR_RESULT	FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, QM_CALLBACK fn_callback, bool input_finished=true, int sampling_rate=16000);
+_FUNASRAPI FUNASR_RESULT	FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, QM_CALLBACK fn_callback, bool input_finished=true, int sampling_rate=16000, std::string wav_format="pcm");
 // file, support wav & pcm
 _FUNASRAPI FUNASR_RESULT	FsmnVadInfer(FUNASR_HANDLE handle, const char* sz_filename, QM_CALLBACK fn_callback, int sampling_rate=16000);
 
@@ -89,7 +89,7 @@
 //OfflineStream
 _FUNASRAPI FUNASR_HANDLE  	FunOfflineInit(std::map<std::string, std::string>& model_path, int thread_num);
 // buffer
-_FUNASRAPI FUNASR_RESULT	FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
+_FUNASRAPI FUNASR_RESULT	FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000, std::string wav_format="pcm");
 // file, support wav & pcm
 _FUNASRAPI FUNASR_RESULT	FunOfflineInfer(FUNASR_HANDLE handle, const char* sz_filename, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate=16000);
 _FUNASRAPI void				FunOfflineUninit(FUNASR_HANDLE handle);
diff --git a/funasr/runtime/onnxruntime/src/CMakeLists.txt b/funasr/runtime/onnxruntime/src/CMakeLists.txt
index d083d8e..c781ef0 100644
--- a/funasr/runtime/onnxruntime/src/CMakeLists.txt
+++ b/funasr/runtime/onnxruntime/src/CMakeLists.txt
@@ -15,8 +15,9 @@
     
     target_compile_definitions(funasr PUBLIC -D_FUNASR_API_EXPORT)
 else()
-    set(EXTRA_LIBS pthread yaml-cpp csrc glog )
+    set(EXTRA_LIBS pthread yaml-cpp csrc glog avutil avcodec avformat swresample)
     include_directories(${ONNXRUNTIME_DIR}/include)
+    include_directories(${FFMPEG_DIR}/include)
 endif()
 
 include_directories(${CMAKE_SOURCE_DIR}/include)
diff --git a/funasr/runtime/onnxruntime/src/audio.cpp b/funasr/runtime/onnxruntime/src/audio.cpp
index 23d0010..85633b7 100644
--- a/funasr/runtime/onnxruntime/src/audio.cpp
+++ b/funasr/runtime/onnxruntime/src/audio.cpp
@@ -9,6 +9,15 @@
 #include "audio.h"
 #include "precomp.h"
 
+extern "C" {
+#include <libavutil/opt.h>
+#include <libavcodec/avcodec.h>
+#include <libavformat/avformat.h>
+#include <libavutil/channel_layout.h>
+#include <libavutil/samplefmt.h>
+#include <libswresample/swresample.h>
+}
+
 using namespace std;
 
 namespace funasr {
@@ -220,6 +229,334 @@
     memset(speech_data, 0, sizeof(float) * speech_len);
     copy(samples.begin(), samples.end(), speech_data);
 }
+
+bool Audio::FfmpegLoad(const char *filename){
+    // from file
+    AVFormatContext* formatContext = avformat_alloc_context();
+    if (avformat_open_input(&formatContext, filename, NULL, NULL) != 0) {
+        printf("Error: Could not open input file.");
+        avformat_close_input(&formatContext);
+        avformat_free_context(formatContext);
+        return false;
+    }
+
+    if (avformat_find_stream_info(formatContext, NULL) < 0) {
+        printf("Error: Could not find stream information.");
+        avformat_close_input(&formatContext);
+        avformat_free_context(formatContext);
+        return false;
+    }
+    const AVCodec* codec = NULL;
+    AVCodecParameters* codecParameters = NULL;
+    int audioStreamIndex = av_find_best_stream(formatContext, AVMEDIA_TYPE_AUDIO, -1, -1, &codec, 0);
+    if (audioStreamIndex >= 0) {
+        codecParameters = formatContext->streams[audioStreamIndex]->codecpar;
+    }
+    AVCodecContext* codecContext = avcodec_alloc_context3(codec);
+    if (!codecContext) {
+        fprintf(stderr, "Failed to allocate codec context\n");
+        avformat_close_input(&formatContext);
+        avformat_free_context(formatContext);
+        return false;
+    }
+    if (avcodec_parameters_to_context(codecContext, codecParameters) != 0) {
+        printf("Error: Could not copy codec parameters to codec context.");
+        avformat_close_input(&formatContext);
+        avformat_free_context(formatContext);
+        avcodec_free_context(&codecContext);
+        return false;
+    }
+    if (avcodec_open2(codecContext, codec, NULL) < 0) {
+        printf("Error: Could not open audio decoder.");
+        avformat_close_input(&formatContext);
+        avformat_free_context(formatContext);
+        avcodec_free_context(&codecContext);
+        return false;
+    }
+    SwrContext *swr_ctx = swr_alloc_set_opts(
+        nullptr, // allocate a new context
+        AV_CH_LAYOUT_MONO, // output channel layout (stereo)
+        AV_SAMPLE_FMT_S16, // output sample format (signed 16-bit)
+        16000, // output sample rate (same as input)
+        av_get_default_channel_layout(codecContext->channels), // input channel layout
+        codecContext->sample_fmt, // input sample format
+        codecContext->sample_rate, // input sample rate
+        0, // logging level
+        nullptr // parent context
+    );
+    if (swr_ctx == nullptr) {
+        std::cerr << "Could not initialize resampler" << std::endl;
+        avformat_close_input(&formatContext);
+        avformat_free_context(formatContext);
+        avcodec_free_context(&codecContext);
+        return false;
+    }
+    if (swr_init(swr_ctx) != 0) {
+        std::cerr << "Could not initialize resampler" << std::endl;
+        avformat_close_input(&formatContext);
+        avformat_free_context(formatContext);
+        avcodec_free_context(&codecContext);
+        swr_free(&swr_ctx);
+        return false;
+    }
+
+    // to pcm
+    AVPacket* packet = av_packet_alloc();
+    AVFrame* frame = av_frame_alloc();
+    std::vector<uint8_t> resampled_buffers;
+    while (av_read_frame(formatContext, packet) >= 0) {
+        if (packet->stream_index == audioStreamIndex) {
+            if (avcodec_send_packet(codecContext, packet) >= 0) {
+                while (avcodec_receive_frame(codecContext, frame) >= 0) {
+                    // Resample audio if necessary
+                    std::vector<uint8_t> resampled_buffer;
+                    int in_samples = frame->nb_samples;
+                    uint8_t **in_data = frame->extended_data;
+                    int out_samples = av_rescale_rnd(in_samples,
+                                                    16000,
+                                                    codecContext->sample_rate,
+                                                    AV_ROUND_DOWN);
+                    
+                    int resampled_size = out_samples * av_get_bytes_per_sample(AV_SAMPLE_FMT_S16);
+                    if (resampled_buffer.size() < resampled_size) {
+                        resampled_buffer.resize(resampled_size);
+                    }                    
+                    uint8_t *resampled_data = resampled_buffer.data();
+                    int ret = swr_convert(
+                        swr_ctx,
+                        &resampled_data, // output buffer
+                        resampled_size, // output buffer size
+                        (const uint8_t **)(frame->data), //(const uint8_t **)(frame->extended_data)
+                        in_samples // input buffer size
+                    );
+                    if (ret < 0) {
+                        std::cerr << "Error resampling audio" << std::endl;
+                        break;
+                    }
+                    std::copy(resampled_buffer.begin(), resampled_buffer.end(), std::back_inserter(resampled_buffers));
+                }
+            }
+        }
+        av_packet_unref(packet);
+    }
+
+    avformat_close_input(&formatContext);
+    avformat_free_context(formatContext);
+    avcodec_free_context(&codecContext);
+    swr_free(&swr_ctx);
+    av_packet_free(&packet);
+    av_frame_free(&frame);
+
+    if (speech_data != NULL) {
+        free(speech_data);
+    }
+    if (speech_buff != NULL) {
+        free(speech_buff);
+    }
+    offset = 0;
+    
+    speech_len = (resampled_buffers.size()) / 2;
+    speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
+    if (speech_buff)
+    {
+        memset(speech_buff, 0, sizeof(int16_t) * speech_len);
+        memcpy((void*)speech_buff, (const void*)resampled_buffers.data(), speech_len * sizeof(int16_t));
+
+        speech_data = (float*)malloc(sizeof(float) * speech_len);
+        memset(speech_data, 0, sizeof(float) * speech_len);
+
+        float scale = 1;
+        if (data_type == 1) {
+            scale = 32768;
+        }
+        for (int32_t i = 0; i != speech_len; ++i) {
+            speech_data[i] = (float)speech_buff[i] / scale;
+        }
+
+        AudioFrame* frame = new AudioFrame(speech_len);
+        frame_queue.push(frame);
+    
+        return true;
+    }
+    else
+        return false;
+    
+}
+
+bool Audio::FfmpegLoad(const char* buf, int n_file_len){
+    // from buf
+    char* buf_copy = (char *)malloc(n_file_len);
+    memcpy(buf_copy, buf, n_file_len);
+
+    AVIOContext* avio_ctx = avio_alloc_context(
+        (unsigned char*)buf_copy, // buffer
+        n_file_len, // buffer size
+        0, // write flag (0 for read-only)
+        nullptr, // opaque pointer (not used here)
+        nullptr, // read callback (not used here)
+        nullptr, // write callback (not used here)
+        nullptr // seek callback (not used here)
+    );
+    AVFormatContext* formatContext = avformat_alloc_context();
+    formatContext->pb = avio_ctx;
+    if (avformat_open_input(&formatContext, "", NULL, NULL) != 0) {
+        printf("Error: Could not open input file.");
+        avio_context_free(&avio_ctx);
+        avformat_close_input(&formatContext);
+        avformat_free_context(formatContext);
+        return false;
+    }
+
+    if (avformat_find_stream_info(formatContext, NULL) < 0) {
+        printf("Error: Could not find stream information.");
+        avio_context_free(&avio_ctx);
+        avformat_close_input(&formatContext);
+        avformat_free_context(formatContext);
+        return false;
+    }
+    const AVCodec* codec = NULL;
+    AVCodecParameters* codecParameters = NULL;
+    int audioStreamIndex = av_find_best_stream(formatContext, AVMEDIA_TYPE_AUDIO, -1, -1, &codec, 0);
+    if (audioStreamIndex >= 0) {
+        codecParameters = formatContext->streams[audioStreamIndex]->codecpar;
+    }
+    AVCodecContext* codecContext = avcodec_alloc_context3(codec);
+    if (!codecContext) {
+        fprintf(stderr, "Failed to allocate codec context\n");
+        avio_context_free(&avio_ctx);
+        avformat_close_input(&formatContext);
+        avformat_free_context(formatContext);
+        return false;
+    }
+    if (avcodec_parameters_to_context(codecContext, codecParameters) != 0) {
+        printf("Error: Could not copy codec parameters to codec context.");
+        avio_context_free(&avio_ctx);
+        avformat_close_input(&formatContext);
+        avformat_free_context(formatContext);
+        avcodec_free_context(&codecContext);
+        return false;
+    }
+    if (avcodec_open2(codecContext, codec, NULL) < 0) {
+        printf("Error: Could not open audio decoder.");
+        avio_context_free(&avio_ctx);
+        avformat_close_input(&formatContext);
+        avformat_free_context(formatContext);
+        avcodec_free_context(&codecContext);
+        return false;
+    }
+    SwrContext *swr_ctx = swr_alloc_set_opts(
+        nullptr, // allocate a new context
+        AV_CH_LAYOUT_MONO, // output channel layout (stereo)
+        AV_SAMPLE_FMT_S16, // output sample format (signed 16-bit)
+        16000, // output sample rate (same as input)
+        av_get_default_channel_layout(codecContext->channels), // input channel layout
+        codecContext->sample_fmt, // input sample format
+        codecContext->sample_rate, // input sample rate
+        0, // logging level
+        nullptr // parent context
+    );
+    if (swr_ctx == nullptr) {
+        std::cerr << "Could not initialize resampler" << std::endl;
+        avio_context_free(&avio_ctx);
+        avformat_close_input(&formatContext);
+        avformat_free_context(formatContext);
+        avcodec_free_context(&codecContext);
+        return false;
+    }
+    if (swr_init(swr_ctx) != 0) {
+        std::cerr << "Could not initialize resampler" << std::endl;
+        avio_context_free(&avio_ctx);
+        avformat_close_input(&formatContext);
+        avformat_free_context(formatContext);
+        avcodec_free_context(&codecContext);
+        swr_free(&swr_ctx);
+        return false;
+    }
+
+    // to pcm
+    AVPacket* packet = av_packet_alloc();
+    AVFrame* frame = av_frame_alloc();
+    std::vector<uint8_t> resampled_buffers;
+    while (av_read_frame(formatContext, packet) >= 0) {
+        if (packet->stream_index == audioStreamIndex) {
+            if (avcodec_send_packet(codecContext, packet) >= 0) {
+                while (avcodec_receive_frame(codecContext, frame) >= 0) {
+                    // Resample audio if necessary
+                    std::vector<uint8_t> resampled_buffer;
+                    int in_samples = frame->nb_samples;
+                    uint8_t **in_data = frame->extended_data;
+                    int out_samples = av_rescale_rnd(in_samples,
+                                                    16000,
+                                                    codecContext->sample_rate,
+                                                    AV_ROUND_DOWN);
+                    
+                    int resampled_size = out_samples * av_get_bytes_per_sample(AV_SAMPLE_FMT_S16);
+                    if (resampled_buffer.size() < resampled_size) {
+                        resampled_buffer.resize(resampled_size);
+                    }                    
+                    uint8_t *resampled_data = resampled_buffer.data();
+                    int ret = swr_convert(
+                        swr_ctx,
+                        &resampled_data, // output buffer
+                        resampled_size, // output buffer size
+                        (const uint8_t **)(frame->data), //(const uint8_t **)(frame->extended_data)
+                        in_samples // input buffer size
+                    );
+                    if (ret < 0) {
+                        std::cerr << "Error resampling audio" << std::endl;
+                        break;
+                    }
+                    std::copy(resampled_buffer.begin(), resampled_buffer.end(), std::back_inserter(resampled_buffers));
+                }
+            }
+        }
+        av_packet_unref(packet);
+    }
+
+    avio_context_free(&avio_ctx);
+    avformat_close_input(&formatContext);
+    avformat_free_context(formatContext);
+    avcodec_free_context(&codecContext);
+    swr_free(&swr_ctx);
+    av_packet_free(&packet);
+    av_frame_free(&frame);
+
+    if (speech_data != NULL) {
+        free(speech_data);
+    }
+    if (speech_buff != NULL) {
+        free(speech_buff);
+    }
+    offset = 0;
+
+    speech_len = (resampled_buffers.size()) / 2;
+    speech_buff = (int16_t*)malloc(sizeof(int16_t) * speech_len);
+    if (speech_buff)
+    {
+        memset(speech_buff, 0, sizeof(int16_t) * speech_len);
+        memcpy((void*)speech_buff, (const void*)resampled_buffers.data(), speech_len * sizeof(int16_t));
+
+        speech_data = (float*)malloc(sizeof(float) * speech_len);
+        memset(speech_data, 0, sizeof(float) * speech_len);
+
+        float scale = 1;
+        if (data_type == 1) {
+            scale = 32768;
+        }
+        for (int32_t i = 0; i != speech_len; ++i) {
+            speech_data[i] = (float)speech_buff[i] / scale;
+        }
+
+        AudioFrame* frame = new AudioFrame(speech_len);
+        frame_queue.push(frame);
+    
+        return true;
+    }
+    else
+        return false;
+    
+}
+
 
 bool Audio::LoadWav(const char *filename, int32_t* sampling_rate)
 {
@@ -507,6 +844,32 @@
     return true;
 }
 
+bool Audio::LoadOthers2Char(const char* filename)
+{
+    if (speech_char != NULL) {
+        free(speech_char);
+    }
+
+    FILE* fp;
+    fp = fopen(filename, "rb");
+    if (fp == nullptr)
+	{
+        LOG(ERROR) << "Failed to read " << filename;
+        return false;
+	}
+    fseek(fp, 0, SEEK_END);
+    uint32_t n_file_len = ftell(fp);
+    fseek(fp, 0, SEEK_SET);
+
+    speech_len = n_file_len;
+    speech_char = (char *)malloc(n_file_len);
+    memset(speech_char, 0, n_file_len);
+    fread(speech_char, 1, n_file_len, fp);
+    fclose(fp);
+    
+    return true;
+}
+
 int Audio::FetchChunck(float *&dout, int len)
 {
     if (offset >= speech_align_len) {
diff --git a/funasr/runtime/onnxruntime/src/funasrruntime.cpp b/funasr/runtime/onnxruntime/src/funasrruntime.cpp
index 82fdd70..a1829fd 100644
--- a/funasr/runtime/onnxruntime/src/funasrruntime.cpp
+++ b/funasr/runtime/onnxruntime/src/funasrruntime.cpp
@@ -36,15 +36,20 @@
 	}
 
 	// APIs for ASR Infer
-	_FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate)
+	_FUNASRAPI FUNASR_RESULT FunASRInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate, std::string wav_format)
 	{
 		funasr::Model* recog_obj = (funasr::Model*)handle;
 		if (!recog_obj)
 			return nullptr;
 
 		funasr::Audio audio(1);
-		if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
-			return nullptr;
+		if(wav_format == "pcm" || wav_format == "PCM"){
+			if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
+				return nullptr;
+		}else{
+			if (!audio.FfmpegLoad(sz_buf, n_len))
+				return nullptr;
+		}
 
 		float* buff;
 		int len;
@@ -82,8 +87,8 @@
 			if (!audio.LoadPcmwav(sz_filename, &sampling_rate))
 				return nullptr;
 		}else{
-			LOG(ERROR)<<"Wrong wav extension";
-			exit(-1);
+			if (!audio.FfmpegLoad(sz_filename))
+				return nullptr;
 		}
 
 		float* buff;
@@ -108,15 +113,20 @@
 	}
 
 	// APIs for VAD Infer
-	_FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, QM_CALLBACK fn_callback, bool input_finished, int sampling_rate)
+	_FUNASRAPI FUNASR_RESULT FsmnVadInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, QM_CALLBACK fn_callback, bool input_finished, int sampling_rate, std::string wav_format)
 	{
 		funasr::VadModel* vad_obj = (funasr::VadModel*)handle;
 		if (!vad_obj)
 			return nullptr;
 
 		funasr::Audio audio(1);
-		if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
-			return nullptr;
+		if(wav_format == "pcm" || wav_format == "PCM"){
+			if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
+				return nullptr;
+		}else{
+			if (!audio.FfmpegLoad(sz_buf, n_len))
+				return nullptr;
+		}
 
 		funasr::FUNASR_VAD_RESULT* p_result = new funasr::FUNASR_VAD_RESULT;
 		p_result->snippet_time = audio.GetTimeLen();
@@ -146,8 +156,8 @@
 			if (!audio.LoadPcmwav(sz_filename, &sampling_rate))
 				return nullptr;
 		}else{
-			LOG(ERROR)<<"Wrong wav extension";
-			exit(-1);
+			if (!audio.FfmpegLoad(sz_filename))
+				return nullptr;
 		}
 
 		funasr::FUNASR_VAD_RESULT* p_result = new funasr::FUNASR_VAD_RESULT;
@@ -189,15 +199,21 @@
 	}
 
 	// APIs for Offline-stream Infer
-	_FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate)
+	_FUNASRAPI FUNASR_RESULT FunOfflineInferBuffer(FUNASR_HANDLE handle, const char* sz_buf, int n_len, FUNASR_MODE mode, QM_CALLBACK fn_callback, int sampling_rate, std::string wav_format)
 	{
 		funasr::OfflineStream* offline_stream = (funasr::OfflineStream*)handle;
 		if (!offline_stream)
 			return nullptr;
 
 		funasr::Audio audio(1);
-		if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
-			return nullptr;
+		if(wav_format == "pcm" || wav_format == "PCM"){
+			if (!audio.LoadPcmwav(sz_buf, n_len, &sampling_rate))
+				return nullptr;
+		}else{
+			if (!audio.FfmpegLoad(sz_buf, n_len))
+				return nullptr;
+		}
+
 		funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
 		p_result->snippet_time = audio.GetTimeLen();
 		if(p_result->snippet_time == 0){
@@ -243,8 +259,8 @@
 			if (!audio.LoadPcmwav(sz_filename, &sampling_rate))
 				return nullptr;
 		}else{
-			LOG(ERROR)<<"Wrong wav extension";
-			exit(-1);
+			if (!audio.FfmpegLoad(sz_filename))
+				return nullptr;
 		}
 		funasr::FUNASR_RECOG_RESULT* p_result = new funasr::FUNASR_RECOG_RESULT;
 		p_result->snippet_time = audio.GetTimeLen();
diff --git a/funasr/runtime/websocket/CMakeLists.txt b/funasr/runtime/websocket/CMakeLists.txt
index 51b0795..06ae59b 100644
--- a/funasr/runtime/websocket/CMakeLists.txt
+++ b/funasr/runtime/websocket/CMakeLists.txt
@@ -40,6 +40,7 @@
 
 # Include generated *.pb.h files
 link_directories(${ONNXRUNTIME_DIR}/lib)
+link_directories(${FFMPEG_DIR}/lib)
 
 include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/include/)
 include_directories(${PROJECT_SOURCE_DIR}/../onnxruntime/third_party/yaml-cpp/include/)
diff --git a/funasr/runtime/websocket/funasr-wss-client.cpp b/funasr/runtime/websocket/funasr-wss-client.cpp
index b6d69f2..231303f 100644
--- a/funasr/runtime/websocket/funasr-wss-client.cpp
+++ b/funasr/runtime/websocket/funasr-wss-client.cpp
@@ -187,6 +187,7 @@
 
 		funasr::Audio audio(1);
         int32_t sampling_rate = 16000;
+        std::string wav_format = "pcm";
 		if(IsTargetFile(wav_path.c_str(), "wav")){
 			int32_t sampling_rate = -1;
 			if(!audio.LoadWav(wav_path.c_str(), &sampling_rate))
@@ -195,8 +196,9 @@
 			if (!audio.LoadPcmwav(wav_path.c_str(), &sampling_rate))
 				return ;
 		}else{
-			printf("Wrong wav extension");
-			exit(-1);
+			wav_format = "others";
+            if (!audio.LoadOthers2Char(wav_path.c_str()))
+				return ;
 		}
 
         float* buff;
@@ -233,20 +235,54 @@
         jsonbegin["chunk_size"] = chunk_size;
         jsonbegin["chunk_interval"] = 10;
         jsonbegin["wav_name"] = wav_id;
+        jsonbegin["wav_format"] = wav_format;
         jsonbegin["is_speaking"] = true;
         m_client.send(m_hdl, jsonbegin.dump(), websocketpp::frame::opcode::text,
                       ec);
 
         // fetch wav data use asr engine api
-        while (audio.Fetch(buff, len, flag) > 0) {
-            short* iArray = new short[len];
-            for (size_t i = 0; i < len; ++i) {
-              iArray[i] = (short)(buff[i]*32768);
-            }
+        if(wav_format == "pcm"){
+            while (audio.Fetch(buff, len, flag) > 0) {
+                short* iArray = new short[len];
+                for (size_t i = 0; i < len; ++i) {
+                iArray[i] = (short)(buff[i]*32768);
+                }
 
-            // send data to server
+                // send data to server
+                int offset = 0;
+                int block_size = 102400;
+                while(offset < len){
+                    int send_block = 0;
+                    if (offset + block_size <= len){
+                        send_block = block_size;
+                    }else{
+                        send_block = len - offset;
+                    }
+                    m_client.send(m_hdl, iArray+offset, send_block * sizeof(short),
+                        websocketpp::frame::opcode::binary, ec);
+                    offset += send_block;
+                }
+
+                LOG(INFO) << "sended data len=" << len * sizeof(short);
+                // The most likely error that we will get is that the connection is
+                // not in the right state. Usually this means we tried to send a
+                // message to a connection that was closed or in the process of
+                // closing. While many errors here can be easily recovered from,
+                // in this simple example, we'll stop the data loop.
+                if (ec) {
+                m_client.get_alog().write(websocketpp::log::alevel::app,
+                                            "Send Error: " + ec.message());
+                break;
+                }
+                delete[] iArray;
+                // WaitABit();
+            }
+        }else{
             int offset = 0;
-            int block_size = 102400;
+            int block_size = 204800;
+            len = audio.GetSpeechLen();
+            char* others_buff = audio.GetSpeechChar();
+
             while(offset < len){
                 int send_block = 0;
                 if (offset + block_size <= len){
@@ -254,25 +290,23 @@
                 }else{
                     send_block = len - offset;
                 }
-                m_client.send(m_hdl, iArray+offset, send_block * sizeof(short),
+                m_client.send(m_hdl, others_buff+offset, send_block,
                     websocketpp::frame::opcode::binary, ec);
                 offset += send_block;
             }
 
-            LOG(INFO) << "sended data len=" << len * sizeof(short);
+            LOG(INFO) << "sended data len=" << len;
             // The most likely error that we will get is that the connection is
             // not in the right state. Usually this means we tried to send a
             // message to a connection that was closed or in the process of
             // closing. While many errors here can be easily recovered from,
             // in this simple example, we'll stop the data loop.
             if (ec) {
-              m_client.get_alog().write(websocketpp::log::alevel::app,
+                m_client.get_alog().write(websocketpp::log::alevel::app,
                                         "Send Error: " + ec.message());
-              break;
             }
-            delete[] iArray;
-            // WaitABit();
         }
+
         nlohmann::json jsonresult;
         jsonresult["is_speaking"] = false;
         m_client.send(m_hdl, jsonresult.dump(), websocketpp::frame::opcode::text,
@@ -332,11 +366,7 @@
     std::vector<string> wav_list;
     std::vector<string> wav_ids;
     string default_id = "wav_default_id";
-    if(IsTargetFile(wav_path, "wav") || IsTargetFile(wav_path, "pcm")){
-        wav_list.emplace_back(wav_path);
-        wav_ids.emplace_back(default_id);
-    }
-    else if(IsTargetFile(wav_path, "scp")){
+    if(IsTargetFile(wav_path, "scp")){
         ifstream in(wav_path);
         if (!in.is_open()) {
             printf("Failed to open scp file");
@@ -353,8 +383,8 @@
         }
         in.close();
     }else{
-        printf("Please check the wav extension!");
-        exit(-1);
+        wav_list.emplace_back(wav_path);
+        wav_ids.emplace_back(default_id);
     }
     
     for (size_t i = 0; i < threads_num; i++) {
diff --git a/funasr/runtime/websocket/readme.md b/funasr/runtime/websocket/readme.md
index b67a905..291126b 100644
--- a/funasr/runtime/websocket/readme.md
+++ b/funasr/runtime/websocket/readme.md
@@ -32,6 +32,15 @@
 tar -zxvf onnxruntime-linux-x64-1.14.0.tgz
 ```
 
+### Download ffmpeg
+```shell
+wget https://github.com/BtbN/FFmpeg-Builds/releases/download/autobuild-2023-07-09-12-50/ffmpeg-N-111383-g20b8688092-linux64-gpl-shared.tar.xz
+tar -xvf ffmpeg-N-111383-g20b8688092-linux64-gpl-shared.tar.xz
+# 鍥藉唴鍙互浣跨敤涓嬭堪鏂瑰紡
+# wget https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/dep_libs/ffmpeg-N-111383-g20b8688092-linux64-gpl-shared.tar.xz
+# tar -xvf ffmpeg-N-111383-g20b8688092-linux64-gpl-shared.tar.xz
+```
+
 ### Install openblas
 ```shell
 sudo apt-get install libopenblas-dev #ubuntu
@@ -48,7 +57,7 @@
 
 git clone https://github.com/alibaba-damo-academy/FunASR.git && cd FunASR/funasr/runtime/websocket
 mkdir build && cd build
-cmake  -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0
+cmake  -DCMAKE_BUILD_TYPE=release .. -DONNXRUNTIME_DIR=/path/to/onnxruntime-linux-x64-1.14.0 --DFFMPEG_DIR=/path/to/ffmpeg-N-111383-g20b8688092-linux64-gpl-shared
 make
 ```
 ## Run the websocket server
diff --git a/funasr/runtime/websocket/websocket-server.cpp b/funasr/runtime/websocket/websocket-server.cpp
index a311c23..59109b3 100644
--- a/funasr/runtime/websocket/websocket-server.cpp
+++ b/funasr/runtime/websocket/websocket-server.cpp
@@ -61,13 +61,13 @@
     int num_samples = buffer.size();  // the size of the buf
 
     if (!buffer.empty()) {
-      // fout.write(buffer.data(), buffer.size());
       // feed data to asr engine
       FUNASR_RESULT Result = FunOfflineInferBuffer(
-          asr_hanlde, buffer.data(), buffer.size(), RASR_NONE, NULL, 16000);
+          asr_hanlde, buffer.data(), buffer.size(), RASR_NONE, NULL, 16000, msg["wav_format"]);
 
       std::string asr_result =
           ((FUNASR_RECOG_RESULT*)Result)->msg;  // get decode result
+      FunASRFreeResult(Result);
 
       websocketpp::lib::error_code ec;
       nlohmann::json jsonresult;        // result json
@@ -107,6 +107,7 @@
                                            // connection
   data_msg->samples = std::make_shared<std::vector<char>>();
   data_msg->msg = nlohmann::json::parse("{}");
+  data_msg->msg["wav_format"] = "pcm";
   data_map.emplace(hdl, data_msg);
   LOG(INFO) << "on_open, active connections: " << data_map.size();
 }
@@ -171,6 +172,9 @@
       if (jsonresult["wav_name"] != nullptr) {
         msg_data->msg["wav_name"] = jsonresult["wav_name"];
       }
+      if (jsonresult["wav_format"] != nullptr) {
+        msg_data->msg["wav_format"] = jsonresult["wav_format"];
+      }
 
       if (jsonresult["is_speaking"] == false ||
           jsonresult["is_finished"] == true) {
@@ -180,9 +184,9 @@
           // do_close(ws);
         } else {
           // add padding to the end of the wav data
-          std::vector<short> padding(static_cast<short>(0.3 * 16000));
-          sample_data_p->insert(sample_data_p->end(), padding.data(),
-                                padding.data() + padding.size());
+          // std::vector<short> padding(static_cast<short>(0.3 * 16000));
+          // sample_data_p->insert(sample_data_p->end(), padding.data(),
+          //                       padding.data() + padding.size());
           // for offline, send all receive data to decoder engine
           asio::post(io_decoder_,
                      std::bind(&WebSocketServer::do_decoder, this,

--
Gitblit v1.9.1