From 6ed27c64c96c6f8b148c6d4110716cba6a185452 Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 27 四月 2023 17:19:39 +0800
Subject: [PATCH] update
---
funasr/models/specaug/abs_specaug.py | 16 ++++++++++++++++
funasr/models/e2e_asr.py | 3 ++-
funasr/models/encoder/conformer_encoder.py | 4 +++-
funasr/models/encoder/transformer_encoder.py | 3 ++-
funasr/models/frontend/abs_frontend.py | 17 +++++++++++++++++
funasr/models/specaug/specaug.py | 5 ++---
6 files changed, 42 insertions(+), 6 deletions(-)
diff --git a/funasr/models/e2e_asr.py b/funasr/models/e2e_asr.py
index d3d5dfd..779d703 100644
--- a/funasr/models/e2e_asr.py
+++ b/funasr/models/e2e_asr.py
@@ -17,6 +17,7 @@
LabelSmoothingLoss, # noqa: H301
)
from funasr.models.ctc import CTC
+from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.base_model import FunASRModel
@@ -41,7 +42,7 @@
self,
vocab_size: int,
token_list: Union[Tuple[str, ...], List[str]],
- frontend: Optional[torch.nn.Module],
+ frontend: Optional[AbsFrontend],
specaug: Optional[torch.nn.Module],
normalize: Optional[torch.nn.Module],
encoder: AbsEncoder,
diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index e649eca..c05e38b 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -19,6 +19,7 @@
RelPositionMultiHeadedAttention, # noqa: H301
LegacyRelPositionMultiHeadedAttention, # noqa: H301
)
+from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.embedding import (
PositionalEncoding, # noqa: H301
ScaledPositionalEncoding, # noqa: H301
@@ -41,7 +42,8 @@
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
from funasr.modules.subsampling import Conv2dSubsamplingPad
-class ConvolutionModule(nn.Module):
+
+class ConvolutionModule(AbsEncoder):
"""ConvolutionModule in Conformer model.
Args:
diff --git a/funasr/models/encoder/transformer_encoder.py b/funasr/models/encoder/transformer_encoder.py
index 55a65b3..2fac880 100644
--- a/funasr/models/encoder/transformer_encoder.py
+++ b/funasr/models/encoder/transformer_encoder.py
@@ -13,6 +13,7 @@
import logging
from funasr.models.ctc import CTC
+from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.attention import MultiHeadedAttention
from funasr.modules.embedding import PositionalEncoding
from funasr.modules.layer_norm import LayerNorm
@@ -36,7 +37,7 @@
from funasr.modules.subsampling import check_short_utt
-class EncoderLayer(nn.Module):
+class EncoderLayer(AbsEncoder):
"""Encoder layer module.
Args:
diff --git a/funasr/models/frontend/abs_frontend.py b/funasr/models/frontend/abs_frontend.py
new file mode 100644
index 0000000..6049a01
--- /dev/null
+++ b/funasr/models/frontend/abs_frontend.py
@@ -0,0 +1,17 @@
+from abc import ABC
+from abc import abstractmethod
+from typing import Tuple
+
+import torch
+
+
+class AbsFrontend(torch.nn.Module, ABC):
+ @abstractmethod
+ def output_size(self) -> int:
+ raise NotImplementedError
+
+ @abstractmethod
+ def forward(
+ self, input: torch.Tensor, input_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ raise NotImplementedError
\ No newline at end of file
diff --git a/funasr/models/specaug/abs_specaug.py b/funasr/models/specaug/abs_specaug.py
new file mode 100644
index 0000000..da6637e
--- /dev/null
+++ b/funasr/models/specaug/abs_specaug.py
@@ -0,0 +1,16 @@
+from typing import Optional
+from typing import Tuple
+
+import torch
+
+
+class AbsSpecAug(torch.nn.Module):
+ """Abstract class for the augmentation of spectrogram
+ The process-flow:
+ Frontend -> SpecAug -> Normalization -> Encoder -> Decoder
+ """
+
+ def forward(
+ self, x: torch.Tensor, x_lengths: torch.Tensor = None
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
+ raise NotImplementedError
diff --git a/funasr/models/specaug/specaug.py b/funasr/models/specaug/specaug.py
index 75c5bb2..d4e305d 100644
--- a/funasr/models/specaug/specaug.py
+++ b/funasr/models/specaug/specaug.py
@@ -3,15 +3,14 @@
from typing import Sequence
from typing import Union
-import torch.nn
-
+from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.mask_along_axis import MaskAlongAxis
from funasr.layers.mask_along_axis import MaskAlongAxisVariableMaxWidth
from funasr.layers.mask_along_axis import MaskAlongAxisLFR
from funasr.layers.time_warp import TimeWarp
-class SpecAug(torch.nn.Module):
+class SpecAug(AbsSpecAug):
"""Implementation of SpecAug.
Reference:
--
Gitblit v1.9.1