游雁
2023-11-16 4ace5a95b052d338947fc88809a440ccd55cf6b4
funasr/models/e2e_tp.py
@@ -2,19 +2,22 @@
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
from typeguard import check_argument_types
import numpy as np
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.predictor.cif import mae_loss
from funasr.models.base_model import FunASRModel
from funasr.modules.add_sos_eos import add_sos_eos
from funasr.modules.nets_utils import make_pad_mask, pad_list
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
from funasr.models.predictor.cif import CifPredictorV3
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
@@ -27,18 +30,17 @@
class TimestampPredictor(FunASRModel):
    """
    Author: Speech Lab, Alibaba Group, China
    Author: Speech Lab of DAMO Academy, Alibaba Group
    """
    def __init__(
            self,
            frontend: Optional[torch.nn.Module],
            encoder: torch.nn.Module,
            frontend: Optional[AbsFrontend],
            encoder: AbsEncoder,
            predictor: CifPredictorV3,
            predictor_bias: int = 0,
            token_list=None,
    ):
        assert check_argument_types()
        super().__init__()
        # note that eos is the same as sos (equivalent ID)
@@ -51,7 +53,7 @@
        self.predictor_bias = predictor_bias
        self.criterion_pre = mae_loss()
        self.token_list = token_list
    def forward(
            self,
            speech: torch.Tensor,
@@ -60,7 +62,6 @@
            text_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
@@ -108,7 +109,6 @@
            self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
@@ -123,7 +123,7 @@
        encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
        return encoder_out, encoder_out_lens
    def _extract_feats(
            self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -146,8 +146,8 @@
        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
            encoder_out.device)
        ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
                                                                                               encoder_out_mask,
                                                                                               token_num)
                                                                                            encoder_out_mask,
                                                                                            token_num)
        return ds_alphas, ds_cif_peak, us_alphas, us_peaks
    def collect_feats(