From b4598f30a54c3a8d5e6084d983fac0fa5a51992b Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期三, 15 三月 2023 17:20:24 +0800
Subject: [PATCH] update
---
egs_modelscope/speaker_diarization/speech_diarization_eend-ola-en-us-callhome-8k/infer.py | 5 +++--
funasr/bin/eend_ola_inference.py | 9 +++++----
funasr/bin/asr_inference_launch.py | 3 +++
3 files changed, 11 insertions(+), 6 deletions(-)
diff --git a/egs_modelscope/speaker_diarization/speech_diarization_eend-ola-en-us-callhome-8k/infer.py b/egs_modelscope/speaker_diarization/speech_diarization_eend-ola-en-us-callhome-8k/infer.py
index dfcb8e6..e0ac08c 100644
--- a/egs_modelscope/speaker_diarization/speech_diarization_eend-ola-en-us-callhome-8k/infer.py
+++ b/egs_modelscope/speaker_diarization/speech_diarization_eend-ola-en-us-callhome-8k/infer.py
@@ -2,8 +2,9 @@
from modelscope.utils.constant import Tasks
inference_diar_pipline = pipeline(
- task=Tasks.speaker_diarization,
+ task=Tasks.auto_speech_recognition,
model='damo/speech_diarization_eend-ola-en-us-callhome-8k',
model_revision="v1.0.0",
)
-results = inference_diar_pipline(audio_in=["https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/record.wav"])
\ No newline at end of file
+results = inference_diar_pipline(audio_in=["https://isv-data.oss-cn-hangzhou.aliyuncs.com/ics/MaaS/ASR/test_data/record2.wav"])
+print(results)
\ No newline at end of file
diff --git a/funasr/bin/asr_inference_launch.py b/funasr/bin/asr_inference_launch.py
index 1fae766..0ab6b1a 100644
--- a/funasr/bin/asr_inference_launch.py
+++ b/funasr/bin/asr_inference_launch.py
@@ -234,6 +234,9 @@
elif mode == "rnnt":
from funasr.bin.asr_inference_rnnt import inference_modelscope
return inference_modelscope(**kwargs)
+ elif mode == "eend-ola":
+ from funasr.bin.eend_ola_inference import inference_modelscope
+ return inference_modelscope(mode=mode, **kwargs)
else:
logging.info("Unknown decoding mode: {}".format(mode))
return None
diff --git a/funasr/bin/eend_ola_inference.py b/funasr/bin/eend_ola_inference.py
index b35824a..0483278 100755
--- a/funasr/bin/eend_ola_inference.py
+++ b/funasr/bin/eend_ola_inference.py
@@ -16,8 +16,8 @@
import numpy as np
import torch
-from typeguard import check_argument_types
from scipy.signal import medfilt
+from typeguard import check_argument_types
from funasr.models.frontend.wav_frontend import WavFrontendMel23
from funasr.tasks.diar import EENDOLADiarTask
@@ -27,6 +27,7 @@
from funasr.utils.types import str2bool
from funasr.utils.types import str2triple_str
from funasr.utils.types import str_or_none
+
class Speech2Diarization:
"""Speech2Diarlization class
@@ -237,7 +238,7 @@
results = speech2diar(**batch)
# post process
- a = results[0].cpu().numpy()
+ a = results[0][0].cpu().numpy()
a = medfilt(a, (11, 1))
rst = []
for spkid, frames in enumerate(a.T):
@@ -246,8 +247,8 @@
fmt = "SPEAKER {:s} 1 {:7.2f} {:7.2f} <NA> <NA> {:s} <NA>"
for s, e in zip(changes[::2], changes[1::2]):
st = s / 10.
- ed = e / 10.
- rst.append(fmt.format(keys[0], st, ed, "{}_{}".format(keys[0],str(spkid))))
+ dur = (e - s) / 10.
+ rst.append(fmt.format(keys[0], st, dur, "{}_{}".format(keys[0], str(spkid))))
# Only supporting batch_size==1
value = "\n".join(rst)
--
Gitblit v1.9.1