From 7bea61862379d056658d2d7cadae132328c0c76e Mon Sep 17 00:00:00 2001
From: jmwang66 <wangjiaming.wjm@alibaba-inc.com>
Date: 星期一, 06 二月 2023 16:09:24 +0800
Subject: [PATCH] update data2vec pretrain

---
 funasr/bin/data2vec_train.py |   45 ++++
 funasr/tasks/data2vec.py     |  376 +++++++++++++++++++++++++++++++++++++
 funasr/models/data2vec.py    |  160 ++++++++++++++++
 3 files changed, 581 insertions(+), 0 deletions(-)

diff --git a/funasr/bin/data2vec_train.py b/funasr/bin/data2vec_train.py
new file mode 100755
index 0000000..b9dbdff
--- /dev/null
+++ b/funasr/bin/data2vec_train.py
@@ -0,0 +1,45 @@
+#!/usr/bin/env python3
+
+import os
+
+from funasr.tasks.data2vec import Data2VecTask
+
+
+def parse_args():
+    parser = Data2VecTask.get_parser()
+    parser.add_argument(
+        "--gpu_id",
+        type=int,
+        default=0,
+        help="local gpu id.",
+    )
+    args = parser.parse_args()
+    return args
+
+
+def main(args=None, cmd=None):
+    # for data2vec Training
+    Data2VecTask.main(args=args, cmd=cmd)
+
+
+if __name__ == '__main__':
+    args = parse_args()
+
+    # setup local gpu_id
+    os.environ['CUDA_VISIBLE_DEVICES'] = str(args.gpu_id)
+
+    # DDP settings
+    if args.ngpu > 1:
+        args.distributed = True
+    else:
+        args.distributed = False
+    assert args.num_worker_count == 1
+
+    # re-compute batch size: when dataset type is small
+    if args.dataset_type == "small":
+        if args.batch_size is not None:
+            args.batch_size = args.batch_size * args.ngpu
+        if args.batch_bins is not None:
+            args.batch_bins = args.batch_bins * args.ngpu
+
+    main(args=args)
diff --git a/funasr/models/data2vec.py b/funasr/models/data2vec.py
new file mode 100644
index 0000000..fcd6bd2
--- /dev/null
+++ b/funasr/models/data2vec.py
@@ -0,0 +1,160 @@
+# Copyright (c) Facebook, Inc. and its affiliates.
+#
+# This source code is licensed under the MIT license found in the
+# LICENSE file in the root directory of this source tree.
+
+from contextlib import contextmanager
+from distutils.version import LooseVersion
+from typing import Dict
+from typing import Optional
+from typing import Tuple
+
+import torch
+from typeguard import check_argument_types
+
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.torch_utils.device_funcs import force_gatherable
+from funasr.train.abs_espnet_model import AbsESPnetModel
+
+if LooseVersion(torch.__version__) >= LooseVersion("1.6.0"):
+    from torch.cuda.amp import autocast
+else:
+    # Nothing to do if torch<1.6.0
+    @contextmanager
+    def autocast(enabled=True):
+        yield
+
+
+class Data2VecPretrainModel(AbsESPnetModel):
+    """Data2Vec Pretrain model"""
+
+    def __init__(
+            self,
+            frontend: Optional[AbsFrontend],
+            specaug: Optional[AbsSpecAug],
+            normalize: Optional[AbsNormalize],
+            preencoder: Optional[AbsPreEncoder],
+            encoder: AbsEncoder,
+    ):
+        assert check_argument_types()
+
+        super().__init__()
+
+        self.frontend = frontend
+        self.specaug = specaug
+        self.normalize = normalize
+        self.preencoder = preencoder
+        self.encoder = encoder
+        self.num_updates = 0
+
+    def forward(
+            self,
+            speech: torch.Tensor,
+            speech_lengths: torch.Tensor,
+    ) -> Tuple[torch.Tensor, Dict[str, torch.Tensor], torch.Tensor]:
+        """Frontend + Encoder + Calc loss
+
+        Args:
+            speech: (Batch, Length, ...)
+            speech_lengths: (Batch, )
+        """
+        # Check that batch_size is unified
+        assert (
+                speech.shape[0]
+                == speech_lengths.shape[0]
+        ), (speech.shape, speech_lengths.shape)
+
+        self.encoder.set_num_updates(self.num_updates)
+
+        # 1. Encoder
+        encoder_out = self.encode(speech, speech_lengths)
+
+        losses = encoder_out["losses"]
+        loss = sum(losses.values())
+        sample_size = encoder_out["sample_size"]
+        loss = loss.sum() / sample_size
+
+        target_var = float(encoder_out["target_var"])
+        pred_var = float(encoder_out["pred_var"])
+        ema_decay = float(encoder_out["ema_decay"])
+
+        stats = dict(
+            loss=torch.clone(loss.detach()),
+            target_var=target_var,
+            pred_var=pred_var,
+            ema_decay=ema_decay,
+        )
+
+        loss, stats, weight = force_gatherable((loss, stats, sample_size), loss.device)
+        return loss, stats, weight
+
+    def collect_feats(
+            self,
+            speech: torch.Tensor,
+            speech_lengths: torch.Tensor
+    ) -> Dict[str, torch.Tensor]:
+        feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+        return {"feats": feats, "feats_lengths": feats_lengths}
+
+    def encode(
+            self,
+            speech: torch.Tensor,
+            speech_lengths: torch.Tensor,
+    ):
+        """Frontend + Encoder.
+
+        Args:
+            speech: (Batch, Length, ...)
+            speech_lengths: (Batch, )
+        """
+        with autocast(False):
+            # 1. Extract feats
+            feats, feats_lengths = self._extract_feats(speech, speech_lengths)
+
+            # 2. Data augmentation
+            if self.specaug is not None and self.training:
+                feats, feats_lengths = self.specaug(feats, feats_lengths)
+
+            # 3. Normalization for feature: e.g. Global-CMVN, Utterance-CMVN
+            if self.normalize is not None:
+                feats, feats_lengths = self.normalize(feats, feats_lengths)
+
+        # Pre-encoder, e.g. used for raw input data
+        if self.preencoder is not None:
+            feats, feats_lengths = self.preencoder(feats, feats_lengths)
+
+        # 4. Forward encoder
+        if min(speech_lengths) == max(speech_lengths):  # for clipping, set speech_lengths as None
+            speech_lengths = None
+        encoder_out = self.encoder(feats, speech_lengths, mask=True, features_only=False)
+
+        return encoder_out
+
+    def _extract_feats(
+            self, speech: torch.Tensor, speech_lengths: torch.Tensor
+    ) -> Tuple[torch.Tensor, torch.Tensor]:
+        assert speech_lengths.dim() == 1, speech_lengths.shape
+
+        # for data-parallel
+        speech = speech[:, : speech_lengths.max()]
+
+        if self.frontend is not None:
+            # Frontend
+            #  e.g. STFT and Feature extract
+            #       data_loader may send time-domain signal in this case
+            # speech (Batch, NSamples) -> feats: (Batch, NFrames, Dim)
+            feats, feats_lengths = self.frontend(speech, speech_lengths)
+        else:
+            # No frontend and no feature extract
+            feats, feats_lengths = speech, speech_lengths
+        return feats, feats_lengths
+
+    def set_num_updates(self, num_updates):
+        self.num_updates = num_updates
+
+    def get_num_updates(self):
+        return self.num_updates
diff --git a/funasr/tasks/data2vec.py b/funasr/tasks/data2vec.py
new file mode 100644
index 0000000..9a64e1f
--- /dev/null
+++ b/funasr/tasks/data2vec.py
@@ -0,0 +1,376 @@
+import argparse
+from typing import Callable
+from typing import Collection
+from typing import Dict
+from typing import List
+from typing import Optional
+from typing import Tuple
+
+import numpy as np
+import torch
+from typeguard import check_argument_types
+from typeguard import check_return_type
+
+from funasr.datasets.collate_fn import CommonCollateFn
+from funasr.datasets.preprocessor import CommonPreprocessor
+from funasr.layers.abs_normalize import AbsNormalize
+from funasr.layers.global_mvn import GlobalMVN
+from funasr.layers.utterance_mvn import UtteranceMVN
+from funasr.models.data2vec import Data2VecPretrainModel
+from funasr.models.encoder.abs_encoder import AbsEncoder
+from funasr.models.encoder.data2vec_encoder import Data2VecEncoder
+from funasr.models.frontend.abs_frontend import AbsFrontend
+from funasr.models.frontend.default import DefaultFrontend
+from funasr.models.frontend.windowing import SlidingWindow
+from funasr.models.preencoder.abs_preencoder import AbsPreEncoder
+from funasr.models.preencoder.sinc import LightweightSincConvs
+from funasr.models.specaug.abs_specaug import AbsSpecAug
+from funasr.models.specaug.specaug import SpecAug
+from funasr.tasks.abs_task import AbsTask
+from funasr.text.phoneme_tokenizer import g2p_choices
+from funasr.torch_utils.initialize import initialize
+from funasr.train.class_choices import ClassChoices
+from funasr.train.trainer import Trainer
+from funasr.utils.types import float_or_none
+from funasr.utils.types import int_or_none
+from funasr.utils.types import str2bool
+from funasr.utils.types import str_or_none
+
+frontend_choices = ClassChoices(
+    name="frontend",
+    classes=dict(default=DefaultFrontend, sliding_window=SlidingWindow),
+    type_check=AbsFrontend,
+    default="default",
+)
+specaug_choices = ClassChoices(
+    name="specaug",
+    classes=dict(specaug=SpecAug),
+    type_check=AbsSpecAug,
+    default=None,
+    optional=True,
+)
+normalize_choices = ClassChoices(
+    "normalize",
+    classes=dict(
+        global_mvn=GlobalMVN,
+        utterance_mvn=UtteranceMVN,
+    ),
+    type_check=AbsNormalize,
+    default=None,
+    optional=True,
+)
+preencoder_choices = ClassChoices(
+    name="preencoder",
+    classes=dict(
+        sinc=LightweightSincConvs,
+    ),
+    type_check=AbsPreEncoder,
+    default=None,
+    optional=True,
+)
+encoder_choices = ClassChoices(
+    "encoder",
+    classes=dict(
+        data2vec_encoder=Data2VecEncoder,
+    ),
+    type_check=AbsEncoder,
+    default="data2vec_encoder",
+)
+model_choices = ClassChoices(
+    "model",
+    classes=dict(
+        data2vec=Data2VecPretrainModel,
+    ),
+    default="data2vec",
+)
+
+
+class Data2VecTask(AbsTask):
+    # If you need more than one optimizers, change this value
+    num_optimizers: int = 1
+
+    # Add variable objects configurations
+    class_choices_list = [
+        # --frontend and --frontend_conf
+        frontend_choices,
+        # --specaug and --specaug_conf
+        specaug_choices,
+        # --normalize and --normalize_conf
+        normalize_choices,
+        # --preencoder and --preencoder_conf
+        preencoder_choices,
+        # --encoder and --encoder_conf
+        encoder_choices,
+        # --model and --model_conf
+        model_choices,
+    ]
+
+    # If you need to modify train() or eval() procedures, change Trainer class here
+    trainer = Trainer
+
+    @classmethod
+    def add_task_arguments(cls, parser: argparse.ArgumentParser):
+        group = parser.add_argument_group(description="Task related")
+
+        # NOTE(kamo): add_arguments(..., required=True) can't be used
+        # to provide --print_config mode. Instead of it, do as
+        group.add_argument(
+            "--token_list",
+            type=str_or_none,
+            default=None,
+            help="A text mapping int-id to token",
+        )
+        group.add_argument(
+            "--init",
+            type=lambda x: str_or_none(x.lower()),
+            default=None,
+            help="The initialization method",
+            choices=[
+                "chainer",
+                "xavier_uniform",
+                "xavier_normal",
+                "kaiming_uniform",
+                "kaiming_normal",
+                None,
+            ],
+        )
+
+        group.add_argument(
+            "--input_size",
+            type=int_or_none,
+            default=None,
+            help="The number of input dimension of the feature",
+        )
+
+        group = parser.add_argument_group(description="Preprocess related")
+        group.add_argument(
+            "--use_preprocessor",
+            type=str2bool,
+            default=True,
+            help="Apply preprocessing to data or not",
+        )
+        group.add_argument(
+            "--token_type",
+            type=str,
+            default=None,
+            choices=["bpe", "char", "word", "phn"],
+            help="The text will be tokenized " "in the specified level token",
+        )
+
+        group.add_argument(
+            "--feats_type",
+            type=str,
+            default='fbank',
+            help="feats type, e.g. fbank, wav, ark_wav(needed to be scale normalization)",
+        )
+
+        group.add_argument(
+            "--bpemodel",
+            type=str_or_none,
+            default=None,
+            help="The model file of sentencepiece",
+        )
+        parser.add_argument(
+            "--non_linguistic_symbols",
+            type=str_or_none,
+            help="non_linguistic_symbols file path",
+        )
+        parser.add_argument(
+            "--cleaner",
+            type=str_or_none,
+            choices=[None, "tacotron", "jaconv", "vietnamese"],
+            default=None,
+            help="Apply text cleaning",
+        )
+        parser.add_argument(
+            "--g2p",
+            type=str_or_none,
+            choices=g2p_choices,
+            default=None,
+            help="Specify g2p method if --token_type=phn",
+        )
+        parser.add_argument(
+            "--speech_volume_normalize",
+            type=float_or_none,
+            default=None,
+            help="Scale the maximum amplitude to the given value.",
+        )
+        parser.add_argument(
+            "--rir_scp",
+            type=str_or_none,
+            default=None,
+            help="The file path of rir scp file.",
+        )
+        parser.add_argument(
+            "--rir_apply_prob",
+            type=float,
+            default=1.0,
+            help="THe probability for applying RIR convolution.",
+        )
+        parser.add_argument(
+            "--noise_scp",
+            type=str_or_none,
+            default=None,
+            help="The file path of noise scp file.",
+        )
+        parser.add_argument(
+            "--noise_apply_prob",
+            type=float,
+            default=1.0,
+            help="The probability applying Noise adding.",
+        )
+        parser.add_argument(
+            "--noise_db_range",
+            type=str,
+            default="13_15",
+            help="The range of noise decibel level.",
+        )
+        parser.add_argument(
+            "--pred_masked_weight",
+            type=float,
+            default=1.0,
+            help="weight for predictive loss for masked frames",
+        )
+        parser.add_argument(
+            "--pred_nomask_weight",
+            type=float,
+            default=0.0,
+            help="weight for predictive loss for unmasked frames",
+        )
+        parser.add_argument(
+            "--loss_weights",
+            type=float,
+            default=0.0,
+            help="weights for additional loss terms (not first one)",
+        )
+
+        for class_choices in cls.class_choices_list:
+            # Append --<name> and --<name>_conf.
+            # e.g. --encoder and --encoder_conf
+            class_choices.add_arguments(group)
+
+    @classmethod
+    def build_collate_fn(
+            cls, args: argparse.Namespace, train: bool
+    ) -> Callable[
+        [Collection[Tuple[str, Dict[str, np.ndarray]]]],
+        Tuple[List[str], Dict[str, torch.Tensor]],
+    ]:
+        assert check_argument_types()
+        return CommonCollateFn(clipping=True)
+
+    @classmethod
+    def build_preprocess_fn(
+            cls, args: argparse.Namespace, train: bool
+    ) -> Optional[Callable[[str, Dict[str, np.array]], Dict[str, np.ndarray]]]:
+        assert check_argument_types()
+        if args.use_preprocessor:
+            retval = CommonPreprocessor(
+                train=train,
+                bpemodel=args.bpemodel,
+                non_linguistic_symbols=args.non_linguistic_symbols,
+                text_cleaner=args.cleaner,
+                g2p_type=args.g2p,
+                # NOTE(kamo): Check attribute existence for backward compatibility
+                rir_scp=args.rir_scp if hasattr(args, "rir_scp") else None,
+                rir_apply_prob=args.rir_apply_prob
+                if hasattr(args, "rir_apply_prob")
+                else 1.0,
+                noise_scp=args.noise_scp if hasattr(args, "noise_scp") else None,
+                noise_apply_prob=args.noise_apply_prob
+                if hasattr(args, "noise_apply_prob")
+                else 1.0,
+                noise_db_range=args.noise_db_range
+                if hasattr(args, "noise_db_range")
+                else "13_15",
+                speech_volume_normalize=args.speech_volume_normalize
+                if hasattr(args, "rir_scp")
+                else None,
+            )
+        else:
+            retval = None
+        assert check_return_type(retval)
+        return retval
+
+    @classmethod
+    def required_data_names(
+            cls, train: bool = True, inference: bool = False
+    ) -> Tuple[str, ...]:
+        # for pre-training
+        retval = ("speech",)
+        return retval
+
+    @classmethod
+    def optional_data_names(
+            cls, train: bool = True, inference: bool = False
+    ) -> Tuple[str, ...]:
+        retval = ()
+        assert check_return_type(retval)
+        return retval
+
+    @classmethod
+    def build_model(cls, args: argparse.Namespace):
+        assert check_argument_types()
+
+        # 1. frontend
+        if args.input_size is None:
+            # Extract features in the model
+            frontend_class = frontend_choices.get_class(args.frontend)
+            frontend = frontend_class(**args.frontend_conf)
+            input_size = frontend.output_size()
+        else:
+            # Give features from data-loader
+            args.frontend = None
+            args.frontend_conf = {}
+            frontend = None
+            input_size = args.input_size
+
+        # 2. Data augmentation for spectrogram
+        if args.specaug is not None:
+            specaug_class = specaug_choices.get_class(args.specaug)
+            specaug = specaug_class(**args.specaug_conf)
+        else:
+            specaug = None
+
+        # 3. Normalization layer
+        if args.normalize is not None:
+            normalize_class = normalize_choices.get_class(args.normalize)
+            normalize = normalize_class(**args.normalize_conf)
+        else:
+            normalize = None
+
+        # 4. Pre-encoder input block
+        # NOTE(kan-bayashi): Use getattr to keep the compatibility
+        if getattr(args, "preencoder", None) is not None:
+            preencoder_class = preencoder_choices.get_class(args.preencoder)
+            preencoder = preencoder_class(**args.preencoder_conf)
+            input_size = preencoder.output_size()
+        else:
+            preencoder = None
+
+        # 5. Encoder
+        encoder_class = encoder_choices.get_class(args.encoder)
+        encoder = encoder_class(
+            input_size=input_size,
+            **args.encoder_conf,
+        )
+
+        # 6. Build model
+        try:
+            model_class = model_choices.get_class(args.model)
+        except AttributeError:
+            model_class = model_choices.get_class("data2vec")
+        model = model_class(
+            frontend=frontend,
+            specaug=specaug,
+            normalize=normalize,
+            preencoder=preencoder,
+            encoder=encoder,
+        )
+
+        # 7. Initialize
+        if args.init is not None:
+            initialize(model, args.init)
+
+        assert check_return_type(model)
+        return model

--
Gitblit v1.9.1