From 7bea61862379d056658d2d7cadae132328c0c76e Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期一, 06 二月 2023 16:09:24 +0800
Subject: [PATCH] update data2vec pretrain
---
funasr/bin/data2vec_train.py | 45 ++++
funasr/tasks/data2vec.py | 376 +++++++++++++++++++++++++++++++++++++
funasr/models/data2vec.py | 160 ++++++++++++++++
3 files changed, 581 insertions(+), 0 deletions(-)
diff --git a/funasr/bin/data2vec_train.py b/funasr/bin/data2vec_train.py
new file mode 100755
index 0000000..b9dbdff
--- /dev/null
+++ b/funasr/bin/data2vec_train.py
@@ -0,0 +1,45 @@
+#!/usr/bin/env python3
+
+import os
+
+from funasr.tasks.data2vec import Data2VecTask
+
+
+def parse_args():
+ parser = Data2VecTask.get_parser()
+ parser.add_argument(
+ "--gpu_id",
+ type=int,
+ default=0,
+ help="local gpu id.",
+ )
+ args = parser.parse_args()
+ return args
+
+
+def main(args=None, cmd=None):
+ # for data2vec Training
+ Data2VecTask.main(args=args, cmd=cmd)
+
+
+if __name__ == '__main__':
+ args = parse_args()
+
+ # setup local gpu_id
+ os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+
+ # DDP settings
+ if args.ngpu > 1:
+ args.distributed = True
+ else:
+ args.distributed = False
+ assert args.num_worker_count == 1
+
+ # re-compute batch size: when dataset type is small
+ if args.dataset_type == "small":
+ if args.batch_size is not None:
+ args.batch_size = args.batch_size * args.ngpu
+ if args.batch_bins is not None:
+ args.batch_bins = args.batch_bins * args.ngpu
+
+ main(args=args)
diff --git a/funasr/models/data2vec.py b/funasr/models/data2vec.py
new file mode 100644
index 0000000..fcd6bd2
--- /dev/null
+++ b/funasr/models/data2vec.py
@@ -0,0 +1,160 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+
+import torch
+from typeguard import check_argument_types
+
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.train.abs_espnet_model import AbsESPnetModel
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+ from torch.cuda.amp import autocast
+else:
+ # Nothing to do if torch<1.6.0
+ @contextmanager
+ def autocast(enabled=True):
+ yield
+
+
+class Data2VecPretrainModel(AbsESPnetModel):
+ """Data2Vec Pretrain model"""
+
+ def __init__(
+ self,
+ frontend: Optional[AbsFrontend],
+ specaug: Optional[AbsSpecAug],
+ normalize: Optional[AbsNormalize],
+ preencoder: Optional[AbsPreEncoder],
+ encoder: AbsEncoder,
+ ):
+ assert check_argument_types()
+
+ super().__init__()
+
+ self.frontend = frontend
+ self.specaug = specaug
+ self.normalize = normalize
+ self.preencoder = preencoder
+ self.encoder = encoder
+ self.num_updates = 0
+
+ def forward(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+ """Frontend + Encoder + Calc loss
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ """
+ # Check that batch_size is unified
+ assert (
+ speech.shape[0]
+ == speech_lengths.shape[0]
+ ), (speech.shape, speech_lengths.shape)
+
+ self.encoder.set_num_updates(self.num_updates)
+
+ # 1. Encoder
+ encoder_out = self.encode(speech, speech_lengths)
+
+ losses = encoder_out["losses"]
+ loss = sum(losses.values())
+ sample_size = encoder_out["sample_size"]
+ loss = loss.sum() / sample_size
+
+ target_var = float(encoder_out["target_var"])
+ pred_var = float(encoder_out["pred_var"])
+ ema_decay = float(encoder_out["ema_decay"])
+
+ stats = dict(
+ loss=torch.clone(loss.detach()),
+ target_var=target_var,
+ pred_var=pred_var,
+ ema_decay=ema_decay,
+ )
+
+ loss, stats, weight = force_gatherable((loss, stats, sample_size), loss.device)
+ return loss, stats, weight
+
+ def collect_feats(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor
+ ) -> Dict[str, torch.Tensor]:
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+ return {"feats": feats, "feats_lengths": feats_lengths}
+
+ def encode(
+ self,
+ speech: torch.Tensor,
+ speech_lengths: torch.Tensor,
+ ):
+ """Frontend + Encoder.
+
+ Args:
+ speech: (Batch, Length, ...)
+ speech_lengths: (Batch, )
+ """
+ with autocast(False):
+ # 1. Extract feats
+ feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+ # 2. Data augmentation
+ if self.specaug is not None and self.training:
+ feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+ # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+ if self.normalize is not None:
+ feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+ # Pre-encoder, e.g. used for raw input data
+ if self.preencoder is not None:
+ feats, feats_lengths = self.preencoder(feats, feats_lengths)
+
+ # 4. Forward encoder
+ if min(speech_lengths) == max(speech_lengths): # for clipping, set speech_lengths as None
+ speech_lengths = None
+ encoder_out = self.encoder(feats, speech_lengths, mask=True, features_only=False)
+
+ return encoder_out
+
+ def _extract_feats(
+ self, speech: torch.Tensor, speech_lengths: torch.Tensor
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
+ assert speech_lengths.dim() == 1, speech_lengths.shape
+
+ # for data-parallel
+ speech = speech[:, : speech_lengths.max()]
+
+ if self.frontend is not None:
+ # Frontend
+ # e.g. STFT and Feature extract
+ # data_loader may send time-domain signal in this case
+ # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
+ feats, feats_lengths = self.frontend(speech, speech_lengths)
+ else:
+ # No frontend and no feature extract
+ feats, feats_lengths = speech, speech_lengths
+ return feats, feats_lengths
+
+ def set_num_updates(self, num_updates):
+ self.num_updates = num_updates
+
+ def get_num_updates(self):
+ return self.num_updates
diff --git a/funasr/tasks/data2vec.py b/funasr/tasks/data2vec.py
new file mode 100644
index 0000000..9a64e1f
--- /dev/null
+++ b/funasr/tasks/data2vec.py
@@ -0,0 +1,376 @@
+import argparse
+from typing import Callable
+from typing import Collection
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.datasets.collate_fn import CommonCollateFn
+from funasr.datasets.preprocessor import CommonPreprocessor
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.models.data2vec import Data2VecPretrainModel
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.preencoder.sinc import LightweightSincConvs
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.specaug.specaug import SpecAug
+from funasr.tasks.abs_task import AbsTask
+from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+from funasr.train.trainer import Trainer
+from funasr.utils.types import float_or_none
+from funasr.utils.types import int_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
+
+frontend_choices = ClassChoices(
+ name="frontend",
+ classes=dict(default=DefaultFrontend, sliding_window=SlidingWindow),
+ type_check=AbsFrontend,
+ default="default",
+)
+specaug_choices = ClassChoices(
+ name="specaug",
+ classes=dict(specaug=SpecAug),
+ type_check=AbsSpecAug,
+ default=None,
+ optional=True,
+)
+normalize_choices = ClassChoices(
+ "normalize",
+ classes=dict(
+ global_mvn=GlobalMVN,
+ utterance_mvn=UtteranceMVN,
+ ),
+ type_check=AbsNormalize,
+ default=None,
+ optional=True,
+)
+preencoder_choices = ClassChoices(
+ name="preencoder",
+ classes=dict(
+ sinc=LightweightSincConvs,
+ ),
+ type_check=AbsPreEncoder,
+ default=None,
+ optional=True,
+)
+encoder_choices = ClassChoices(
+ "encoder",
+ classes=dict(
+ data2vec_encoder=Data2VecEncoder,
+ ),
+ type_check=AbsEncoder,
+ default="data2vec_encoder",
+)
+model_choices = ClassChoices(
+ "model",
+ classes=dict(
+ data2vec=Data2VecPretrainModel,
+ ),
+ default="data2vec",
+)
+
+
+class Data2VecTask(AbsTask):
+ # If you need more than one optimizers, change this value
+ num_optimizers: int = 1
+
+ # Add variable objects configurations
+ class_choices_list = [
+ # --frontend and --frontend_conf
+ frontend_choices,
+ # --specaug and --specaug_conf
+ specaug_choices,
+ # --normalize and --normalize_conf
+ normalize_choices,
+ # --preencoder and --preencoder_conf
+ preencoder_choices,
+ # --encoder and --encoder_conf
+ encoder_choices,
+ # --model and --model_conf
+ model_choices,
+ ]
+
+ # If you need to modify train() or eval() procedures, change Trainer class here
+ trainer = Trainer
+
+ @classmethod
+ def add_task_arguments(cls, parser: argparse.ArgumentParser):
+ group = parser.add_argument_group(description="Task related")
+
+ # NOTE(kamo): add_arguments(..., required=True) can't be used
+ # to provide --print_config mode. Instead of it, do as
+ group.add_argument(
+ "--token_list",
+ type=str_or_none,
+ default=None,
+ help="A text mapping int-id to token",
+ )
+ group.add_argument(
+ "--init",
+ type=lambda x: str_or_none(x.lower()),
+ default=None,
+ help="The initialization method",
+ choices=[
+ "chainer",
+ "xavier_uniform",
+ "xavier_normal",
+ "kaiming_uniform",
+ "kaiming_normal",
+ None,
+ ],
+ )
+
+ group.add_argument(
+ "--input_size",
+ type=int_or_none,
+ default=None,
+ help="The number of input dimension of the feature",
+ )
+
+ group = parser.add_argument_group(description="Preprocess related")
+ group.add_argument(
+ "--use_preprocessor",
+ type=str2bool,
+ default=True,
+ help="Apply preprocessing to data or not",
+ )
+ group.add_argument(
+ "--token_type",
+ type=str,
+ default=None,
+ choices=["bpe", "char", "word", "phn"],
+ help="The text will be tokenized " "in the specified level token",
+ )
+
+ group.add_argument(
+ "--feats_type",
+ type=str,
+ default='fbank',
+ help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)",
+ )
+
+ group.add_argument(
+ "--bpemodel",
+ type=str_or_none,
+ default=None,
+ help="The model file of sentencepiece",
+ )
+ parser.add_argument(
+ "--non_linguistic_symbols",
+ type=str_or_none,
+ help="non_linguistic_symbols file path",
+ )
+ parser.add_argument(
+ "--cleaner",
+ type=str_or_none,
+ choices=[None, "tacotron", "jaconv", "vietnamese"],
+ default=None,
+ help="Apply text cleaning",
+ )
+ parser.add_argument(
+ "--g2p",
+ type=str_or_none,
+ choices=g2p_choices,
+ default=None,
+ help="Specify g2p method if --token_type=phn",
+ )
+ parser.add_argument(
+ "--speech_volume_normalize",
+ type=float_or_none,
+ default=None,
+ help="Scale the maximum amplitude to the given value.",
+ )
+ parser.add_argument(
+ "--rir_scp",
+ type=str_or_none,
+ default=None,
+ help="The file path of rir scp file.",
+ )
+ parser.add_argument(
+ "--rir_apply_prob",
+ type=float,
+ default=1.0,
+ help="THe probability for applying RIR convolution.",
+ )
+ parser.add_argument(
+ "--noise_scp",
+ type=str_or_none,
+ default=None,
+ help="The file path of noise scp file.",
+ )
+ parser.add_argument(
+ "--noise_apply_prob",
+ type=float,
+ default=1.0,
+ help="The probability applying Noise adding.",
+ )
+ parser.add_argument(
+ "--noise_db_range",
+ type=str,
+ default="13_15",
+ help="The range of noise decibel level.",
+ )
+ parser.add_argument(
+ "--pred_masked_weight",
+ type=float,
+ default=1.0,
+ help="weight for predictive loss for masked frames",
+ )
+ parser.add_argument(
+ "--pred_nomask_weight",
+ type=float,
+ default=0.0,
+ help="weight for predictive loss for unmasked frames",
+ )
+ parser.add_argument(
+ "--loss_weights",
+ type=float,
+ default=0.0,
+ help="weights for additional loss terms (not first one)",
+ )
+
+ for class_choices in cls.class_choices_list:
+ # Append --<name> and --<name>_conf.
+ # e.g. --encoder and --encoder_conf
+ class_choices.add_arguments(group)
+
+ @classmethod
+ def build_collate_fn(
+ cls, args: argparse.Namespace, train: bool
+ ) -> Callable[
+ [Collection[Tuple[str, Dict[str, np.ndarray]]]],
+ Tuple[List[str], Dict[str, torch.Tensor]],
+ ]:
+ assert check_argument_types()
+ return CommonCollateFn(clipping=True)
+
+ @classmethod
+ def build_preprocess_fn(
+ cls, args: argparse.Namespace, train: bool
+ ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
+ assert check_argument_types()
+ if args.use_preprocessor:
+ retval = CommonPreprocessor(
+ train=train,
+ bpemodel=args.bpemodel,
+ non_linguistic_symbols=args.non_linguistic_symbols,
+ text_cleaner=args.cleaner,
+ g2p_type=args.g2p,
+ # NOTE(kamo): Check attribute existence for backward compatibility
+ rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
+ rir_apply_prob=args.rir_apply_prob
+ if hasattr(args, "rir_apply_prob")
+ else 1.0,
+ noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
+ noise_apply_prob=args.noise_apply_prob
+ if hasattr(args, "noise_apply_prob")
+ else 1.0,
+ noise_db_range=args.noise_db_range
+ if hasattr(args, "noise_db_range")
+ else "13_15",
+ speech_volume_normalize=args.speech_volume_normalize
+ if hasattr(args, "rir_scp")
+ else None,
+ )
+ else:
+ retval = None
+ assert check_return_type(retval)
+ return retval
+
+ @classmethod
+ def required_data_names(
+ cls, train: bool = True, inference: bool = False
+ ) -> Tuple[str, ...]:
+ # for pre-training
+ retval = ("speech",)
+ return retval
+
+ @classmethod
+ def optional_data_names(
+ cls, train: bool = True, inference: bool = False
+ ) -> Tuple[str, ...]:
+ retval = ()
+ assert check_return_type(retval)
+ return retval
+
+ @classmethod
+ def build_model(cls, args: argparse.Namespace):
+ assert check_argument_types()
+
+ # 1. frontend
+ if args.input_size is None:
+ # Extract features in the model
+ frontend_class = frontend_choices.get_class(args.frontend)
+ frontend = frontend_class(**args.frontend_conf)
+ input_size = frontend.output_size()
+ else:
+ # Give features from data-loader
+ args.frontend = None
+ args.frontend_conf = {}
+ frontend = None
+ input_size = args.input_size
+
+ # 2. Data augmentation for spectrogram
+ if args.specaug is not None:
+ specaug_class = specaug_choices.get_class(args.specaug)
+ specaug = specaug_class(**args.specaug_conf)
+ else:
+ specaug = None
+
+ # 3. Normalization layer
+ if args.normalize is not None:
+ normalize_class = normalize_choices.get_class(args.normalize)
+ normalize = normalize_class(**args.normalize_conf)
+ else:
+ normalize = None
+
+ # 4. Pre-encoder input block
+ # NOTE(kan-bayashi): Use getattr to keep the compatibility
+ if getattr(args, "preencoder", None) is not None:
+ preencoder_class = preencoder_choices.get_class(args.preencoder)
+ preencoder = preencoder_class(**args.preencoder_conf)
+ input_size = preencoder.output_size()
+ else:
+ preencoder = None
+
+ # 5. Encoder
+ encoder_class = encoder_choices.get_class(args.encoder)
+ encoder = encoder_class(
+ input_size=input_size,
+ **args.encoder_conf,
+ )
+
+ # 6. Build model
+ try:
+ model_class = model_choices.get_class(args.model)
+ except AttributeError:
+ model_class = model_choices.get_class("data2vec")
+ model = model_class(
+ frontend=frontend,
+ specaug=specaug,
+ normalize=normalize,
+ preencoder=preencoder,
+ encoder=encoder,
+ )
+
+ # 7. Initialize
+ if args.init is not None:
+ initialize(model, args.init)
+
+ assert check_return_type(model)
+ return model
--
Gitblit v1.9.1