From 6997763bf65705257fe6bca6ee63fcf006122abb Mon Sep 17 00:00:00 2001
From: 嘉渊 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期四, 27 四月 2023 17:51:13 +0800
Subject: [PATCH] update
---
funasr/models/frontend/wav_frontend_kaldifeat.py | 112 ----------------------
funasr/tasks/punctuation.py | 2
funasr/tasks/sv.py | 8
funasr/tasks/diar.py | 10 +-
funasr/tasks/asr.py | 12 +-
funasr/tasks/abs_task.py | 22 ++--
funasr/tasks/vad.py | 79 ++++-----------
7 files changed, 48 insertions(+), 197 deletions(-)
diff --git a/funasr/models/frontend/wav_frontend_kaldifeat.py b/funasr/models/frontend/wav_frontend_kaldifeat.py
index 85adbb7..5372de3 100644
--- a/funasr/models/frontend/wav_frontend_kaldifeat.py
+++ b/funasr/models/frontend/wav_frontend_kaldifeat.py
@@ -1,17 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
-from typing import Tuple
-
import numpy as np
import torch
-import torchaudio.compliance.kaldi as kaldi
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from typeguard import check_argument_types
-from torch.nn.utils.rnn import pad_sequence
-
-# import kaldifeat
def load_cmvn(cmvn_file):
with open(cmvn_file, 'r', encoding='utf-8') as f:
@@ -75,107 +67,3 @@
LFR_inputs.append(frame)
LFR_outputs = torch.vstack(LFR_inputs)
return LFR_outputs.type(torch.float32)
-
-# class WavFrontend_kaldifeat(AbsFrontend):
-# """Conventional frontend structure for ASR.
-# """
-#
-# def __init__(
-# self,
-# cmvn_file: str = None,
-# fs: int = 16000,
-# window: str = 'hamming',
-# n_mels: int = 80,
-# frame_length: int = 25,
-# frame_shift: int = 10,
-# lfr_m: int = 1,
-# lfr_n: int = 1,
-# dither: float = 1.0,
-# snip_edges: bool = True,
-# upsacle_samples: bool = True,
-# device: str = 'cpu',
-# **kwargs,
-# ):
-# super().__init__()
-#
-# opts = kaldifeat.FbankOptions()
-# opts.device = device
-# opts.frame_opts.samp_freq = fs
-# opts.frame_opts.dither = dither
-# opts.frame_opts.window_type = window
-# opts.frame_opts.frame_shift_ms = float(frame_shift)
-# opts.frame_opts.frame_length_ms = float(frame_length)
-# opts.mel_opts.num_bins = n_mels
-# opts.energy_floor = 0
-# opts.frame_opts.snip_edges = snip_edges
-# opts.mel_opts.debug_mel = False
-# self.opts = opts
-# self.fbank_fn = None
-# self.fbank_beg_idx = 0
-# self.reset_fbank_status()
-#
-# self.lfr_m = lfr_m
-# self.lfr_n = lfr_n
-# self.cmvn_file = cmvn_file
-# self.upsacle_samples = upsacle_samples
-#
-# def output_size(self) -> int:
-# return self.n_mels * self.lfr_m
-#
-# def forward_fbank(
-# self,
-# input: torch.Tensor,
-# input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
-# batch_size = input.size(0)
-# feats = []
-# feats_lens = []
-# for i in range(batch_size):
-# waveform_length = input_lengths[i]
-# waveform = input[i][:waveform_length]
-# waveform = waveform * (1 << 15)
-#
-# self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
-# frames = self.fbank_fn.num_frames_ready
-# frames_cur = frames - self.fbank_beg_idx
-# mat = torch.empty([frames_cur, self.opts.mel_opts.num_bins], dtype=torch.float32).to(
-# device=self.opts.device)
-# for i in range(self.fbank_beg_idx, frames):
-# mat[i, :] = self.fbank_fn.get_frame(i)
-# self.fbank_beg_idx += frames_cur
-#
-# feat_length = mat.size(0)
-# feats.append(mat)
-# feats_lens.append(feat_length)
-#
-# feats_lens = torch.as_tensor(feats_lens)
-# feats_pad = pad_sequence(feats,
-# batch_first=True,
-# padding_value=0.0)
-# return feats_pad, feats_lens
-#
-# def reset_fbank_status(self):
-# self.fbank_fn = kaldifeat.OnlineFbank(self.opts)
-# self.fbank_beg_idx = 0
-#
-# def forward_lfr_cmvn(
-# self,
-# input: torch.Tensor,
-# input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
-# batch_size = input.size(0)
-# feats = []
-# feats_lens = []
-# for i in range(batch_size):
-# mat = input[i, :input_lengths[i], :]
-# if self.lfr_m != 1 or self.lfr_n != 1:
-# mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
-# if self.cmvn_file is not None:
-# mat = apply_cmvn(mat, self.cmvn_file)
-# feat_length = mat.size(0)
-# feats.append(mat)
-# feats_lens.append(feat_length)
-#
-# feats_lens = torch.as_tensor(feats_lens)
-# feats_pad = pad_sequence(feats,
-# batch_first=True,
-# padding_value=0.0)
-# return feats_pad, feats_lens
\ No newline at end of file
diff --git a/funasr/tasks/abs_task.py b/funasr/tasks/abs_task.py
index 6922ae0..5f9e8fc 100644
--- a/funasr/tasks/abs_task.py
+++ b/funasr/tasks/abs_task.py
@@ -30,7 +30,7 @@
import torch.nn
import torch.optim
import yaml
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from torch.utils.data import DataLoader
from typeguard import check_argument_types
from typeguard import check_return_type
@@ -230,8 +230,8 @@
>>> cls.check_task_requirements()
If your model is defined as following,
- >>> from funasr.train.abs_espnet_model import AbsESPnetModel
- >>> class Model(AbsESPnetModel):
+ >>> from funasr.models.base_model import FunASRModel
+ >>> class Model(FunASRModel):
... def forward(self, input, output, opt=None): pass
then "required_data_names" should be as
@@ -251,8 +251,8 @@
>>> cls.check_task_requirements()
If your model is defined as follows,
- >>> from funasr.train.abs_espnet_model import AbsESPnetModel
- >>> class Model(AbsESPnetModel):
+ >>> from funasr.models.base_model import FunASRModel
+ >>> class Model(FunASRModel):
... def forward(self, input, output, opt=None): pass
then "optional_data_names" should be as
@@ -263,7 +263,7 @@
@classmethod
@abstractmethod
- def build_model(cls, args: argparse.Namespace) -> AbsESPnetModel:
+ def build_model(cls, args: argparse.Namespace) -> FunASRModel:
raise NotImplementedError
@classmethod
@@ -1235,9 +1235,9 @@
# 2. Build model
model = cls.build_model(args=args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
model = model.to(
dtype=getattr(torch, args.train_dtype),
@@ -1921,7 +1921,7 @@
model_file: Union[Path, str] = None,
cmvn_file: Union[Path, str] = None,
device: str = "cpu",
- ) -> Tuple[AbsESPnetModel, argparse.Namespace]:
+ ) -> Tuple[FunASRModel, argparse.Namespace]:
"""Build model from the files.
This method is used for inference or fine-tuning.
@@ -1948,9 +1948,9 @@
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
model.to(device)
if model_file is not None:
diff --git a/funasr/tasks/asr.py b/funasr/tasks/asr.py
index e151473..6d93032 100644
--- a/funasr/tasks/asr.py
+++ b/funasr/tasks/asr.py
@@ -72,7 +72,7 @@
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.get_default_kwargs import get_default_kwargs
@@ -127,7 +127,7 @@
mfcca=MFCCA,
timestamp_prediction=TimestampPredictor,
),
- type_check=AbsESPnetModel,
+ type_check=FunASRModel,
default="asr",
)
preencoder_choices = ClassChoices(
@@ -810,9 +810,9 @@
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
model.to(device)
model_dict = dict()
@@ -1057,9 +1057,9 @@
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
model.to(device)
model_dict = dict()
diff --git a/funasr/tasks/diar.py b/funasr/tasks/diar.py
index 096a5c8..0fa8c83 100644
--- a/funasr/tasks/diar.py
+++ b/funasr/tasks/diar.py
@@ -50,7 +50,7 @@
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.tasks.abs_task import AbsTask
from funasr.torch_utils.initialize import initialize
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.types import float_or_none
@@ -536,9 +536,9 @@
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
model.to(device)
model_dict = dict()
@@ -894,9 +894,9 @@
args = yaml.safe_load(f)
args = argparse.Namespace(**args)
model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
if model_file is not None:
if device == "cuda":
diff --git a/funasr/tasks/punctuation.py b/funasr/tasks/punctuation.py
index 0170f28..a63bbe4 100644
--- a/funasr/tasks/punctuation.py
+++ b/funasr/tasks/punctuation.py
@@ -14,7 +14,6 @@
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor
-from funasr.train.abs_model import AbsPunctuation
from funasr.train.abs_model import PunctuationModel
from funasr.models.target_delay_transformer import TargetDelayTransformer
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
@@ -31,7 +30,6 @@
punc_choices = ClassChoices(
"punctuation",
classes=dict(target_delay=TargetDelayTransformer, vad_realtime=VadRealtimeTransformer),
- type_check=AbsPunctuation,
default="target_delay",
)
diff --git a/funasr/tasks/sv.py b/funasr/tasks/sv.py
index bef5dc5..d732e5a 100644
--- a/funasr/tasks/sv.py
+++ b/funasr/tasks/sv.py
@@ -45,7 +45,7 @@
from funasr.models.specaug.specaug import SpecAug
from funasr.tasks.abs_task import AbsTask
from funasr.torch_utils.initialize import initialize
-from funasr.train.abs_espnet_model import AbsESPnetModel
+from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.types import float_or_none
@@ -90,7 +90,7 @@
classes=dict(
espnet=ESPnetSVModel,
),
- type_check=AbsESPnetModel,
+ type_check=FunASRModel,
default="espnet",
)
preencoder_choices = ClassChoices(
@@ -484,9 +484,9 @@
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
- if not isinstance(model, AbsESPnetModel):
+ if not isinstance(model, FunASRModel):
raise RuntimeError(
- f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
+ f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
)
model.to(device)
model_dict = dict()
diff --git a/funasr/tasks/vad.py b/funasr/tasks/vad.py
index d07acf1..ec95596 100644
--- a/funasr/tasks/vad.py
+++ b/funasr/tasks/vad.py
@@ -1,77 +1,42 @@
import argparse
import logging
+import os
+from pathlib import Path
from typing import Callable
from typing import Collection
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
-import os
-from pathlib import Path
-from typing import Tuple
from typing import Union
-import yaml
+
import numpy as np
import torch
+import yaml
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.datasets.collate_fn import CommonCollateFn
-from funasr.datasets.preprocessor import CommonPreprocessor
-from funasr.models.ctc import CTC
-from funasr.models.decoder.abs_decoder import AbsDecoder
-from funasr.models.decoder.rnn_decoder import RNNDecoder
-from funasr.models.decoder.transformer_decoder import (
- DynamicConvolution2DTransformerDecoder, # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
-from funasr.models.decoder.transformer_decoder import (
- LightweightConvolution2DTransformerDecoder, # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import (
- LightweightConvolutionTransformerDecoder, # noqa: H301
-)
-from funasr.models.decoder.transformer_decoder import TransformerDecoder
-from funasr.models.encoder.abs_encoder import AbsEncoder
-from funasr.models.encoder.conformer_encoder import ConformerEncoder
-from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
-from funasr.models.encoder.rnn_encoder import RNNEncoder
-from funasr.models.encoder.transformer_encoder import TransformerEncoder
-from funasr.models.frontend.abs_frontend import AbsFrontend
-from funasr.models.frontend.default import DefaultFrontend
-from funasr.models.frontend.fused import FusedFrontends
-from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
-from funasr.models.frontend.s3prl import S3prlFrontend
-from funasr.models.frontend.windowing import SlidingWindow
-from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
-from funasr.models.postencoder.hugging_face_transformers_postencoder import (
- HuggingFaceTransformersPostEncoder, # noqa: H301
-)
-from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
-from funasr.models.preencoder.linear import LinearProjection
-from funasr.models.preencoder.sinc import LightweightSincConvs
-from funasr.models.specaug.abs_specaug import AbsSpecAug
-from funasr.models.specaug.specaug import SpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
-from funasr.tasks.abs_task import AbsTask
-from funasr.text.phoneme_tokenizer import g2p_choices
-from funasr.train.abs_espnet_model import AbsESPnetModel
-from funasr.train.class_choices import ClassChoices
-from funasr.train.trainer import Trainer
-from funasr.utils.get_default_kwargs import get_default_kwargs
-from funasr.utils.nested_dict_action import NestedDictAction
-from funasr.utils.types import float_or_none
-from funasr.utils.types import int_or_none
-from funasr.utils.types import str2bool
-from funasr.utils.types import str_or_none
-
-from funasr.models.specaug.specaug import SpecAugLFR
-from funasr.models.predictor.cif import CifPredictor, CifPredictorV2
-from funasr.modules.subsampling import Conv1dSubsampling
from funasr.models.e2e_vad import E2EVadModel
from funasr.models.encoder.fsmn_encoder import FSMN
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.fused import FusedFrontends
+from funasr.models.frontend.s3prl import S3prlFrontend
+from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.specaug.specaug import SpecAug
+from funasr.models.specaug.specaug import SpecAugLFR
+from funasr.tasks.abs_task import AbsTask
+from funasr.train.class_choices import ClassChoices
+from funasr.train.trainer import Trainer
+from funasr.utils.types import float_or_none
+from funasr.utils.types import int_or_none
+from funasr.utils.types import str_or_none
frontend_choices = ClassChoices(
name="frontend",
@@ -292,7 +257,7 @@
model_class = model_choices.get_class(args.model)
except AttributeError:
model_class = model_choices.get_class("e2evad")
-
+
# 1. frontend
if args.input_size is None:
# Extract features in the model
@@ -308,7 +273,7 @@
args.frontend_conf = {}
frontend = None
input_size = args.input_size
-
+
model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
return model
@@ -344,7 +309,7 @@
with config_file.open("r", encoding="utf-8") as f:
args = yaml.safe_load(f)
- #if cmvn_file is not None:
+ # if cmvn_file is not None:
args["cmvn_file"] = cmvn_file
args = argparse.Namespace(**args)
model = cls.build_model(args)
--
Gitblit v1.9.1