雾聪
2023-06-28 54931dd4e1a099d7d6f144c4e12e5453deb3aa26
funasr/models/e2e_sa_asr.py
@@ -40,7 +40,7 @@
        yield
class ESPnetASRModel(FunASRModel):
class SAASRModel(FunASRModel):
    """CTC-attention hybrid Encoder-Decoder model"""
    def __init__(
@@ -51,10 +51,8 @@
            frontend: Optional[AbsFrontend],
            specaug: Optional[AbsSpecAug],
            normalize: Optional[AbsNormalize],
            preencoder: Optional[AbsPreEncoder],
            asr_encoder: AbsEncoder,
            spk_encoder: torch.nn.Module,
            postencoder: Optional[AbsPostEncoder],
            decoder: AbsDecoder,
            ctc: CTC,
            spk_weight: float = 0.5,
@@ -89,8 +87,6 @@
        self.frontend = frontend
        self.specaug = specaug
        self.normalize = normalize
        self.preencoder = preencoder
        self.postencoder = postencoder
        self.asr_encoder = asr_encoder
        self.spk_encoder = spk_encoder
@@ -293,10 +289,6 @@
            if self.normalize is not None:
                feats, feats_lengths = self.normalize(feats, feats_lengths)
        # Pre-encoder, e.g. used for raw input data
        if self.preencoder is not None:
            feats, feats_lengths = self.preencoder(feats, feats_lengths)
        # 4. Forward encoder
        # feats: (Batch, Length, Dim)
        # -> encoder_out: (Batch, Length2, Dim2)
@@ -317,11 +309,6 @@
            encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1)
        else:
            encoder_out_spk=encoder_out_spk_ori
        # Post-encoder, e.g. NLU
        if self.postencoder is not None:
            encoder_out, encoder_out_lens = self.postencoder(
                encoder_out, encoder_out_lens
            )
        assert encoder_out.size(0) == speech.size(0), (
            encoder_out.size(),
@@ -337,7 +324,7 @@
        )
        if intermediate_outs is not None:
            return (encoder_out, intermediate_outs), encoder_out_lens
            return (encoder_out, intermediate_outs), encoder_out_lens, encoder_out_spk
        return encoder_out, encoder_out_lens, encoder_out_spk