From d5a80d642a5721eb1352cba59833a5cf4b91000f Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 11 四月 2023 00:09:29 +0800
Subject: [PATCH] update

---
 funasr/models/encoder/rnn_encoder.py         |    3 -
 /dev/null                                    |   55 ---------------------------
 funasr/models/encoder/resnet34_encoder.py    |    3 -
 funasr/models/encoder/data2vec_encoder.py    |    3 -
 funasr/models/encoder/conformer_encoder.py   |    3 -
 funasr/models/encoder/sanm_encoder.py        |    7 +--
 funasr/models/encoder/mfcca_encoder.py       |    4 -
 funasr/models/encoder/transformer_encoder.py |    3 -
 8 files changed, 9 insertions(+), 72 deletions(-)

diff --git a/funasr/models/encoder/abs_encoder.py b/funasr/models/encoder/abs_encoder.py
deleted file mode 100644
index 1fb7c97..0000000
--- a/funasr/models/encoder/abs_encoder.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from abc import ABC
-from abc import abstractmethod
-from typing import Optional
-from typing import Tuple
-
-import torch
-
-
-class AbsEncoder(torch.nn.Module, ABC):
-    @abstractmethod
-    def output_size(self) -> int:
-        raise NotImplementedError
-
-    @abstractmethod
-    def forward(
-        self,
-        xs_pad: torch.Tensor,
-        ilens: torch.Tensor,
-        prev_states: torch.Tensor = None,
-    ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
-        raise NotImplementedError
diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index 7c7f661..e649eca 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -14,7 +14,6 @@
 from typeguard import check_argument_types
 
 from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.modules.attention import (
     MultiHeadedAttention,  # noqa: H301
     RelPositionMultiHeadedAttention,  # noqa: H301
@@ -277,7 +276,7 @@
         return x, mask
 
 
-class ConformerEncoder(AbsEncoder):
+class ConformerEncoder(torch.nn.Module):
     """Conformer encoder module.
 
     Args:
diff --git a/funasr/models/encoder/data2vec_encoder.py b/funasr/models/encoder/data2vec_encoder.py
index fd1796c..a30e91e 100644
--- a/funasr/models/encoder/data2vec_encoder.py
+++ b/funasr/models/encoder/data2vec_encoder.py
@@ -12,7 +12,6 @@
 import torch.nn.functional as F
 from typeguard import check_argument_types
 
-from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.modules.data2vec.data_utils import compute_mask_indices
 from funasr.modules.data2vec.ema_module import EMAModule
 from funasr.modules.data2vec.grad_multiply import GradMultiply
@@ -29,7 +28,7 @@
     return end - r * pct_remaining
 
 
-class Data2VecEncoder(AbsEncoder):
+class Data2VecEncoder(torch.nn.Module):
     def __init__(
             self,
             # for ConvFeatureExtractionModel
diff --git a/funasr/models/encoder/mfcca_encoder.py b/funasr/models/encoder/mfcca_encoder.py
index 83d0b0e..9ffd452 100644
--- a/funasr/models/encoder/mfcca_encoder.py
+++ b/funasr/models/encoder/mfcca_encoder.py
@@ -34,8 +34,6 @@
 from funasr.modules.subsampling import Conv2dSubsampling8
 from funasr.modules.subsampling import TooShortUttError
 from funasr.modules.subsampling import check_short_utt
-from funasr.models.encoder.abs_encoder import AbsEncoder
-import pdb
 import math
 
 class ConvolutionModule(nn.Module):
@@ -108,7 +106,7 @@
 
 
 
-class MFCCAEncoder(AbsEncoder):
+class MFCCAEncoder(torch.nn.Module):
     """Conformer encoder module.
 
     Args:
diff --git a/funasr/models/encoder/resnet34_encoder.py b/funasr/models/encoder/resnet34_encoder.py
index 7d7179a..6f978eb 100644
--- a/funasr/models/encoder/resnet34_encoder.py
+++ b/funasr/models/encoder/resnet34_encoder.py
@@ -1,6 +1,5 @@
 import torch
 from torch.nn import functional as F
-from funasr.models.encoder.abs_encoder import AbsEncoder
 from typing import Tuple, Optional
 from funasr.models.pooling.statistic_pooling import statistic_pooling, windowed_statistic_pooling
 from collections import OrderedDict
@@ -76,7 +75,7 @@
         return xs_pad, ilens
 
 
-class ResNet34(AbsEncoder):
+class ResNet34(torch.nn.Module):
     def __init__(
             self,
             input_size,
diff --git a/funasr/models/encoder/rnn_encoder.py b/funasr/models/encoder/rnn_encoder.py
index 7a3b053..6b75574 100644
--- a/funasr/models/encoder/rnn_encoder.py
+++ b/funasr/models/encoder/rnn_encoder.py
@@ -9,10 +9,9 @@
 from funasr.modules.nets_utils import make_pad_mask
 from funasr.modules.rnn.encoders import RNN
 from funasr.modules.rnn.encoders import RNNP
-from funasr.models.encoder.abs_encoder import AbsEncoder
 
 
-class RNNEncoder(AbsEncoder):
+class RNNEncoder(torch.nn.Module):
     """RNNEncoder class.
 
     Args:
diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 2a3a353..1462403 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -26,7 +26,6 @@
 from funasr.modules.subsampling import TooShortUttError
 from funasr.modules.subsampling import check_short_utt
 from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.modules.mask import subsequent_mask, vad_mask
 
 class EncoderLayerSANM(nn.Module):
@@ -115,7 +114,7 @@
 
         return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
 
-class SANMEncoder(AbsEncoder):
+class SANMEncoder(torch.nn.Module):
     """
     author: Speech Lab, Alibaba Group, China
     San-m: Memory equipped self-attention for end-to-end speech recognition
@@ -547,7 +546,7 @@
         return var_dict_torch_update
 
 
-class SANMEncoderChunkOpt(AbsEncoder):
+class SANMEncoderChunkOpt(torch.nn.Module):
     """
     author: Speech Lab, Alibaba Group, China
     SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
@@ -960,7 +959,7 @@
         return var_dict_torch_update
 
 
-class SANMVadEncoder(AbsEncoder):
+class SANMVadEncoder(torch.nn.Module):
     """
     author: Speech Lab, Alibaba Group, China
 
diff --git a/funasr/models/encoder/transformer_encoder.py b/funasr/models/encoder/transformer_encoder.py
index ff9c3db..55a65b3 100644
--- a/funasr/models/encoder/transformer_encoder.py
+++ b/funasr/models/encoder/transformer_encoder.py
@@ -13,7 +13,6 @@
 import logging
 
 from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
 from funasr.modules.attention import MultiHeadedAttention
 from funasr.modules.embedding import PositionalEncoding
 from funasr.modules.layer_norm import LayerNorm
@@ -144,7 +143,7 @@
         return x, mask
 
 
-class TransformerEncoder(AbsEncoder):
+class TransformerEncoder(torch.nn.Module):
     """Transformer encoder module.
 
     Args:
diff --git a/funasr/train/abs_espnet_model.py b/funasr/train/abs_espnet_model.py
deleted file mode 100644
index cc6a5a2..0000000
--- a/funasr/train/abs_espnet_model.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-#  Apache 2.0  (http://www.apache.org/licenses/LICENSE-2.0)
-
-from abc import ABC
-from abc import abstractmethod
-from typing import Dict
-from typing import Tuple
-
-import torch
-
-
-class AbsESPnetModel(torch.nn.Module, ABC):
-    """The common abstract class among each tasks
-
-    "ESPnetModel" is referred to a class which inherits torch.nn.Module,
-    and makes the dnn-models forward as its member field,
-    a.k.a delegate pattern,
-    and defines "loss", "stats", and "weight" for the task.
-
-    If you intend to implement new task in ESPNet,
-    the model must inherit this class.
-    In other words, the "mediator" objects between
-    our training system and the your task class are
-    just only these three values, loss, stats, and weight.
-
-    Example:
-        >>> from funasr.tasks.abs_task import AbsTask
-        >>> class YourESPnetModel(AbsESPnetModel):
-        ...     def forward(self, input, input_lengths):
-        ...         ...
-        ...         return loss, stats, weight
-        >>> class YourTask(AbsTask):
-        ...     @classmethod
-        ...     def build_model(cls, args: argparse.Namespace) -> YourESPnetModel:
-    """
-
-    def __init__(self):
-        super().__init__()
-        self.num_updates = 0
-
-    @abstractmethod
-    def forward(
-        self, **batch: torch.Tensor
-    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
-        raise NotImplementedError
-
-    @abstractmethod
-    def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
-        raise NotImplementedError
-
-    def set_num_updates(self, num_updates):
-        self.num_updates = num_updates
-
-    def get_num_updates(self):
-        return self.num_updates

--
Gitblit v1.9.1