hnluo
2023-09-11 9fcb3cc06b4e324f0913d2f61b89becc2baeef1b
funasr/tasks/vad.py
@@ -1,77 +1,40 @@
import argparse
import logging
import os
from pathlib import Path
from typing import Callable
from typing import Collection
from typing import Dict
from typing import List
from typing import Optional
from typing import Tuple
import os
from pathlib import Path
from typing import Tuple
from typing import Union
import yaml
import numpy as np
import torch
from typeguard import check_argument_types
from typeguard import check_return_type
import yaml
from funasr.datasets.collate_fn import CommonCollateFn
from funasr.datasets.preprocessor import CommonPreprocessor
from funasr.models.ctc import CTC
from funasr.models.decoder.abs_decoder import AbsDecoder
from funasr.models.decoder.rnn_decoder import RNNDecoder
from funasr.models.decoder.transformer_decoder import (
    DynamicConvolution2DTransformerDecoder,  # noqa: H301
)
from funasr.models.decoder.transformer_decoder import DynamicConvolutionTransformerDecoder
from funasr.models.decoder.transformer_decoder import (
    LightweightConvolution2DTransformerDecoder,  # noqa: H301
)
from funasr.models.decoder.transformer_decoder import (
    LightweightConvolutionTransformerDecoder,  # noqa: H301
)
from funasr.models.decoder.transformer_decoder import TransformerDecoder
from funasr.models.encoder.abs_encoder import AbsEncoder
from funasr.models.encoder.conformer_encoder import ConformerEncoder
from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
from funasr.models.encoder.rnn_encoder import RNNEncoder
from funasr.models.encoder.transformer_encoder import TransformerEncoder
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.fused import FusedFrontends
from funasr.models.frontend.wav_frontend import WavFrontend
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.postencoder.abs_postencoder import AbsPostEncoder
from funasr.models.postencoder.hugging_face_transformers_postencoder import (
    HuggingFaceTransformersPostEncoder,  # noqa: H301
)
from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
from funasr.models.preencoder.linear import LinearProjection
from funasr.models.preencoder.sinc import LightweightSincConvs
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug
from funasr.layers.abs_normalize import AbsNormalize
from funasr.layers.global_mvn import GlobalMVN
from funasr.layers.utterance_mvn import UtteranceMVN
from funasr.tasks.abs_task import AbsTask
from funasr.text.phoneme_tokenizer import g2p_choices
from funasr.train.abs_espnet_model import AbsESPnetModel
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.get_default_kwargs import get_default_kwargs
from funasr.utils.nested_dict_action import NestedDictAction
from funasr.utils.types import float_or_none
from funasr.utils.types import int_or_none
from funasr.utils.types import str2bool
from funasr.utils.types import str_or_none
from funasr.models.specaug.specaug import SpecAugLFR
from funasr.models.predictor.cif import CifPredictor, CifPredictorV2
from funasr.modules.subsampling import Conv1dSubsampling
from funasr.models.e2e_vad import E2EVadModel
from funasr.models.encoder.fsmn_encoder import FSMN
from funasr.models.frontend.abs_frontend import AbsFrontend
from funasr.models.frontend.default import DefaultFrontend
from funasr.models.frontend.fused import FusedFrontends
from funasr.models.frontend.s3prl import S3prlFrontend
from funasr.models.frontend.wav_frontend import WavFrontend, WavFrontendOnline
from funasr.models.frontend.windowing import SlidingWindow
from funasr.models.specaug.abs_specaug import AbsSpecAug
from funasr.models.specaug.specaug import SpecAug
from funasr.models.specaug.specaug import SpecAugLFR
from funasr.tasks.abs_task import AbsTask
from funasr.train.class_choices import ClassChoices
from funasr.train.trainer import Trainer
from funasr.utils.types import float_or_none
from funasr.utils.types import int_or_none
from funasr.utils.types import str_or_none
frontend_choices = ClassChoices(
    name="frontend",
@@ -81,6 +44,7 @@
        s3prl=S3prlFrontend,
        fused=FusedFrontends,
        wav_frontend=WavFrontend,
        wav_frontend_online=WavFrontendOnline,
    ),
    type_check=AbsFrontend,
    default="default",
@@ -226,7 +190,6 @@
        [Collection[Tuple[str, Dict[str, np.ndarray]]]],
        Tuple[List[str], Dict[str, torch.Tensor]],
    ]:
        assert check_argument_types()
        # NOTE(kamo): int value = 0 is reserved by CTC-blank symbol
        return CommonCollateFn(float_pad_value=0.0, int_pad_value=-1)
@@ -234,8 +197,7 @@
    def build_preprocess_fn(
            cls, args: argparse.Namespace, train: bool
    ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
        assert check_argument_types()
        #if args.use_preprocessor:
        # if args.use_preprocessor:
        #    retval = CommonPreprocessor(
        #        train=train,
        #        # NOTE(kamo): Check attribute existence for backward compatibility
@@ -254,10 +216,9 @@
        #        if hasattr(args, "rir_scp")
        #        else None,
        #    )
        #else:
        # else:
        #    retval = None
        retval = None
        assert check_return_type(retval)
        return retval
    @classmethod
@@ -276,12 +237,10 @@
            cls, train: bool = True, inference: bool = False
    ) -> Tuple[str, ...]:
        retval = ()
        assert check_return_type(retval)
        return retval
    @classmethod
    def build_model(cls, args: argparse.Namespace):
        assert check_argument_types()
        # 4. Encoder
        encoder_class = encoder_choices.get_class(args.encoder)
        encoder = encoder_class(**args.encoder_conf)
@@ -291,7 +250,24 @@
            model_class = model_choices.get_class(args.model)
        except AttributeError:
            model_class = model_choices.get_class("e2evad")
        model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf)
        # 1. frontend
        if args.input_size is None:
            # Extract features in the model
            frontend_class = frontend_choices.get_class(args.frontend)
            if args.frontend == 'wav_frontend':
                frontend = frontend_class(cmvn_file=args.cmvn_file, **args.frontend_conf)
            else:
                frontend = frontend_class(**args.frontend_conf)
            input_size = frontend.output_size()
        else:
            # Give features from data-loader
            args.frontend = None
            args.frontend_conf = {}
            frontend = None
            input_size = args.input_size
        model = model_class(encoder=encoder, vad_post_args=args.vad_post_conf, frontend=frontend)
        return model
@@ -302,6 +278,7 @@
            config_file: Union[Path, str] = None,
            model_file: Union[Path, str] = None,
            device: str = "cpu",
            cmvn_file: Union[Path, str] = None,
    ):
        """Build model from the files.
@@ -313,7 +290,6 @@
            device: Device type, "cpu", "cuda", or "cuda:N".
        """
        assert check_argument_types()
        if config_file is None:
            assert model_file is not None, (
                "The argument 'model_file' must be provided "
@@ -325,6 +301,8 @@
        with config_file.open("r", encoding="utf-8") as f:
            args = yaml.safe_load(f)
        # if cmvn_file is not None:
        args["cmvn_file"] = cmvn_file
        args = argparse.Namespace(**args)
        model = cls.build_model(args)
        model.to(device)