From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add

---
 funasr/models/e2e_sa_asr.py |   19 ++-----------------
 1 files changed, 2 insertions(+), 17 deletions(-)

diff --git a/funasr/models/e2e_sa_asr.py b/funasr/models/e2e_sa_asr.py
index 8304607..cf1587d 100644
--- a/funasr/models/e2e_sa_asr.py
+++ b/funasr/models/e2e_sa_asr.py
@@ -12,7 +12,6 @@
 
 import torch
 import torch.nn.functional as F
-from typeguard import check_argument_types
 
 from funasr.layers.abs_normalize import AbsNormalize
 from funasr.losses.label_smoothing_loss import (
@@ -40,7 +39,7 @@
         yield
 
 
-class ESPnetASRModel(FunASRModel):
+class SAASRModel(FunASRModel):
     """CTC-attention hybrid Encoder-Decoder model"""
 
     def __init__(
@@ -51,10 +50,8 @@
             frontend: Optional[AbsFrontend],
             specaug: Optional[AbsSpecAug],
             normalize: Optional[AbsNormalize],
-            preencoder: Optional[AbsPreEncoder],
             asr_encoder: AbsEncoder,
             spk_encoder: torch.nn.Module,
-            postencoder: Optional[AbsPostEncoder],
             decoder: AbsDecoder,
             ctc: CTC,
             spk_weight: float = 0.5,
@@ -69,7 +66,6 @@
             sym_blank: str = "<blank>",
             extract_feats_in_collect_stats: bool = True,
     ):
-        assert check_argument_types()
         assert 0.0 <= ctc_weight <= 1.0, ctc_weight
         assert 0.0 <= interctc_weight < 1.0, interctc_weight
 
@@ -89,8 +85,6 @@
         self.frontend = frontend
         self.specaug = specaug
         self.normalize = normalize
-        self.preencoder = preencoder
-        self.postencoder = postencoder
         self.asr_encoder = asr_encoder
         self.spk_encoder = spk_encoder
 
@@ -293,10 +287,6 @@
             if self.normalize is not None:
                 feats, feats_lengths = self.normalize(feats, feats_lengths)
 
-        # Pre-encoder, e.g. used for raw input data
-        if self.preencoder is not None:
-            feats, feats_lengths = self.preencoder(feats, feats_lengths)
-
         # 4. Forward encoder
         # feats: (Batch, Length, Dim)
         # -> encoder_out: (Batch, Length2, Dim2)
@@ -317,11 +307,6 @@
             encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1)
         else:
             encoder_out_spk=encoder_out_spk_ori
-        # Post-encoder, e.g. NLU
-        if self.postencoder is not None:
-            encoder_out, encoder_out_lens = self.postencoder(
-                encoder_out, encoder_out_lens
-            )
 
         assert encoder_out.size(0) == speech.size(0), (
             encoder_out.size(),
@@ -337,7 +322,7 @@
         )
 
         if intermediate_outs is not None:
-            return (encoder_out, intermediate_outs), encoder_out_lens
+            return (encoder_out, intermediate_outs), encoder_out_lens, encoder_out_spk
 
         return encoder_out, encoder_out_lens, encoder_out_spk
 

--
Gitblit v1.9.1