zhifu gao
2023-03-23 d50d98b0a9675beeaca1c8fdad0264f4334af8f1
Merge pull request #287 from alibaba-damo-academy/dev_gzf

Dev gzf
3个文件已修改
55 ■■■■■ 已修改文件
egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py 7 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/bin/punctuation_infer_vadrealtime.py 30 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/runtime/python/websocket/ASR_server.py 18 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
egs_modelscope/punctuation/punc_ct-transformer_zh-cn-common-vadrealtime-vocab272727/infer.py
@@ -17,13 +17,10 @@
)
vads = inputs.split("|")
cache_out = []
rec_result_all="outputs:"
param_dict = {"cache": []}
for vad in vads:
    rec_result = inference_pipeline(text_in=vad, cache=cache_out)
    #print(rec_result)
    cache_out = rec_result['cache']
    rec_result = inference_pipeline(text_in=vad, param_dict=param_dict)
    rec_result_all += rec_result['text']
print(rec_result_all)
funasr/bin/punctuation_infer_vadrealtime.py
@@ -226,7 +226,7 @@
    ):
        results = []
        split_size = 10
        cache_in = param_dict["cache"]
        if raw_inputs != None:
            line = raw_inputs.strip()
            key = "demo"
@@ -234,34 +234,12 @@
                item = {'key': key, 'value': ""}
                results.append(item)
                return results
            result, _, cache = text2punc(line, cache)
            item = {'key': key, 'value': result, 'cache': cache}
            result, _, cache = text2punc(line, cache_in)
            param_dict["cache"] = cache
            item = {'key': key, 'value': result}
            results.append(item)
            return results
        for inference_text, _, _ in data_path_and_name_and_type:
            with open(inference_text, "r", encoding="utf-8") as fin:
                for line in fin:
                    line = line.strip()
                    segs = line.split("\t")
                    if len(segs) != 2:
                        continue
                    key = segs[0]
                    if len(segs[1]) == 0:
                        continue
                    result, _ = text2punc(segs[1])
                    item = {'key': key, 'value': result}
                    results.append(item)
        output_path = output_dir_v2 if output_dir_v2 is not None else output_dir
        if output_path != None:
            output_file_name = "infer.out"
            Path(output_path).mkdir(parents=True, exist_ok=True)
            output_file_path = (Path(output_path) / output_file_name).absolute()
            with open(output_file_path, "w", encoding="utf-8") as fout:
                for item_i in results:
                    key_out = item_i["key"]
                    value_out = item_i["value"]
                    fout.write(f"{key_out}\t{value_out}\n")
        return results
    return _forward
funasr/runtime/python/websocket/ASR_server.py
@@ -53,7 +53,7 @@
inference_pipeline_vad = pipeline(
    task=Tasks.voice_activity_detection,
    model=args.vad_model,
    model_revision="v1.2.0",
    model_revision=None,
    output_dir=None,
    batch_size=1,
    mode='online',
@@ -62,7 +62,7 @@
param_dict_vad = {'in_cache': dict(), "is_final": False}
  
# asr
param_dict_asr = dict()
param_dict_asr = {}
# param_dict["hotword"] = "小五 小五月"  # 设置热词,用空格隔开
inference_pipeline_asr = pipeline(
    task=Tasks.auto_speech_recognition,
@@ -71,10 +71,11 @@
    ngpu=args.ngpu,
)
inference_pipline_punc = pipeline(
param_dict_punc = {'cache': list()}
inference_pipeline_punc = pipeline(
    task=Tasks.punctuation,
    model=args.punc_model,
    model_revision="v1.0.1",
    model_revision=None,
    ngpu=args.ngpu,
)
@@ -116,13 +117,16 @@
def asr():  # 推理
    global inference_pipeline2
    global speek
    global speek, param_dict_punc
    while True:
        while not speek.empty():
            audio_in = speek.get()
            speek.task_done()
            rec_result = inference_pipeline_asr(audio_in=audio_in)
            print(rec_result)
            if len(audio_in) > 0:
                rec_result = inference_pipeline_asr(audio_in=audio_in)
                if 'text' in rec_result:
                    rec_result = inference_pipeline_punc(text_in=rec_result['text'], param_dict=param_dict_punc)
                print(rec_result["text"])
            time.sleep(0.1)
        time.sleep(0.1)