嘉渊
2023-04-27 6997763bf65705257fe6bca6ee63fcf006122abb
update
7个文件已修改
245 ■■■■ 已修改文件
funasr/models/frontend/wav_frontend_kaldifeat.py 112 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/abs_task.py 22 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/asr.py 12 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/diar.py 10 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/punctuation.py 2 ●●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/sv.py 8 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/tasks/vad.py 79 ●●●● 补丁 | 查看 | 原始文档 | blame | 历史
funasr/models/frontend/wav_frontend_kaldifeat.py
@@ -1,17 +1,9 @@
# Copyright (c) Alibaba, Inc. and its affiliates.
# Part of the implementation is borrowed from espnet/espnet.
from typing import Tuple
import numpy as np
import torch
import torchaudio.compliance.kaldi as kaldi
from funasr.models.frontend.abs_frontend import AbsFrontend
from typeguard import check_argument_types
from torch.nn.utils.rnn import pad_sequence
# import kaldifeat
def load_cmvn(cmvn_file):
    with open(cmvn_file, 'r', encoding='utf-8') as f:
@@ -75,107 +67,3 @@
            LFR_inputs.append(frame)
    LFR_outputs = torch.vstack(LFR_inputs)
    return LFR_outputs.type(torch.float32)
# class WavFrontend_kaldifeat(AbsFrontend):
#     """Conventional frontend structure for ASR.
#     """
#
#     def __init__(
#         self,
#         cmvn_file: str = None,
#         fs: int = 16000,
#         window: str = 'hamming',
#         n_mels: int = 80,
#         frame_length: int = 25,
#         frame_shift: int = 10,
#         lfr_m: int = 1,
#         lfr_n: int = 1,
#         dither: float = 1.0,
#         snip_edges: bool = True,
#         upsacle_samples: bool = True,
#         device: str = 'cpu',
#         **kwargs,
#     ):
#         super().__init__()
#
#         opts = kaldifeat.FbankOptions()
#         opts.device = device
#         opts.frame_opts.samp_freq = fs
#         opts.frame_opts.dither = dither
#         opts.frame_opts.window_type = window
#         opts.frame_opts.frame_shift_ms = float(frame_shift)
#         opts.frame_opts.frame_length_ms = float(frame_length)
#         opts.mel_opts.num_bins = n_mels
#         opts.energy_floor = 0
#         opts.frame_opts.snip_edges = snip_edges
#         opts.mel_opts.debug_mel = False
#         self.opts = opts
#         self.fbank_fn = None
#         self.fbank_beg_idx = 0
#         self.reset_fbank_status()
#
#         self.lfr_m = lfr_m
#         self.lfr_n = lfr_n
#         self.cmvn_file = cmvn_file
#         self.upsacle_samples = upsacle_samples
#
#     def output_size(self) -> int:
#         return self.n_mels * self.lfr_m
#
#     def forward_fbank(
#         self,
#         input: torch.Tensor,
#         input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
#         batch_size = input.size(0)
#         feats = []
#         feats_lens = []
#         for i in range(batch_size):
#             waveform_length = input_lengths[i]
#             waveform = input[i][:waveform_length]
#             waveform = waveform * (1 << 15)
#
#             self.fbank_fn.accept_waveform(self.opts.frame_opts.samp_freq, waveform.tolist())
#             frames = self.fbank_fn.num_frames_ready
#             frames_cur = frames - self.fbank_beg_idx
#             mat = torch.empty([frames_cur, self.opts.mel_opts.num_bins], dtype=torch.float32).to(
#                 device=self.opts.device)
#             for i in range(self.fbank_beg_idx, frames):
#                 mat[i, :] = self.fbank_fn.get_frame(i)
#             self.fbank_beg_idx += frames_cur
#
#             feat_length = mat.size(0)
#             feats.append(mat)
#             feats_lens.append(feat_length)
#
#         feats_lens = torch.as_tensor(feats_lens)
#         feats_pad = pad_sequence(feats,
#                                  batch_first=True,
#                                  padding_value=0.0)
#         return feats_pad, feats_lens
#
#     def reset_fbank_status(self):
#         self.fbank_fn = kaldifeat.OnlineFbank(self.opts)
#         self.fbank_beg_idx = 0
#
#     def forward_lfr_cmvn(
#         self,
#         input: torch.Tensor,
#         input_lengths: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
#         batch_size = input.size(0)
#         feats = []
#         feats_lens = []
#         for i in range(batch_size):
#             mat = input[i, :input_lengths[i], :]
#             if self.lfr_m != 1 or self.lfr_n != 1:
#                 mat = apply_lfr(mat, self.lfr_m, self.lfr_n)
#             if self.cmvn_file is not None:
#                 mat = apply_cmvn(mat, self.cmvn_file)
#             feat_length = mat.size(0)
#             feats.append(mat)
#             feats_lens.append(feat_length)
#
#         feats_lens = torch.as_tensor(feats_lens)
#         feats_pad = pad_sequence(feats,
#                                  batch_first=True,
#                                  padding_value=0.0)
#         return feats_pad, feats_lens
funasr/tasks/abs_task.py
@@ -30,7 +30,7 @@
import torch.nn
import torch.optim
import yaml
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.base_model import FunASRModel
from torch.utils.data import DataLoader
from typeguard import check_argument_types
from typeguard import check_return_type
@@ -230,8 +230,8 @@
        >>> cls.check_task_requirements()
        If your model is defined as following,
        >>> from funasr.train.abs_espnet_model import AbsESPnetModel
        >>> class Model(AbsESPnetModel):
        >>> from funasr.models.base_model import FunASRModel
        >>> class Model(FunASRModel):
        ...     def forward(self, input, output, opt=None):  pass
        then "required_data_names" should be as
@@ -251,8 +251,8 @@
        >>> cls.check_task_requirements()
        If your model is defined as follows,
        >>> from funasr.train.abs_espnet_model import AbsESPnetModel
        >>> class Model(AbsESPnetModel):
        >>> from funasr.models.base_model import FunASRModel
        >>> class Model(FunASRModel):
        ...     def forward(self, input, output, opt=None):  pass
        then "optional_data_names" should be as
@@ -263,7 +263,7 @@
    @classmethod
    @abstractmethod
    def build_model(cls, args: argparse.Namespace) -> AbsESPnetModel:
    def build_model(cls, args: argparse.Namespace) -> FunASRModel:
        raise NotImplementedError
    @classmethod
@@ -1235,9 +1235,9 @@
        # 2. Build model
        model = cls.build_model(args=args)
        if not isinstance(model, AbsESPnetModel):
        if not isinstance(model, FunASRModel):
            raise RuntimeError(
                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
            )
        model = model.to(
            dtype=getattr(torch, args.train_dtype),
@@ -1921,7 +1921,7 @@
            model_file: Union[Path, str] = None,
            cmvn_file: Union[Path, str] = None,
            device: str = "cpu",
    ) -> Tuple[AbsESPnetModel, argparse.Namespace]:
    ) -> Tuple[FunASRModel, argparse.Namespace]:
        """Build model from the files.
        This method is used for inference or fine-tuning.
@@ -1948,9 +1948,9 @@
            args["cmvn_file"] = cmvn_file
        args = argparse.Namespace(**args)
        model = cls.build_model(args)
        if not isinstance(model, AbsESPnetModel):
        if not isinstance(model, FunASRModel):
            raise RuntimeError(
                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
            )
        model.to(device)
        if model_file is not None:
funasr/tasks/asr.py
@@ -72,7 +72,7 @@
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.torch_utils.initialize import initialize
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.get_default_kwargs import get_default_kwargs
@@ -127,7 +127,7 @@
        mfcca=MFCCA,
        timestamp_prediction=TimestampPredictor,
    ),
    type_check=AbsESPnetModel,
    type_check=FunASRModel,
    default="asr",
)
preencoder_choices = ClassChoices(
@@ -810,9 +810,9 @@
            args["cmvn_file"] = cmvn_file
        args = argparse.Namespace(**args)
        model = cls.build_model(args)
        if not isinstance(model, AbsESPnetModel):
        if not isinstance(model, FunASRModel):
            raise RuntimeError(
                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
            )
        model.to(device)
        model_dict = dict()
@@ -1057,9 +1057,9 @@
            args["cmvn_file"] = cmvn_file
        args = argparse.Namespace(**args)
        model = cls.build_model(args)
        if not isinstance(model, AbsESPnetModel):
        if not isinstance(model, FunASRModel):
            raise RuntimeError(
                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
            )
        model.to(device)
        model_dict = dict()
funasr/tasks/diar.py
@@ -50,7 +50,7 @@
from funasr.modules.eend_ola.encoder_decoder_attractor import EncoderDecoderAttractor
from funasr.tasks.abs_task import AbsTask
from funasr.torch_utils.initialize import initialize
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.types import float_or_none
@@ -536,9 +536,9 @@
            args["cmvn_file"] = cmvn_file
        args = argparse.Namespace(**args)
        model = cls.build_model(args)
        if not isinstance(model, AbsESPnetModel):
        if not isinstance(model, FunASRModel):
            raise RuntimeError(
                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
            )
        model.to(device)
        model_dict = dict()
@@ -894,9 +894,9 @@
            args = yaml.safe_load(f)
        args = argparse.Namespace(**args)
        model = cls.build_model(args)
        if not isinstance(model, AbsESPnetModel):
        if not isinstance(model, FunASRModel):
            raise RuntimeError(
                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
            )
        if model_file is not None:
            if device == "cuda":
funasr/tasks/punctuation.py
@@ -14,7 +14,6 @@
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import PuncTrainTokenizerCommonPreprocessor
from funasr.train.abs_model import AbsPunctuation
from funasr.train.abs_model import PunctuationModel
from funasr.models.target_delay_transformer import TargetDelayTransformer
from funasr.models.vad_realtime_transformer import VadRealtimeTransformer
@@ -31,7 +30,6 @@
punc_choices = ClassChoices(
    "punctuation",
    classes=dict(target_delay=TargetDelayTransformer, vad_realtime=VadRealtimeTransformer),
    type_check=AbsPunctuation,
    default="target_delay",
)
funasr/tasks/sv.py
@@ -45,7 +45,7 @@
from funasr.models.specaug.specaug import SpecAug
from funasr.tasks.abs_task import AbsTask
from funasr.torch_utils.initialize import initialize
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.models.base_model import FunASRModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.types import float_or_none
@@ -90,7 +90,7 @@
    classes=dict(
        espnet=ESPnetSVModel,
    ),
    type_check=AbsESPnetModel,
    type_check=FunASRModel,
    default="espnet",
)
preencoder_choices = ClassChoices(
@@ -484,9 +484,9 @@
            args["cmvn_file"] = cmvn_file
        args = argparse.Namespace(**args)
        model = cls.build_model(args)
        if not isinstance(model, AbsESPnetModel):
        if not isinstance(model, FunASRModel):
            raise RuntimeError(
                f"model must inherit {AbsESPnetModel.__name__}, but got {type(model)}"
                f"model must inherit {FunASRModel.__name__}, but got {type(model)}"
            )
        model.to(device)
        model_dict = dict()
funasr/tasks/vad.py
@@ -1,77 +1,42 @@
import argparse
import logging
import os
from pathlib import Path
from typing import Callable
from typing import Collection
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import os
from pathlib import Path
from typing import Tuple
from typing import Union
import yaml
import numpy as np
import torch
import yaml
from typeguard import check_argument_types
from typeguard import check_return_type
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.decoder.rnn_decoder import RNNDecoder
from funasr.models.decoder.transformer_decoder import (
    DynamicConvolution2DTransformerDecoder,  # noqa: H301
)
from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
from funasr.models.decoder.transformer_decoder import (
    LightweightConvolution2DTransformerDecoder,  # noqa: H301
)
from funasr.models.decoder.transformer_decoder import (
    LightweightConvolutionTransformerDecoder,  # noqa: H301
)
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.transformer_encoder import TransformerEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.fused import FusedFrontends
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.postencoder.hugging_face_transformers_postencoder import (
    HuggingFaceTransformersPostEncoder,  # noqa: H301
)
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.preencoder.linear import LinearProjection
from funasr.models.preencoder.sinc import LightweightSincConvs
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.get_default_kwargs import get_default_kwargs
from funasr.utils.nested_dict_action import NestedDictAction
from funasr.utils.types import float_or_none
from funasr.utils.types import int_or_none
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
from funasr.models.specaug.specaug import SpecAugLFR
from funasr.models.predictor.cif import CifPredictor, CifPredictorV2
from funasr.modules.subsampling import Conv1dSubsampling
from funasr.models.e2e_vad import E2EVadModel
from funasr.models.encoder.fsmn_encoder import FSMN
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.fused import FusedFrontends
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
from funasr.tasks.abs_task import AbsTask
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.types import float_or_none
from funasr.utils.types import int_or_none
from funasr.utils.types import str_or_none
frontend_choices = ClassChoices(
    name="frontend",
@@ -292,7 +257,7 @@
            model_class = model_choices.get_class(args.model)
        except AttributeError:
            model_class = model_choices.get_class("e2evad")
        # 1. frontend
        if args.input_size is None:
            # Extract features in the model
@@ -308,7 +273,7 @@
            args.frontend_conf = {}
            frontend = None
            input_size = args.input_size
        model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
        return model
@@ -344,7 +309,7 @@
        with config_file.open("r", encoding="utf-8") as f:
            args = yaml.safe_load(f)
        #if cmvn_file is not None:
        # if cmvn_file is not None:
        args["cmvn_file"] = cmvn_file
        args = argparse.Namespace(**args)
        model = cls.build_model(args)