From d5a80d642a5721eb1352cba59833a5cf4b91000f Mon Sep 17 00:00:00 2001
From: speech_asr <wangjiaming.wjm@alibaba-inc.com>
Date: 星期二, 11 四月 2023 00:09:29 +0800
Subject: [PATCH] update
---
funasr/models/encoder/rnn_encoder.py | 3 -
/dev/null | 55 ---------------------------
funasr/models/encoder/resnet34_encoder.py | 3 -
funasr/models/encoder/data2vec_encoder.py | 3 -
funasr/models/encoder/conformer_encoder.py | 3 -
funasr/models/encoder/sanm_encoder.py | 7 +--
funasr/models/encoder/mfcca_encoder.py | 4 -
funasr/models/encoder/transformer_encoder.py | 3 -
8 files changed, 9 insertions(+), 72 deletions(-)
diff --git a/funasr/models/encoder/abs_encoder.py b/funasr/models/encoder/abs_encoder.py
deleted file mode 100644
index 1fb7c97..0000000
--- a/funasr/models/encoder/abs_encoder.py
+++ /dev/null
@@ -1,21 +0,0 @@
-from abc import ABC
-from abc import abstractmethod
-from typing import Optional
-from typing import Tuple
-
-import torch
-
-
-class AbsEncoder(torch.nn.Module, ABC):
- @abstractmethod
- def output_size(self) -> int:
- raise NotImplementedError
-
- @abstractmethod
- def forward(
- self,
- xs_pad: torch.Tensor,
- ilens: torch.Tensor,
- prev_states: torch.Tensor = None,
- ) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]:
- raise NotImplementedError
diff --git a/funasr/models/encoder/conformer_encoder.py b/funasr/models/encoder/conformer_encoder.py
index 7c7f661..e649eca 100644
--- a/funasr/models/encoder/conformer_encoder.py
+++ b/funasr/models/encoder/conformer_encoder.py
@@ -14,7 +14,6 @@
from typeguard import check_argument_types
from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.attention import (
MultiHeadedAttention, # noqa: H301
RelPositionMultiHeadedAttention, # noqa: H301
@@ -277,7 +276,7 @@
return x, mask
-class ConformerEncoder(AbsEncoder):
+class ConformerEncoder(torch.nn.Module):
"""Conformer encoder module.
Args:
diff --git a/funasr/models/encoder/data2vec_encoder.py b/funasr/models/encoder/data2vec_encoder.py
index fd1796c..a30e91e 100644
--- a/funasr/models/encoder/data2vec_encoder.py
+++ b/funasr/models/encoder/data2vec_encoder.py
@@ -12,7 +12,6 @@
import torch.nn.functional as F
from typeguard import check_argument_types
-from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.data2vec.data_utils import compute_mask_indices
from funasr.modules.data2vec.ema_module import EMAModule
from funasr.modules.data2vec.grad_multiply import GradMultiply
@@ -29,7 +28,7 @@
return end - r * pct_remaining
-class Data2VecEncoder(AbsEncoder):
+class Data2VecEncoder(torch.nn.Module):
def __init__(
self,
# for ConvFeatureExtractionModel
diff --git a/funasr/models/encoder/mfcca_encoder.py b/funasr/models/encoder/mfcca_encoder.py
index 83d0b0e..9ffd452 100644
--- a/funasr/models/encoder/mfcca_encoder.py
+++ b/funasr/models/encoder/mfcca_encoder.py
@@ -34,8 +34,6 @@
from funasr.modules.subsampling import Conv2dSubsampling8
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
-from funasr.models.encoder.abs_encoder import AbsEncoder
-import pdb
import math
class ConvolutionModule(nn.Module):
@@ -108,7 +106,7 @@
-class MFCCAEncoder(AbsEncoder):
+class MFCCAEncoder(torch.nn.Module):
"""Conformer encoder module.
Args:
diff --git a/funasr/models/encoder/resnet34_encoder.py b/funasr/models/encoder/resnet34_encoder.py
index 7d7179a..6f978eb 100644
--- a/funasr/models/encoder/resnet34_encoder.py
+++ b/funasr/models/encoder/resnet34_encoder.py
@@ -1,6 +1,5 @@
import torch
from torch.nn import functional as F
-from funasr.models.encoder.abs_encoder import AbsEncoder
from typing import Tuple, Optional
from funasr.models.pooling.statistic_pooling import statistic_pooling, windowed_statistic_pooling
from collections import OrderedDict
@@ -76,7 +75,7 @@
return xs_pad, ilens
-class ResNet34(AbsEncoder):
+class ResNet34(torch.nn.Module):
def __init__(
self,
input_size,
diff --git a/funasr/models/encoder/rnn_encoder.py b/funasr/models/encoder/rnn_encoder.py
index 7a3b053..6b75574 100644
--- a/funasr/models/encoder/rnn_encoder.py
+++ b/funasr/models/encoder/rnn_encoder.py
@@ -9,10 +9,9 @@
from funasr.modules.nets_utils import make_pad_mask
from funasr.modules.rnn.encoders import RNN
from funasr.modules.rnn.encoders import RNNP
-from funasr.models.encoder.abs_encoder import AbsEncoder
-class RNNEncoder(AbsEncoder):
+class RNNEncoder(torch.nn.Module):
"""RNNEncoder class.
Args:
diff --git a/funasr/models/encoder/sanm_encoder.py b/funasr/models/encoder/sanm_encoder.py
index 2a3a353..1462403 100644
--- a/funasr/models/encoder/sanm_encoder.py
+++ b/funasr/models/encoder/sanm_encoder.py
@@ -26,7 +26,6 @@
from funasr.modules.subsampling import TooShortUttError
from funasr.modules.subsampling import check_short_utt
from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.mask import subsequent_mask, vad_mask
class EncoderLayerSANM(nn.Module):
@@ -115,7 +114,7 @@
return x, mask, cache, mask_shfit_chunk, mask_att_chunk_encoder
-class SANMEncoder(AbsEncoder):
+class SANMEncoder(torch.nn.Module):
"""
author: Speech Lab, Alibaba Group, China
San-m: Memory equipped self-attention for end-to-end speech recognition
@@ -547,7 +546,7 @@
return var_dict_torch_update
-class SANMEncoderChunkOpt(AbsEncoder):
+class SANMEncoderChunkOpt(torch.nn.Module):
"""
author: Speech Lab, Alibaba Group, China
SCAMA: Streaming chunk-aware multihead attention for online end-to-end speech recognition
@@ -960,7 +959,7 @@
return var_dict_torch_update
-class SANMVadEncoder(AbsEncoder):
+class SANMVadEncoder(torch.nn.Module):
"""
author: Speech Lab, Alibaba Group, China
diff --git a/funasr/models/encoder/transformer_encoder.py b/funasr/models/encoder/transformer_encoder.py
index ff9c3db..55a65b3 100644
--- a/funasr/models/encoder/transformer_encoder.py
+++ b/funasr/models/encoder/transformer_encoder.py
@@ -13,7 +13,6 @@
import logging
from funasr.models.ctc import CTC
-from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.modules.attention import MultiHeadedAttention
from funasr.modules.embedding import PositionalEncoding
from funasr.modules.layer_norm import LayerNorm
@@ -144,7 +143,7 @@
return x, mask
-class TransformerEncoder(AbsEncoder):
+class TransformerEncoder(torch.nn.Module):
"""Transformer encoder module.
Args:
diff --git a/funasr/train/abs_espnet_model.py b/funasr/train/abs_espnet_model.py
deleted file mode 100644
index cc6a5a2..0000000
--- a/funasr/train/abs_espnet_model.py
+++ /dev/null
@@ -1,55 +0,0 @@
-# Copyright ESPnet (https://github.com/espnet/espnet). All Rights Reserved.
-# Apache 2.0 (http://www.apache.org/licenses/LICENSE-2.0)
-
-from abc import ABC
-from abc import abstractmethod
-from typing import Dict
-from typing import Tuple
-
-import torch
-
-
-class AbsESPnetModel(torch.nn.Module, ABC):
- """The common abstract class among each tasks
-
- "ESPnetModel" is referred to a class which inherits torch.nn.Module,
- and makes the dnn-models forward as its member field,
- a.k.a delegate pattern,
- and defines "loss", "stats", and "weight" for the task.
-
- If you intend to implement new task in ESPNet,
- the model must inherit this class.
- In other words, the "mediator" objects between
- our training system and the your task class are
- just only these three values, loss, stats, and weight.
-
- Example:
- >>> from funasr.tasks.abs_task import AbsTask
- >>> class YourESPnetModel(AbsESPnetModel):
- ... def forward(self, input, input_lengths):
- ... ...
- ... return loss, stats, weight
- >>> class YourTask(AbsTask):
- ... @classmethod
- ... def build_model(cls, args: argparse.Namespace) -> YourESPnetModel:
- """
-
- def __init__(self):
- super().__init__()
- self.num_updates = 0
-
- @abstractmethod
- def forward(
- self, **batch: torch.Tensor
- ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
- raise NotImplementedError
-
- @abstractmethod
- def collect_feats(self, **batch: torch.Tensor) -> Dict[str, torch.Tensor]:
- raise NotImplementedError
-
- def set_num_updates(self, num_updates):
- self.num_updates = num_updates
-
- def get_num_updates(self):
- return self.num_updates
--
Gitblit v1.9.1