From 33d3d2084403fd34b79c835d2f2fe04f6cd8f738 Mon Sep 17 00:00:00 2001
From: 游雁 <zhifu.gzf@alibaba-inc.com>
Date: 星期三, 13 九月 2023 09:33:54 +0800
Subject: [PATCH] Merge branch 'main' of github.com:alibaba-damo-academy/FunASR add
---
funasr/models/e2e_sa_asr.py | 19 ++-----------------
1 files changed, 2 insertions(+), 17 deletions(-)
diff --git a/funasr/models/e2e_sa_asr.py b/funasr/models/e2e_sa_asr.py
index 8304607..cf1587d 100644
--- a/funasr/models/e2e_sa_asr.py
+++ b/funasr/models/e2e_sa_asr.py
@@ -12,7 +12,6 @@
import torch
import torch.nn.functional as F
-from typeguard import check_argument_types
from funasr.layers.abs_normalize import AbsNormalize
from funasr.losses.label_smoothing_loss import (
@@ -40,7 +39,7 @@
yield
-class ESPnetASRModel(FunASRModel):
+class SAASRModel(FunASRModel):
"""CTC-attention hybrid Encoder-Decoder model"""
def __init__(
@@ -51,10 +50,8 @@
frontend: Optional[AbsFrontend],
specaug: Optional[AbsSpecAug],
normalize: Optional[AbsNormalize],
- preencoder: Optional[AbsPreEncoder],
asr_encoder: AbsEncoder,
spk_encoder: torch.nn.Module,
- postencoder: Optional[AbsPostEncoder],
decoder: AbsDecoder,
ctc: CTC,
spk_weight: float = 0.5,
@@ -69,7 +66,6 @@
sym_blank: str = "<blank>",
extract_feats_in_collect_stats: bool = True,
):
- assert check_argument_types()
assert 0.0 <= ctc_weight <= 1.0, ctc_weight
assert 0.0 <= interctc_weight < 1.0, interctc_weight
@@ -89,8 +85,6 @@
self.frontend = frontend
self.specaug = specaug
self.normalize = normalize
- self.preencoder = preencoder
- self.postencoder = postencoder
self.asr_encoder = asr_encoder
self.spk_encoder = spk_encoder
@@ -293,10 +287,6 @@
if self.normalize is not None:
feats, feats_lengths = self.normalize(feats, feats_lengths)
- # Pre-encoder, e.g. used for raw input data
- if self.preencoder is not None:
- feats, feats_lengths = self.preencoder(feats, feats_lengths)
-
# 4. Forward encoder
# feats: (Batch, Length, Dim)
# -> encoder_out: (Batch, Length2, Dim2)
@@ -317,11 +307,6 @@
encoder_out_spk=F.interpolate(encoder_out_spk_ori.transpose(-2,-1), size=(encoder_out.size(1)), mode='nearest').transpose(-2,-1)
else:
encoder_out_spk=encoder_out_spk_ori
- # Post-encoder, e.g. NLU
- if self.postencoder is not None:
- encoder_out, encoder_out_lens = self.postencoder(
- encoder_out, encoder_out_lens
- )
assert encoder_out.size(0) == speech.size(0), (
encoder_out.size(),
@@ -337,7 +322,7 @@
)
if intermediate_outs is not None:
- return (encoder_out, intermediate_outs), encoder_out_lens
+ return (encoder_out, intermediate_outs), encoder_out_lens, encoder_out_spk
return encoder_out, encoder_out_lens, encoder_out_spk
--
Gitblit v1.9.1