nichongjia-2007
2023-06-30 012903e42ec890ab5c50137beb365c3d94e731d1
funasr/models/e2e_sa_asr.py
@@ -12,7 +12,6 @@
import torch
import torch.nn.functional as F
from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
@@ -40,7 +39,7 @@
        yield
class ESPnetASRModel(FunASRModel):
class SAASRModel(FunASRModel):
    """CTC-attention hybrid Encoder-Decoder model"""
    def __init__(
@@ -51,10 +50,8 @@
            frontend: Optional[AbsFrontend],
            specaug: Optional[AbsSpecAug],
            normalize: Optional[AbsNormalize],
            preencoder: Optional[AbsPreEncoder],
            asr_encoder: AbsEncoder,
            spk_encoder: torch.nn.Module,
            postencoder: Optional[AbsPostEncoder],
            decoder: AbsDecoder,
            ctc: CTC,
            spk_weight: float = 0.5,
@@ -69,7 +66,6 @@
            sym_blank: str = "<blank>",
            extract_feats_in_collect_stats: bool = True,
    ):
        assert check_argument_types()
        assert 0.0 <= ctc_weight <= 1.0, ctc_weight
        assert 0.0 <= interctc_weight < 1.0, interctc_weight
@@ -89,8 +85,6 @@
        self.frontend = frontend
        self.specaug = specaug
        self.normalize = normalize
        self.preencoder = preencoder
        self.postencoder = postencoder
        self.asr_encoder = asr_encoder
        self.spk_encoder = spk_encoder
@@ -293,10 +287,6 @@
            if self.normalize is not None:
                feats, feats_lengths = self.normalize(feats, feats_lengths)
        # Pre-encoder, e.g. used for raw input data
        if self.preencoder is not None:
            feats, feats_lengths = self.preencoder(feats, feats_lengths)
        # 4. Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
@@ -317,11 +307,6 @@
            encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1)
        else:
            encoder_out_spk=encoder_out_spk_ori
        # Post-encoder, e.g. NLU
        if self.postencoder is not None:
            encoder_out, encoder_out_lens = self.postencoder(
                encoder_out, encoder_out_lens
            )
        assert encoder_out.size(0) == speech.size(0), (
            encoder_out.size(),
@@ -337,7 +322,7 @@
        )
        if intermediate_outs is not None:
            return (encoder_out, intermediate_outs), encoder_out_lens
            return (encoder_out, intermediate_outs), encoder_out_lens, encoder_out_spk
        return encoder_out, encoder_out_lens, encoder_out_spk