From 09a28d19df5854bdd4bd4d3a05dcb6f502ec6b07 Mon Sep 17 00:00:00 2001
From: shixian.shi <shixian.shi@alibaba-inc.com>
Date: 星期五, 12 一月 2024 18:02:10 +0800
Subject: [PATCH] update
---
funasr/models/fsmn_vad_streaming/model.py | 92 ++++++++++++++--------------------------------
1 files changed, 28 insertions(+), 64 deletions(-)
diff --git a/funasr/models/fsmn_vad_streaming/model.py b/funasr/models/fsmn_vad_streaming/model.py
index 4c7e943..e0d104a 100644
--- a/funasr/models/fsmn_vad_streaming/model.py
+++ b/funasr/models/fsmn_vad_streaming/model.py
@@ -496,7 +496,7 @@
def forward(self, feats: torch.Tensor, waveform: torch.tensor, cache: Dict[str, torch.Tensor] = dict(),
is_final: bool = False
):
- if not cache:
+ if len(cache) == 0:
self.AllResetDetection()
self.waveform = waveform # compute decibel for each frame
self.ComputeDecibel()
@@ -521,13 +521,15 @@
if is_final:
# reset class variables and clear the dict for the next query
self.AllResetDetection()
- return segments, cache
+ return segments
def init_cache(self, cache: dict = {}, **kwargs):
cache["frontend"] = {}
cache["prev_samples"] = torch.empty(0)
+ cache["encoder"] = {}
return cache
+
def generate(self,
data_in,
data_lengths=None,
@@ -543,7 +545,7 @@
meta_data = {}
chunk_size = kwargs.get("chunk_size", 50) # 50ms
- chunk_stride_samples = chunk_size * 16
+ chunk_stride_samples = int(chunk_size * frontend.fs / 1000)
time1 = time.perf_counter()
cfg = {"is_final": kwargs.get("is_final", False)}
@@ -552,7 +554,7 @@
audio_fs=kwargs.get("fs", 16000),
data_type=kwargs.get("data_type", "sound"),
tokenizer=tokenizer,
- **cfg,
+ cache=cfg,
)
_is_final = cfg["is_final"] # if data_in is a file or url, set is_final=True
@@ -562,9 +564,9 @@
audio_sample = torch.cat((cache["prev_samples"], audio_sample_list[0]))
- n = len(audio_sample) // chunk_stride_samples + int(_is_final)
- m = len(audio_sample) % chunk_stride_samples * (1 - int(_is_final))
- tokens = []
+ n = int(len(audio_sample) // chunk_stride_samples + int(_is_final))
+ m = int(len(audio_sample) % chunk_stride_samples * (1 - int(_is_final)))
+ segments = []
for i in range(n):
kwargs["is_final"] = _is_final and i == n - 1
audio_sample_i = audio_sample[i * chunk_stride_samples:(i + 1) * chunk_stride_samples]
@@ -576,58 +578,21 @@
time3 = time.perf_counter()
meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
meta_data["batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
-
- meta_data = {}
- audio_sample_list = [data_in]
- if isinstance(data_in, torch.Tensor): # fbank
- speech, speech_lengths = data_in, data_lengths
- if len(speech.shape) < 3:
- speech = speech[None, :, :]
- if speech_lengths is None:
- speech_lengths = speech.shape[1]
- else:
- # extract fbank feats
- time1 = time.perf_counter()
- audio_sample_list = load_audio_text_image_video(data_in, fs=frontend.fs, audio_fs=kwargs.get("fs", 16000))
- time2 = time.perf_counter()
- meta_data["load_data"] = f"{time2 - time1:0.3f}"
- speech, speech_lengths = extract_fbank(audio_sample_list, data_type=kwargs.get("data_type", "sound"),
- frontend=frontend)
- time3 = time.perf_counter()
- meta_data["extract_feat"] = f"{time3 - time2:0.3f}"
- meta_data[
- "batch_data_time"] = speech_lengths.sum().item() * frontend.frame_shift * frontend.lfr_n / 1000
-
- speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
-
- # b. Forward Encoder streaming
- t_offset = 0
- feats = speech
- feats_len = speech_lengths.max().item()
- waveform = pad_sequence(audio_sample_list, batch_first=True).to(device=kwargs["device"]) # data: [batch, N]
- cache = kwargs.get("cache", {})
- batch_size = kwargs.get("batch_size", 1)
- step = min(feats_len, 6000)
- segments = [[]] * batch_size
-
- for t_offset in range(0, feats_len, min(step, feats_len - t_offset)):
- if t_offset + step >= feats_len - 1:
- step = feats_len - t_offset
- is_final = True
- else:
- is_final = False
+ speech.to(device=kwargs["device"]), speech_lengths.to(device=kwargs["device"])
+
batch = {
- "feats": feats[:, t_offset:t_offset + step, :],
- "waveform": waveform[:, t_offset * 160:min(waveform.shape[-1], (t_offset + step - 1) * 160 + 400)],
- "is_final": is_final,
- "cache": cache
+ "feats": speech,
+ "waveform": cache["frontend"]["waveforms"],
+ "is_final": kwargs["is_final"],
+ "cache": cache["encoder"]
}
+ segments_i = self.forward(**batch)
+ segments.extend(segments_i)
- segments_part, cache = self.forward(**batch)
- if segments_part:
- for batch_num in range(0, batch_size):
- segments[batch_num] += segments_part[batch_num]
+ cache["prev_samples"] = audio_sample[:-m]
+ if _is_final:
+ self.init_cache(cache, **kwargs)
ibest_writer = None
if ibest_writer is None and kwargs.get("output_dir") is not None:
@@ -635,16 +600,15 @@
ibest_writer = writer[f"{1}best_recog"]
results = []
- for i in range(batch_size):
-
- if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
- results[i] = json.dumps(results[i])
-
- if ibest_writer is not None:
- ibest_writer["text"][key[i]] = segments[i]
+ result_i = {"key": key[0], "value": segments}
+ if "MODELSCOPE_ENVIRONMENT" in os.environ and os.environ["MODELSCOPE_ENVIRONMENT"] == "eas":
+ result_i = json.dumps(result_i)
- result_i = {"key": key[i], "value": segments[i]}
- results.append(result_i)
+ results.append(result_i)
+
+ if ibest_writer is not None:
+ ibest_writer["text"][key[0]] = segments
+
return results, meta_data
--
Gitblit v1.9.1