嘉渊
2023-04-27 607073619cedf2c114e1589aa6d5953d171f33bf
funasr/models/e2e_tp.py
@@ -2,19 +2,23 @@
from contextlib import contextmanager
from distutils.version import LooseVersion
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
from typing import Union
import torch
import numpy as np
from typeguard import check_argument_types
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.predictor.cif import mae_loss
from funasr.models.base_model import FunASRModel
from funasr.modules.add_sos_eos import add_sos_eos
from funasr.modules.nets_utils import make_pad_mask, pad_list
from funasr.torch_utils.device_funcs import force_gatherable
from funasr.models.base_model import FunASRModel
from funasr.models.predictor.cif import CifPredictorV3
if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
    from torch.cuda.amp import autocast
@@ -25,15 +29,15 @@
        yield
class TimestampPredictor(FunASRModel):
class TimestampPredictor(AbsESPnetModel):
    """
    Author: Speech Lab, Alibaba Group, China
    Author: Speech Lab of DAMO Academy, Alibaba Group
    """
    def __init__(
            self,
            frontend: Optional[torch.nn.Module],
            encoder: torch.nn.Module,
            frontend: Optional[AbsFrontend],
            encoder: AbsEncoder,
            predictor: CifPredictorV3,
            predictor_bias: int = 0,
            token_list=None,
@@ -51,7 +55,7 @@
        self.predictor_bias = predictor_bias
        self.criterion_pre = mae_loss()
        self.token_list = token_list
    def forward(
            self,
            speech: torch.Tensor,
@@ -60,7 +64,6 @@
            text_lengths: torch.Tensor,
    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
        """Frontend + Encoder + Decoder + Calc loss
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
@@ -108,7 +111,6 @@
            self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """Frontend + Encoder. Note that this method is used by asr_inference.py
        Args:
                speech: (Batch, Length, ...)
                speech_lengths: (Batch, )
@@ -123,7 +125,7 @@
        encoder_out, encoder_out_lens, _ = self.encoder(feats, feats_lengths)
        return encoder_out, encoder_out_lens
    def _extract_feats(
            self, speech: torch.Tensor, speech_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -146,8 +148,8 @@
        encoder_out_mask = (~make_pad_mask(encoder_out_lens, maxlen=encoder_out.size(1))[:, None, :]).to(
            encoder_out.device)
        ds_alphas, ds_cif_peak, us_alphas, us_peaks = self.predictor.get_upsample_timestamp(encoder_out,
                                                                                               encoder_out_mask,
                                                                                               token_num)
                                                                                            encoder_out_mask,
                                                                                            token_num)
        return ds_alphas, ds_cif_peak, us_alphas, us_peaks
    def collect_feats(