From ce6b70e4795fb5afb685d1fead898589c4970990 Mon Sep 17 00:00:00 2001
From: 维石 <shixian.shi@alibaba-inc.com>
Date: 星期四, 06 六月 2024 10:08:17 +0800
Subject: [PATCH] update paraformer timestamp
---
funasr/models/paraformer/model.py | 24 ++++++++++++++++++++----
funasr/utils/timestamp_tools.py | 6 +++---
2 files changed, 23 insertions(+), 7 deletions(-)
diff --git a/funasr/models/paraformer/model.py b/funasr/models/paraformer/model.py
index 0d9bb2b..85967af 100644
--- a/funasr/models/paraformer/model.py
+++ b/funasr/models/paraformer/model.py
@@ -4,6 +4,7 @@
# MIT License (https://opensource.org/licenses/MIT)
import time
+import copy
import torch
import logging
from torch.cuda.amp import autocast
@@ -21,6 +22,7 @@
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask
+from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
@@ -452,6 +454,7 @@
is_use_lm = (
kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
)
+ pred_timestamp = kwargs.get("pred_timestamp", False)
if self.beam_search is None and (is_use_lm or is_use_ctc):
logging.info("enable beam_search")
self.init_beam_search(**kwargs)
@@ -506,6 +509,7 @@
predictor_outs[2],
predictor_outs[3],
)
+
pre_token_length = pre_token_length.round().long()
if torch.max(pre_token_length) < 1:
return []
@@ -564,10 +568,22 @@
# Change integer-ids to tokens
token = tokenizer.ids2tokens(token_int)
text_postprocessed = tokenizer.tokens2text(token)
- if not hasattr(tokenizer, "bpemodel"):
- text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
-
- result_i = {"key": key[i], "text": text_postprocessed}
+
+ if pred_timestamp:
+ timestamp_str, timestamp = ts_prediction_lfr6_standard(
+ pre_peak_index[i],
+ alphas[i],
+ copy.copy(token),
+ vad_offset=kwargs.get("begin_time", 0),
+ upsample_rate=1,
+ )
+ if not hasattr(tokenizer, "bpemodel"):
+ text_postprocessed, time_stamp_postprocessed, _ = postprocess_utils.sentence_postprocess(token, timestamp)
+ result_i = {"key": key[i], "text": text_postprocessed, "timestamp": time_stamp_postprocessed,}
+ else:
+ if not hasattr(tokenizer, "bpemodel"):
+ text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
+ result_i = {"key": key[i], "text": text_postprocessed}
if ibest_writer is not None:
ibest_writer["token"][key[i]] = " ".join(token)
diff --git a/funasr/utils/timestamp_tools.py b/funasr/utils/timestamp_tools.py
index 831d773..af61e5a 100644
--- a/funasr/utils/timestamp_tools.py
+++ b/funasr/utils/timestamp_tools.py
@@ -29,13 +29,13 @@
def ts_prediction_lfr6_standard(
- us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True
+ us_alphas, us_peaks, char_list, vad_offset=0.0, force_time_shift=-1.5, sil_in_str=True, upsample_rate=3,
):
if not len(char_list):
return "", []
START_END_THRESHOLD = 5
- MAX_TOKEN_DURATION = 12
- TIME_RATE = 10.0 * 6 / 1000 / 3 # 3 times upsampled
+ MAX_TOKEN_DURATION = 12 # 3 times upsampled
+ TIME_RATE=10.0 * 6 / 1000 / upsample_rate
if len(us_alphas.shape) == 2:
alphas, peaks = us_alphas[0], us_peaks[0] # support inference batch_size=1 only
else:
--
Gitblit v1.9.1