From 6997763bf65705257fe6bca6ee63fcf006122abb Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 27 四月 2023 17:51:13 +0800
Subject: [PATCH] update

---
 funasr/models/frontend/wav_frontend_kaldifeat.py |  112 ----------------------
 funasr/tasks/punctuation.py                      |    2 
 funasr/tasks/sv.py                               |    8 
 funasr/tasks/diar.py                             |   10 +-
 funasr/tasks/asr.py                              |   12 +-
 funasr/tasks/abs_task.py                         |   22 ++--
 funasr/tasks/vad.py                              |   79 ++++-----------
 7 files changed, 48 insertions(+), 197 deletions(-)

diff --git a/funasr/models/frontend/wav_frontend_kaldifeat.py b/funasr/models/frontend/wav_frontend_kaldifeat.py
index 85adbb7..5372de3 100644
--- a/funasr/models/frontend/wav_frontend_kaldifeat.py
+++ b/funasr/models/frontend/wav_frontend_kaldifeat.py
@@ -1,17 +1,9 @@
 # Copyright (c) Alibaba, Inc. and its affiliates.
 # Part of the implementation is borrowed from espnet/espnet.
 
-from typing import Tuple
-
 import numpy as np
 import torch
-import torchaudio.compliance.kaldi as kaldi
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from typeguard import check_argument_types
-from torch.nn.utils.rnn import pad_sequence
 
-
-# import kaldifeat
 
 def load_cmvn(cmvn_file):
     with open(cmvn_file, 'r', encoding='utf-8') as f:
@@ -75,107 +67,3 @@
             LFR_inputs.append(frame)
     LFR_outputs = torch.vstack(LFR_inputs)
     return LFR_outputs.type(torch.float32)
-
-# class WavFrontend_kaldifeat(AbsFrontend):
-#     """Conventional frontend structure for ASR.
-#     """
-#
-#     def __init__(
-#         self,
-#         cmvn_file: str = None,
-#         fs: int = 16000,
-#         window: str = 'hamming',
-#         n_mels: int = 80,
-#         frame_length: int = 25,
-#         frame_shift: int = 10,
-#         lfr_m: int = 1,
-#         lfr_n: int = 1,
-#         dither: float = 1.0,
-#         snip_edges: bool = True,
-#         upsacle_samples: bool = True,
-#         device: str = 'cpu',
-#         **kwargs,
-#     ):
-#         super().__init__()
-#
-#         opts = kaldifeat.FbankOptions()
-#         opts.device = device
-#         opts.frame_opts.samp_freq = fs
-#         opts.frame_opts.dither = dither
-#         opts.frame_opts.window_type = window
-#         opts.frame_opts.frame_shift_ms = float(frame_shift)
-#         opts.frame_opts.frame_length_ms = float(frame_length)
-#         opts.mel_opts.num_bins = n_mels
-#         opts.energy_floor = 0
-#         opts.frame_opts.snip_edges = snip_edges
-#         opts.mel_opts.debug_mel = False
-#         self.opts = opts
-#         self.fbank_fn = None
-#         self.fbank_beg_idx = 0
-#         self.reset_fbank_status()
-#
-#         self.lfr_m = lfr_m
-#         self.lfr_n = lfr_n
-#         self.cmvn_file = cmvn_file
-#         self.upsacle_samples = upsacle_samples
-#
-#     def output_size(self) -> int:
-#         return self.n_mels * self.lfr_m
-#
-#     def forward_fbank(
-#         self,
-#         input: torch.Tensor,
-#         input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
-#         batch_size = input.size(0)
-#         feats = []
-#         feats_lens = []
-#         for i in range(batch_size):
-#             waveform_length = input_lengths[i]
-#             waveform = input[i][:waveform_length]
-#             waveform = waveform * (1 << 15)
-#
-#             self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
-#             frames = self.fbank_fn.num_frames_ready
-#             frames_cur = frames - self.fbank_beg_idx
-#             mat = torch.empty([frames_cur, self.opts.mel_opts.num_bins], dtype=torch.float32).to(
-#                 device=self.opts.device)
-#             for i in range(self.fbank_beg_idx, frames):
-#                 mat[i, :] = self.fbank_fn.get_frame(i)
-#             self.fbank_beg_idx += frames_cur
-#
-#             feat_length = mat.size(0)
-#             feats.append(mat)
-#             feats_lens.append(feat_length)
-#
-#         feats_lens = torch.as_tensor(feats_lens)
-#         feats_pad = pad_sequence(feats,
-#                                  batch_first=True,
-#                                  padding_value=0.0)
-#         return feats_pad, feats_lens
-#
-#     def reset_fbank_status(self):
-#         self.fbank_fn = kaldifeat.OnlineFbank(self.opts)
-#         self.fbank_beg_idx = 0
-#
-#     def forward_lfr_cmvn(
-#         self,
-#         input: torch.Tensor,
-#         input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
-#         batch_size = input.size(0)
-#         feats = []
-#         feats_lens = []
-#         for i in range(batch_size):
-#             mat = input[i, :input_lengths[i], :]
-#             if self.lfr_m != 1 or self.lfr_n != 1:
-#                 mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
-#             if self.cmvn_file is not None:
-#                 mat = apply_cmvn(mat, self.cmvn_file)
-#             feat_length = mat.size(0)
-#             feats.append(mat)
-#             feats_lens.append(feat_length)
-#
-#         feats_lens = torch.as_tensor(feats_lens)
-#         feats_pad = pad_sequence(feats,
-#                                  batch_first=True,
-#                                  padding_value=0.0)
-#         return feats_pad, feats_lens
\ No newline at end of file
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 6922ae0..5f9e8fc 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -30,7 +30,7 @@
 import torch.nn
 import torch.optim
 import yaml
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 from torch.utils.data import DataLoader
 from typeguard import check_argument_types
 from typeguard import check_return_type
@@ -230,8 +230,8 @@
         >>> cls.check_task_requirements()
         If your model is defined as following,
 
-        >>> from funasr.train.abs_espnet_model import AbsESPnetModel
-        >>> class Model(AbsESPnetModel):
+        >>> from funasr.models.base_model import FunASRModel
+        >>> class Model(FunASRModel):
         ...     def forward(self, input, output, opt=None):  pass
 
         then "required_data_names" should be as
@@ -251,8 +251,8 @@
         >>> cls.check_task_requirements()
         If your model is defined as follows,
 
-        >>> from funasr.train.abs_espnet_model import AbsESPnetModel
-        >>> class Model(AbsESPnetModel):
+        >>> from funasr.models.base_model import FunASRModel
+        >>> class Model(FunASRModel):
         ...     def forward(self, input, output, opt=None):  pass
 
         then "optional_data_names" should be as
@@ -263,7 +263,7 @@
 
     @classmethod
     @abstractmethod
-    def build_model(cls, args: argparse.Namespace) -> AbsESPnetModel:
+    def build_model(cls, args: argparse.Namespace) -> FunASRModel:
         raise NotImplementedError
 
     @classmethod
@@ -1235,9 +1235,9 @@
 
         # 2. Build model
         model = cls.build_model(args=args)
-        if not isinstance(model, AbsESPnetModel):
+        if not isinstance(model, FunASRModel):
             raise RuntimeError(
-                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
             )
         model = model.to(
             dtype=getattr(torch, args.train_dtype),
@@ -1921,7 +1921,7 @@
             model_file: Union[Path, str] = None,
             cmvn_file: Union[Path, str] = None,
             device: str = "cpu",
-    ) -> Tuple[AbsESPnetModel, argparse.Namespace]:
+    ) -> Tuple[FunASRModel, argparse.Namespace]:
         """Build model from the files.
 
         This method is used for inference or fine-tuning.
@@ -1948,9 +1948,9 @@
             args["cmvn_file"] = cmvn_file
         args = argparse.Namespace(**args)
         model = cls.build_model(args)
-        if not isinstance(model, AbsESPnetModel):
+        if not isinstance(model, FunASRModel):
             raise RuntimeError(
-                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
             )
         model.to(device)
         if model_file is not None:
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index e151473..6d93032 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -72,7 +72,7 @@
 from funasr.tasks.abs_task import AbsTask
 from funasr.text.phoneme_tokenizer import g2p_choices
 from funasr.torch_utils.initialize import initialize
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 from funasr.train.class_choices import ClassChoices
 from funasr.train.trainer import Trainer
 from funasr.utils.get_default_kwargs import get_default_kwargs
@@ -127,7 +127,7 @@
         mfcca=MFCCA,
         timestamp_prediction=TimestampPredictor,
     ),
-    type_check=AbsESPnetModel,
+    type_check=FunASRModel,
     default="asr",
 )
 preencoder_choices = ClassChoices(
@@ -810,9 +810,9 @@
             args["cmvn_file"] = cmvn_file
         args = argparse.Namespace(**args)
         model = cls.build_model(args)
-        if not isinstance(model, AbsESPnetModel):
+        if not isinstance(model, FunASRModel):
             raise RuntimeError(
-                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
             )
         model.to(device)
         model_dict = dict()
@@ -1057,9 +1057,9 @@
             args["cmvn_file"] = cmvn_file
         args = argparse.Namespace(**args)
         model = cls.build_model(args)
-        if not isinstance(model, AbsESPnetModel):
+        if not isinstance(model, FunASRModel):
             raise RuntimeError(
-                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
             )
         model.to(device)
         model_dict = dict()
diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index 096a5c8..0fa8c83 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -50,7 +50,7 @@
 from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
 from funasr.tasks.abs_task import AbsTask
 from funasr.torch_utils.initialize import initialize
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 from funasr.train.class_choices import ClassChoices
 from funasr.train.trainer import Trainer
 from funasr.utils.types import float_or_none
@@ -536,9 +536,9 @@
             args["cmvn_file"] = cmvn_file
         args = argparse.Namespace(**args)
         model = cls.build_model(args)
-        if not isinstance(model, AbsESPnetModel):
+        if not isinstance(model, FunASRModel):
             raise RuntimeError(
-                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
             )
         model.to(device)
         model_dict = dict()
@@ -894,9 +894,9 @@
             args = yaml.safe_load(f)
         args = argparse.Namespace(**args)
         model = cls.build_model(args)
-        if not isinstance(model, AbsESPnetModel):
+        if not isinstance(model, FunASRModel):
             raise RuntimeError(
-                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
             )
         if model_file is not None:
             if device == "cuda":
diff --git a/funasr/tasks/punctuation.py b/funasr/tasks/punctuation.py
index 0170f28..a63bbe4 100644
--- a/funasr/tasks/punctuation.py
+++ b/funasr/tasks/punctuation.py
@@ -14,7 +14,6 @@
 
 from funasr.datasets.collate_fn import CommonCollateFn
 from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor
-from funasr.train.abs_model import AbsPunctuation
 from funasr.train.abs_model import PunctuationModel
 from funasr.models.target_delay_transformer import TargetDelayTransformer
 from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
@@ -31,7 +30,6 @@
 punc_choices = ClassChoices(
     "punctuation",
     classes=dict(target_delay=TargetDelayTransformer, vad_realtime=VadRealtimeTransformer),
-    type_check=AbsPunctuation,
     default="target_delay",
 )
 
diff --git a/funasr/tasks/sv.py b/funasr/tasks/sv.py
index bef5dc5..d732e5a 100644
--- a/funasr/tasks/sv.py
+++ b/funasr/tasks/sv.py
@@ -45,7 +45,7 @@
 from funasr.models.specaug.specaug import SpecAug
 from funasr.tasks.abs_task import AbsTask
 from funasr.torch_utils.initialize import initialize
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
 from funasr.train.class_choices import ClassChoices
 from funasr.train.trainer import Trainer
 from funasr.utils.types import float_or_none
@@ -90,7 +90,7 @@
     classes=dict(
         espnet=ESPnetSVModel,
     ),
-    type_check=AbsESPnetModel,
+    type_check=FunASRModel,
     default="espnet",
 )
 preencoder_choices = ClassChoices(
@@ -484,9 +484,9 @@
             args["cmvn_file"] = cmvn_file
         args = argparse.Namespace(**args)
         model = cls.build_model(args)
-        if not isinstance(model, AbsESPnetModel):
+        if not isinstance(model, FunASRModel):
             raise RuntimeError(
-                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
             )
         model.to(device)
         model_dict = dict()
diff --git a/funasr/tasks/vad.py b/funasr/tasks/vad.py
index d07acf1..ec95596 100644
--- a/funasr/tasks/vad.py
+++ b/funasr/tasks/vad.py
@@ -1,77 +1,42 @@
 import argparse
 import logging
+import os
+from pathlib import Path
 from typing import Callable
 from typing import Collection
 from typing import Dict
 from typing import List
 from typing import Optional
 from typing import Tuple
-import os
-from pathlib import Path
-from typing import Tuple
 from typing import Union
-import yaml
+
 import numpy as np
 import torch
+import yaml
 from typeguard import check_argument_types
 from typeguard import check_return_type
 
 from funasr.datasets.collate_fn import CommonCollateFn
-from funasr.datasets.preprocessor import CommonPreprocessor
-from funasr.models.ctc import CTC
-from funasr.models.decoder.abs_decoder import AbsDecoder
-from funasr.models.decoder.rnn_decoder import RNNDecoder
-from funasr.models.decoder.transformer_decoder import (
-    DynamicConvolution2DTransformerDecoder,  # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
-from funasr.models.decoder.transformer_decoder import (
-    LightweightConvolution2DTransformerDecoder,  # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import (
-    LightweightConvolutionTransformerDecoder,  # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import TransformerDecoder
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.encoder.conformer_encoder import ConformerEncoder
-from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
-from funasr.models.encoder.rnn_encoder import RNNEncoder
-from funasr.models.encoder.transformer_encoder import TransformerEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.frontend.default import DefaultFrontend
-from funasr.models.frontend.fused import FusedFrontends
-from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
-from funasr.models.frontend.s3prl import S3prlFrontend
-from funasr.models.frontend.windowing import SlidingWindow
-from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.postencoder.hugging_face_transformers_postencoder import (
-    HuggingFaceTransformersPostEncoder,  # noqa: H301
-)
-from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.preencoder.linear import LinearProjection
-from funasr.models.preencoder.sinc import LightweightSincConvs
-from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.specaug.specaug import SpecAug
 from funasr.layers.abs_normalize import AbsNormalize
 from funasr.layers.global_mvn import GlobalMVN
 from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.tasks.abs_task import AbsTask
-from funasr.text.phoneme_tokenizer import g2p_choices
-from funasr.train.abs_espnet_model import AbsESPnetModel
-from funasr.train.class_choices import ClassChoices
-from funasr.train.trainer import Trainer
-from funasr.utils.get_default_kwargs import get_default_kwargs
-from funasr.utils.nested_dict_action import NestedDictAction
-from funasr.utils.types import float_or_none
-from funasr.utils.types import int_or_none
-from funasr.utils.types import str2bool
-from funasr.utils.types import str_or_none
-
-from funasr.models.specaug.specaug import SpecAugLFR
-from funasr.models.predictor.cif import CifPredictor, CifPredictorV2
-from funasr.modules.subsampling import Conv1dSubsampling
 from funasr.models.e2e_vad import E2EVadModel
 from funasr.models.encoder.fsmn_encoder import FSMN
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.specaug.specaug import SpecAug
+from funasr.models.specaug.specaug import SpecAugLFR
+from funasr.tasks.abs_task import AbsTask
+from funasr.train.class_choices import ClassChoices
+from funasr.train.trainer import Trainer
+from funasr.utils.types import float_or_none
+from funasr.utils.types import int_or_none
+from funasr.utils.types import str_or_none
 
 frontend_choices = ClassChoices(
     name="frontend",
@@ -292,7 +257,7 @@
             model_class = model_choices.get_class(args.model)
         except AttributeError:
             model_class = model_choices.get_class("e2evad")
-        
+
         # 1. frontend
         if args.input_size is None:
             # Extract features in the model
@@ -308,7 +273,7 @@
             args.frontend_conf = {}
             frontend = None
             input_size = args.input_size
-        
+
         model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
 
         return model
@@ -344,7 +309,7 @@
 
         with config_file.open("r", encoding="utf-8") as f:
             args = yaml.safe_load(f)
-        #if cmvn_file is not None:
+        # if cmvn_file is not None:
         args["cmvn_file"] = cmvn_file
         args = argparse.Namespace(**args)
         model = cls.build_model(args)

--
Gitblit v1.9.1