嘉渊
2023-04-24 6427c834dfd97b1f05c6659cdc7ccf010bf82fe1
funasr/models/e2e_tp.py
@@ -2,22 +2,17 @@
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
import numpy as np
from typeguard import check_argument_types
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.predictor.cif import mae_loss
from funasr.models.base_model import FunASRModel
from funasr.modules.add_sos_eos import add_sos_eos
from funasr.modules.nets_utils import make_pad_mask, pad_list
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.predictor.cif import CifPredictorV3
@@ -30,17 +25,18 @@
        yield
class TimestampPredictor(AbsESPnetModel):
class TimestampPredictor(FunASRModel):
    """
    Author: Speech Lab, Alibaba Group, China
    """
    def __init__(
            self,
            frontend: Optional[AbsFrontend],
            encoder: AbsEncoder,
            frontend: Optional[torch.nn.Module],
            encoder: torch.nn.Module,
            predictor: CifPredictorV3,
            predictor_bias: int = 0,
            token_list=None,
    ):
        assert check_argument_types()
@@ -54,6 +50,7 @@
        self.predictor = predictor
        self.predictor_bias = predictor_bias
        self.criterion_pre = mae_loss()
        self.token_list = token_list
    
    def forward(
            self,
@@ -148,7 +145,26 @@
    def calc_predictor_timestamp(self, encoder_out, encoder_out_lens, token_num):
        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
            encoder_out.device)
        ds_alphas, ds_cif_peak, us_alphas, us_cif_peak = self.predictor.get_upsample_timestamp(encoder_out,
        ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
                                                                                               encoder_out_mask,
                                                                                               token_num)
        return ds_alphas, ds_cif_peak, us_alphas, us_cif_peak
        return ds_alphas, ds_cif_peak, us_alphas, us_peaks
    def collect_feats(
            self,
            speech: torch.Tensor,
            speech_lengths: torch.Tensor,
            text: torch.Tensor,
            text_lengths: torch.Tensor,
    ) -> Dict[str, torch.Tensor]:
        if self.extract_feats_in_collect_stats:
            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
        else:
            # Generate dummy stats if extract_feats_in_collect_stats is False
            logging.warning(
                "Generating dummy stats for feats and feats_lengths, "
                "because encoder_conf.extract_feats_in_collect_stats is "
                f"{self.extract_feats_in_collect_stats}"
            )
            feats, feats_lengths = speech, speech_lengths
        return {"feats": feats, "feats_lengths": feats_lengths}