嘉渊
2023-04-27 6ed27c64c96c6f8b148c6d4110716cba6a185452
update
4个文件已修改
2个文件已添加
48 ■■■■ 已修改文件
funasr/models/e2e_asr.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/conformer_encoder.py 4 ●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/encoder/transformer_encoder.py 3 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/abs_frontend.py 17 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/specaug/abs_specaug.py 16 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/specaug/specaug.py 5 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/e2e_asr.py
@@ -17,6 +17,7 @@
    LabelSmoothingLoss,  # noqa: H301
)
from funasr.models.ctc import CTC
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.base_model import FunASRModel
@@ -41,7 +42,7 @@
            self,
            vocab_size: int,
            token_list: Union[Tuple[str, ...], List[str]],
            frontend: Optional[torch.nn.Module],
            frontend: Optional[AbsFrontend],
            specaug: Optional[torch.nn.Module],
            normalize: Optional[torch.nn.Module],
            encoder: AbsEncoder,
funasr/models/encoder/conformer_encoder.py
@@ -19,6 +19,7 @@
    RelPositionMultiHeadedAttention,  # noqa: H301
    LegacyRelPositionMultiHeadedAttention,  # noqa: H301
)
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.embedding import (
    PositionalEncoding,  # noqa: H301
    ScaledPositionalEncoding,  # noqa: H301
@@ -41,7 +42,8 @@
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
from funasr.modules.subsampling import Conv2dSubsamplingPad
class ConvolutionModule(nn.Module):
class ConvolutionModule(AbsEncoder):
    """ConvolutionModule in Conformer model.
    Args:
funasr/models/encoder/transformer_encoder.py
@@ -13,6 +13,7 @@
import logging
from funasr.models.ctc import CTC
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.attention import MultiHeadedAttention
from funasr.modules.embedding import PositionalEncoding
from funasr.modules.layer_norm import LayerNorm
@@ -36,7 +37,7 @@
from funasr.modules.subsampling import check_short_utt
class EncoderLayer(nn.Module):
class EncoderLayer(AbsEncoder):
    """Encoder layer module.
    Args:
funasr/models/frontend/abs_frontend.py
New file
@@ -0,0 +1,17 @@
from abc import ABC
from abc import abstractmethod
from typing import Tuple
import torch
class AbsFrontend(torch.nn.Module, ABC):
    @abstractmethod
    def output_size(self) -> int:
        raise NotImplementedError
    @abstractmethod
    def forward(
        self, input: torch.Tensor, input_lengths: torch.Tensor
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        raise NotImplementedError
funasr/models/specaug/abs_specaug.py
New file
@@ -0,0 +1,16 @@
from typing import Optional
from typing import Tuple
import torch
class AbsSpecAug(torch.nn.Module):
    """Abstract class for the augmentation of spectrogram
    The process-flow:
    Frontend  -> SpecAug -> Normalization -> Encoder -> Decoder
    """
    def forward(
        self, x: torch.Tensor, x_lengths: torch.Tensor = None
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
        raise NotImplementedError
funasr/models/specaug/specaug.py
@@ -3,15 +3,14 @@
from typing import Sequence
from typing import Union
import torch.nn
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.layers.mask_along_axis import MaskAlongAxis
from funasr.layers.mask_along_axis import MaskAlongAxisVariableMaxWidth
from funasr.layers.mask_along_axis import MaskAlongAxisLFR
from funasr.layers.time_warp import TimeWarp
class SpecAug(torch.nn.Module):
class SpecAug(AbsSpecAug):
    """Implementation of SpecAug.
    Reference: