zhifu gao
2024-06-11 20aa07268a7fafaaab7762b488615af32a0e82b4
funasr/models/paraformer/model.py
@@ -4,6 +4,7 @@
#  MIT License  (https://opensource.org/licenses/MIT)
import time
import copy
import torch
import logging
from torch.cuda.amp import autocast
@@ -21,6 +22,7 @@
from funasr.losses.label_smoothing_loss import LabelSmoothingLoss
from funasr.models.transformer.utils.add_sos_eos import add_sos_eos
from funasr.models.transformer.utils.nets_utils import make_pad_mask
from funasr.utils.timestamp_tools import ts_prediction_lfr6_standard
from funasr.utils.load_utils import load_audio_text_image_video, extract_fbank
@@ -452,6 +454,7 @@
        is_use_lm = (
            kwargs.get("lm_weight", 0.0) > 0.00001 and kwargs.get("lm_file", None) is not None
        )
        pred_timestamp = kwargs.get("pred_timestamp", False)
        if self.beam_search is None and (is_use_lm or is_use_ctc):
            logging.info("enable beam_search")
            self.init_beam_search(**kwargs)
@@ -506,6 +509,7 @@
            predictor_outs[2],
            predictor_outs[3],
        )
        pre_token_length = pre_token_length.round().long()
        if torch.max(pre_token_length) < 1:
            return []
@@ -564,10 +568,22 @@
                    # Change integer-ids to tokens
                    token = tokenizer.ids2tokens(token_int)
                    text_postprocessed = tokenizer.tokens2text(token)
                    if not hasattr(tokenizer, "bpemodel"):
                        text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                    result_i = {"key": key[i], "text": text_postprocessed}
                    if pred_timestamp:
                        timestamp_str, timestamp = ts_prediction_lfr6_standard(
                            pre_peak_index[i],
                            alphas[i],
                            copy.copy(token),
                            vad_offset=kwargs.get("begin_time", 0),
                            upsample_rate=1,
                        )
                        if not hasattr(tokenizer, "bpemodel"):
                            text_postprocessed, time_stamp_postprocessed, _ = postprocess_utils.sentence_postprocess(token, timestamp)
                        result_i = {"key": key[i], "text": text_postprocessed, "timestamp": time_stamp_postprocessed,}
                    else:
                        if not hasattr(tokenizer, "bpemodel"):
                            text_postprocessed, _ = postprocess_utils.sentence_postprocess(token)
                        result_i = {"key": key[i], "text": text_postprocessed}
                    if ibest_writer is not None:
                        ibest_writer["token"][key[i]] = " ".join(token)